Skip to content
4 changes: 3 additions & 1 deletion src/sagemaker/local/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import sagemaker.local.data
import sagemaker.local.utils
import sagemaker.utils
from sagemaker.utils import check_tarfile_data_filter_attribute

CONTAINER_PREFIX = "algo"
STUDIO_HOST_NAME = "sagemaker-local"
Expand Down Expand Up @@ -686,7 +687,8 @@ def _prepare_serving_volumes(self, model_location):
for filename in model_data_source.get_file_list():
if tarfile.is_tarfile(filename):
with tarfile.open(filename) as tar:
tar.extractall(path=model_data_source.get_root_dir())
check_tarfile_data_filter_attribute()
tar.extractall(path=model_data_source.get_root_dir(), filter="data")

volumes.append(_Volume(model_data_source.get_root_dir(), "/opt/ml/model"))

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/serve/model_server/djl_serving/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from typing import List
from pathlib import Path

from sagemaker.utils import _tmpdir
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
from sagemaker.s3 import S3Downloader
from sagemaker.djl_inference import DJLModel
from sagemaker.djl_inference.model import _read_existing_serving_properties
Expand Down Expand Up @@ -53,7 +53,8 @@ def _extract_js_resource(js_model_dir: str, js_id: str):
"""Uncompress the jumpstart resource"""
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
with tarfile.open(str(tmp_sourcedir)) as resources:
resources.extractall(path=js_model_dir)
check_tarfile_data_filter_attribute()
resources.extractall(path=js_model_dir, filter="data")


def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path):
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/serve/model_server/tgi/prepare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from pathlib import Path

from sagemaker.serve.utils.local_hardware import _check_disk_space, _check_docker_disk_usage
from sagemaker.utils import _tmpdir
from sagemaker.utils import _tmpdir, check_tarfile_data_filter_attribute
from sagemaker.s3 import S3Downloader

logger = logging.getLogger(__name__)
Expand All @@ -29,7 +29,8 @@ def _extract_js_resource(js_model_dir: str, code_dir: Path, js_id: str):
"""Uncompress the jumpstart resource"""
tmp_sourcedir = Path(js_model_dir).joinpath(f"infer-prepack-{js_id}.tar.gz")
with tarfile.open(str(tmp_sourcedir)) as resources:
resources.extractall(path=code_dir)
check_tarfile_data_filter_attribute()
resources.extractall(path=code_dir, filter="data")


def _copy_jumpstart_artifacts(model_data: str, js_id: str, code_dir: Path) -> bool:
Expand Down
29 changes: 27 additions & 2 deletions src/sagemaker/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import random
import re
import shutil
import sys
import tarfile
import tempfile
import time
Expand Down Expand Up @@ -591,7 +592,8 @@ def _create_or_update_code_dir(
download_file_from_url(source_directory, local_code_path, sagemaker_session)

with tarfile.open(name=local_code_path, mode="r:gz") as t:
t.extractall(path=code_dir)
check_tarfile_data_filter_attribute()
t.extractall(path=code_dir, filter="data")

elif source_directory:
if os.path.exists(code_dir):
Expand Down Expand Up @@ -628,7 +630,8 @@ def _extract_model(model_uri, sagemaker_session, tmp):
else:
local_model_path = model_uri.replace("file://", "")
with tarfile.open(name=local_model_path, mode="r:gz") as t:
t.extractall(path=tmp_model_dir)
check_tarfile_data_filter_attribute()
t.extractall(path=tmp_model_dir, filter="data")
return tmp_model_dir


Expand Down Expand Up @@ -1489,3 +1492,25 @@ def format_tags(tags: Tags) -> List[TagsDict]:
return [{"Key": str(k), "Value": str(v)} for k, v in tags.items()]

return tags


class PythonVersionError(Exception):
"""Raise when a secure [/patched] version of Python is not used."""


def check_tarfile_data_filter_attribute():
"""Check if tarfile has data_filter utility.

Tarfile-data_filter utility has guardrails against untrusted de-serialisation.

Raises:
PythonVersionError: if `tarfile.data_filter` is not available.
"""
# The function and it's usages can be deprecated post support of python >= 3.12
if not hasattr(tarfile, "data_filter"):
raise PythonVersionError(
f"Since tarfile extraction is unsafe the operation is prohibited "
f"per PEP-721. Please update your Python [{sys.version}] "
f"to latest patch [refer to https://www.python.org/downloads/] "
f"to consume the security patch"
)
10 changes: 8 additions & 2 deletions src/sagemaker/workflow/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@
Step,
ConfigurableRetryStep,
)
from sagemaker.utils import _save_model, download_file_from_url, format_tags
from sagemaker.utils import (
_save_model,
download_file_from_url,
format_tags,
check_tarfile_data_filter_attribute,
)
from sagemaker.workflow.retry import RetryPolicy
from sagemaker.workflow.utilities import trim_request_dict

