diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index f38bc1fbe5..7893ee9260 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -40,6 +40,7 @@ import sagemaker.local.data import sagemaker.local.utils import sagemaker.utils +from sagemaker.utils import check_tarfile_data_filter_attribute CONTAINER_PREFIX = "algo" STUDIO_HOST_NAME = "sagemaker-local" @@ -686,7 +687,8 @@ def _prepare_serving_volumes(self, model_location): for filename in model_data_source.get_file_list(): if tarfile.is_tarfile(filename): with tarfile.open(filename) as tar: - tar.extractall(path=model_data_source.get_root_dir()) + check_tarfile_data_filter_attribute() + tar.extractall(path=model_data_source.get_root_dir(), filter="data") volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model")) diff --git a/src/sagemaker/serve/model_server/djl_serving/prepare.py b/src/sagemaker/serve/model_server/djl_serving/prepare.py index 386c5fb66e..6bdada0b6c 100644 --- a/src/sagemaker/serve/model_server/djl_serving/prepare.py +++ b/src/sagemaker/serve/model_server/djl_serving/prepare.py @@ -20,7 +20,7 @@ from typing import List from pathlib import Path -from sagemaker.utils import _tmpdir +from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute from sagemaker.s3 import S3Downloader from sagemaker.djl_inference import DJLModel from sagemaker.djl_inference.model import _read_existing_serving_properties @@ -53,7 +53,8 @@ def _extract_js_resource(js_model_dir: str, js_id: str): """Uncompress the jumpstart resource""" tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") with tarfile.open(str(tmp_sourcedir)) as resources: - resources.extractall(path=js_model_dir) + check_tarfile_data_filter_attribute() + resources.extractall(path=js_model_dir, filter="data") def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path): diff --git a/src/sagemaker/serve/model_server/tgi/prepare.py b/src/sagemaker/serve/model_server/tgi/prepare.py index fe1162e505..9b187dd2ed 100644 --- a/src/sagemaker/serve/model_server/tgi/prepare.py +++ b/src/sagemaker/serve/model_server/tgi/prepare.py @@ -19,7 +19,7 @@ from pathlib import Path from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage -from sagemaker.utils import _tmpdir +from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute from sagemaker.s3 import S3Downloader logger = logging.getLogger(__name__) @@ -29,7 +29,8 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str): """Uncompress the jumpstart resource""" tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz") with tarfile.open(str(tmp_sourcedir)) as resources: - resources.extractall(path=code_dir) + check_tarfile_data_filter_attribute() + resources.extractall(path=code_dir, filter="data") def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool: diff --git a/src/sagemaker/utils.py b/src/sagemaker/utils.py index 15a3d128de..a6d26db48b 100644 --- a/src/sagemaker/utils.py +++ b/src/sagemaker/utils.py @@ -22,6 +22,7 @@ import random import re import shutil +import sys import tarfile import tempfile import time @@ -591,7 +592,8 @@ def _create_or_update_code_dir( download_file_from_url(source_directory, local_code_path, sagemaker_session) with tarfile.open(name=local_code_path, mode="r:gz") as t: - t.extractall(path=code_dir) + check_tarfile_data_filter_attribute() + t.extractall(path=code_dir, filter="data") elif source_directory: if os.path.exists(code_dir): @@ -628,7 +630,8 @@ def _extract_model(model_uri, sagemaker_session, tmp): else: local_model_path = model_uri.replace("file://", "") with tarfile.open(name=local_model_path, mode="r:gz") as t: - t.extractall(path=tmp_model_dir) + check_tarfile_data_filter_attribute() + t.extractall(path=tmp_model_dir, filter="data") return tmp_model_dir @@ -1489,3 +1492,25 @@ def format_tags(tags: Tags) -> List[TagsDict]: return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()] return tags + + +class PythonVersionError(Exception): + """Raise when a secure [/patched] version of Python is not used.""" + + +def check_tarfile_data_filter_attribute(): + """Check if tarfile has data_filter utility. + + Tarfile-data_filter utility has guardrails against untrusted de-serialisation. + + Raises: + PythonVersionError: if `tarfile.data_filter` is not available. + """ + # The function and it's usages can be deprecated post support of python >= 3.12 + if not hasattr(tarfile, "data_filter"): + raise PythonVersionError( + f"Since tarfile extraction is unsafe the operation is prohibited " + f"per PEP-721. Please update your Python [{sys.version}] " + f"to latest patch [refer to https://www.python.org/downloads/] " + f"to consume the security patch" + ) diff --git a/src/sagemaker/workflow/_utils.py b/src/sagemaker/workflow/_utils.py index 9c4fa114ab..1b88bfd924 100644 --- a/src/sagemaker/workflow/_utils.py +++ b/src/sagemaker/workflow/_utils.py @@ -32,7 +32,12 @@ Step, ConfigurableRetryStep, ) -from sagemaker.utils import _save_model, download_file_from_url, format_tags +from sagemaker.utils import ( + _save_model, + download_file_from_url, + format_tags, + check_tarfile_data_filter_attribute, +) from sagemaker.workflow.retry import RetryPolicy from sagemaker.workflow.utilities import trim_request_dict @@ -257,7 +262,8 @@ def _inject_repack_script_and_launcher(self): download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session) with tarfile.open(name=old_targz_path, mode="r:gz") as t: - t.extractall(path=targz_contents_dir) + check_tarfile_data_filter_attribute() + t.extractall(path=targz_contents_dir, filter="data") shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT)) with open( diff --git a/tests/integ/s3_utils.py b/tests/integ/s3_utils.py index 58a403341e..500dc4a33a 100644 --- a/tests/integ/s3_utils.py +++ b/tests/integ/s3_utils.py @@ -19,6 +19,8 @@ import boto3 from six.moves.urllib.parse import urlparse +from sagemaker.utils import check_tarfile_data_filter_attribute + def assert_s3_files_exist(sagemaker_session, s3_url, files): parsed_url = urlparse(s3_url) @@ -55,4 +57,5 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session): s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model) with tarfile.open(model, "r") as tar_file: - tar_file.extractall(tmpdir) + check_tarfile_data_filter_attribute() + tar_file.extractall(tmpdir, filter="data") diff --git a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py index 40d3edb251..caa8884186 100644 --- a/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/djl_serving/test_djl_prepare.py @@ -272,4 +272,4 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path): mock_path.assert_called_once_with(js_model_dir) mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz") - mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir) + mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir, filter="data") diff --git a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py index c055be1f7d..c072f3cb99 100644 --- a/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py +++ b/tests/unit/sagemaker/serve/model_server/tgi/test_tgi_prepare.py @@ -156,4 +156,4 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path): mock_path.assert_called_once_with(js_model_dir) mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz") - mock_resource_obj.extractall.assert_called_once_with(path=code_dir) + mock_resource_obj.extractall.assert_called_once_with(path=code_dir, filter="data") diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index fa419eb848..4600785159 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -24,7 +24,7 @@ from mock import Mock, patch from sagemaker import fw_utils -from sagemaker.utils import name_from_image +from sagemaker.utils import name_from_image, check_tarfile_data_filter_attribute from sagemaker.session_settings import SessionSettings from sagemaker.instance_group import InstanceGroup @@ -424,7 +424,8 @@ def list_tar_files(folder, tar_ball, tmpdir): startpath = str(tmpdir.ensure(folder, dir=True)) with tarfile.open(name=tar_ball, mode="r:gz") as t: - t.extractall(path=startpath) + check_tarfile_data_filter_attribute() + t.extractall(path=startpath, filter="data") def walk(): for root, dirs, files in os.walk(startpath): diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index d733752428..8488a8308e 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -42,6 +42,8 @@ resolve_nested_dict_value_from_config, update_list_of_dicts_with_values_from_config, volume_size_supported, + PythonVersionError, + check_tarfile_data_filter_attribute, ) from tests.unit.sagemaker.workflow.helpers import CustomStep from sagemaker.workflow.parameters import ParameterString, ParameterInteger @@ -1748,3 +1750,15 @@ def test_instance_family_from_full_instance_type(self): for instance_type, family in instance_type_to_family_test_dict.items(): self.assertEqual(family, get_instance_type_family(instance_type)) + + +class TestCheckTarfileDataFilterAttribute(TestCase): + def test_check_tarfile_data_filter_attribute_unhappy_case(self): + with pytest.raises(PythonVersionError): + with patch("tarfile.data_filter", None): + delattr(tarfile, "data_filter") + check_tarfile_data_filter_attribute() + + def test_check_tarfile_data_filter_attribute_happy_case(self): + with patch("tarfile.data_filter", "some_value"): + check_tarfile_data_filter_attribute()