Expand Down Expand Up @@ -257,7 +262,8 @@ def _inject_repack_script_and_launcher(self):
download_file_from_url(self._source_dir, old_targz_path, self.sagemaker_session)

with tarfile.open(name=old_targz_path, mode="r:gz") as t:
t.extractall(path=targz_contents_dir)
check_tarfile_data_filter_attribute()
t.extractall(path=targz_contents_dir, filter="data")

shutil.copy2(fname, os.path.join(targz_contents_dir, REPACK_SCRIPT))
with open(
Expand Down
5 changes: 4 additions & 1 deletion tests/integ/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import boto3
from six.moves.urllib.parse import urlparse

from sagemaker.utils import check_tarfile_data_filter_attribute


def assert_s3_files_exist(sagemaker_session, s3_url, files):
parsed_url = urlparse(s3_url)
Expand Down Expand Up @@ -55,4 +57,5 @@ def extract_files_from_s3(s3_url, tmpdir, sagemaker_session):
s3.Bucket(parsed_url.netloc).download_file(parsed_url.path.lstrip("/"), model)

with tarfile.open(model, "r") as tar_file:
tar_file.extractall(tmpdir)
check_tarfile_data_filter_attribute()
tar_file.extractall(tmpdir, filter="data")
Original file line number Diff line number Diff line change
Expand Up @@ -272,4 +272,4 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path):

mock_path.assert_called_once_with(js_model_dir)
mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz")
mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir)
mock_resource_obj.extractall.assert_called_once_with(path=js_model_dir, filter="data")
Original file line number Diff line number Diff line change
Expand Up @@ -156,4 +156,4 @@ def test_extract_js_resources_success(self, mock_tarfile, mock_path):

mock_path.assert_called_once_with(js_model_dir)
mock_path_obj.joinpath.assert_called_once_with(f"infer-prepack-{MOCK_JUMPSTART_ID}.tar.gz")
mock_resource_obj.extractall.assert_called_once_with(path=code_dir)
mock_resource_obj.extractall.assert_called_once_with(path=code_dir, filter="data")
5 changes: 3 additions & 2 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from mock import Mock, patch

from sagemaker import fw_utils
from sagemaker.utils import name_from_image
from sagemaker.utils import name_from_image, check_tarfile_data_filter_attribute
from sagemaker.session_settings import SessionSettings
from sagemaker.instance_group import InstanceGroup

Expand Down Expand Up @@ -424,7 +424,8 @@ def list_tar_files(folder, tar_ball, tmpdir):
startpath = str(tmpdir.ensure(folder, dir=True))

with tarfile.open(name=tar_ball, mode="r:gz") as t:
t.extractall(path=startpath)
check_tarfile_data_filter_attribute()
t.extractall(path=startpath, filter="data")

def walk():
for root, dirs, files in os.walk(startpath):
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
resolve_nested_dict_value_from_config,
update_list_of_dicts_with_values_from_config,
volume_size_supported,
PythonVersionError,
check_tarfile_data_filter_attribute,
)
from tests.unit.sagemaker.workflow.helpers import CustomStep
from sagemaker.workflow.parameters import ParameterString, ParameterInteger
Expand Down Expand Up @@ -1748,3 +1750,15 @@ def test_instance_family_from_full_instance_type(self):

for instance_type, family in instance_type_to_family_test_dict.items():
self.assertEqual(family, get_instance_type_family(instance_type))


class TestCheckTarfileDataFilterAttribute(TestCase):
def test_check_tarfile_data_filter_attribute_unhappy_case(self):
with pytest.raises(PythonVersionError):
with patch("tarfile.data_filter", None):
delattr(tarfile, "data_filter")
check_tarfile_data_filter_attribute()

def test_check_tarfile_data_filter_attribute_happy_case(self):
with patch("tarfile.data_filter", "some_value"):
check_tarfile_data_filter_attribute()