Skip to content
66 changes: 52 additions & 14 deletions src/sagemaker/jumpstart/artifacts/resource_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"""This module contains functions for obtaining JumpStart resoure requirements."""
from __future__ import absolute_import

from typing import Optional
from typing import Dict, Optional, Tuple

from sagemaker.jumpstart.constants import (
DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -28,6 +28,20 @@
from sagemaker.session import Session
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP: Dict[
str, Dict[str, Tuple[str, str]]
] = {
"requests": {
"num_accelerators": ("num_accelerators", "num_accelerators"),
"num_cpus": ("num_cpus", "num_cpus"),
"copies": ("copies", "copy_count"),
"min_memory_mb": ("memory", "min_memory"),
},
"limits": {
"max_memory_mb": ("memory", "max_memory"),
},
}


def _retrieve_default_resources(
model_id: str,
Expand All @@ -37,6 +51,7 @@ def _retrieve_default_resources(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
instance_type: Optional[str] = None,
) -> ResourceRequirements:
"""Retrieves the default resource requirements for the model.

Expand All @@ -60,6 +75,8 @@ def _retrieve_default_resources(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
host requirements specific for the instance type.
Returns:
str: The default resource requirements to use for the model or None.

Expand Down Expand Up @@ -87,23 +104,44 @@ def _retrieve_default_resources(
is_dynamic_container_deployment_supported = (
model_specs.dynamic_container_deployment_supported
)
default_resource_requirements = model_specs.hosting_resource_requirements
default_resource_requirements: Dict[str, int] = (
model_specs.hosting_resource_requirements or {}
)
else:
raise NotImplementedError(
f"Unsupported script scope for retrieving default resource requirements: '{scope}'"
)

instance_specific_resource_requirements: Dict[str, int] = (
model_specs.hosting_instance_type_variants.get_instance_specific_resource_requirements(
instance_type
)
if instance_type
and getattr(model_specs, "hosting_instance_type_variants", None) is not None
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could we simplify to hasattr()?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this because model_specs.hosting_instance_type_variants could be defined but equal None

else {}
)

default_resource_requirements = {
**default_resource_requirements,
**instance_specific_resource_requirements,
}

if is_dynamic_container_deployment_supported:
requests = {}
if "num_accelerators" in default_resource_requirements:
requests["num_accelerators"] = default_resource_requirements["num_accelerators"]
if "min_memory_mb" in default_resource_requirements:
requests["memory"] = default_resource_requirements["min_memory_mb"]
if "num_cpus" in default_resource_requirements:
requests["num_cpus"] = default_resource_requirements["num_cpus"]

limits = {}
if "max_memory_mb" in default_resource_requirements:
limits["memory"] = default_resource_requirements["max_memory_mb"]
return ResourceRequirements(requests=requests, limits=limits)

all_resource_requirement_kwargs = {}

for (
requirement_type,
spec_field_to_resource_requirement_map,
) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.items():
requirement_kwargs = {}
for spec_field, resource_requirement in spec_field_to_resource_requirement_map.items():
if spec_field in default_resource_requirements:
requirement_kwargs[resource_requirement[0]] = default_resource_requirements[
spec_field
]

all_resource_requirement_kwargs[requirement_type] = requirement_kwargs

return ResourceRequirements(**all_resource_requirement_kwargs)
return None
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,7 @@ def _add_resources_to_kwargs(kwargs: JumpStartModelInitKwargs) -> JumpStartModel
tolerate_deprecated_model=kwargs.tolerate_deprecated_model,
tolerate_vulnerable_model=kwargs.tolerate_vulnerable_model,
sagemaker_session=kwargs.sagemaker_session,
instance_type=kwargs.instance_type,
)

return kwargs
Expand Down
23 changes: 23 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,29 @@ def get_instance_specific_artifact_key(self, instance_type: str) -> Optional[str
instance_type=instance_type, property_name="artifact_key"
)

def get_instance_specific_resource_requirements(self, instance_type: str) -> Optional[str]:
"""Returns instance specific resource requirements.

If a value exists for both the instance family and instance type, the instance type value
is chosen.
"""

instance_specific_resource_requirements: dict = (
self.variants.get(instance_type, {})
.get("properties", {})
.get("resource_requirements", {})
)

instance_type_family = get_instance_type_family(instance_type)

instance_family_resource_requirements: dict = (
self.variants.get(instance_type_family, {})
.get("properties", {})
.get("resource_requirements", {})
)

return {**instance_family_resource_requirements, **instance_specific_resource_requirements}

def _get_instance_specific_property(
self, instance_type: str, property_name: str
) -> Optional[str]:
Expand Down
7 changes: 6 additions & 1 deletion src/sagemaker/resource_requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import logging
from typing import Optional
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements

from sagemaker.jumpstart import utils as jumpstart_utils
from sagemaker.jumpstart import artifacts
Expand All @@ -33,7 +34,8 @@ def retrieve_default(
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
) -> str:
instance_type: Optional[str] = None,
) -> ResourceRequirements:
"""Retrieves the default resource requirements for the model matching the given arguments.

Args:
Expand All @@ -56,6 +58,8 @@ def retrieve_default(
object, used for SageMaker interactions. If not
specified, one is created using the default AWS configuration
chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
instance_type (str): An instance type to optionally supply in order to get
host requirements specific for the instance type.
Returns:
str: The default resource requirements to use for the model.

Expand All @@ -79,4 +83,5 @@ def retrieve_default(
tolerate_vulnerable_model,
tolerate_deprecated_model,
sagemaker_session=sagemaker_session,
instance_type=instance_type,
)
20 changes: 20 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,22 @@
"model_package_arn": "$gpu_model_package_arn",
}
},
"g5": {
"properties": {
"resource_requirements": {
"num_accelerators": 888810,
"randon-field-2": 2222,
}
}
},
"m2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"c2": {"regional_properties": {"image_uri": "$cpu_image_uri"}},
"ml.g5.xlarge": {
"properties": {
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"},
"resource_requirements": {"num_accelerators": 10},
}
},
"ml.g5.48xlarge": {
"properties": {"environment_variables": {"TENSOR_PARALLEL_DEGREE": "8"}}
},
Expand All @@ -857,6 +871,12 @@
"framework_version": "1.5.0",
"py_version": "py3",
},
"dynamic_container_deployment_supported": True,
"hosting_resource_requirements": {
"min_memory_mb": 81999,
"num_accelerators": 1,
"random_field_1": 1,
},
"hosting_artifact_key": "pytorch-infer/infer-pytorch-ic-mobilenet-v2.tar.gz",
"training_artifact_key": "pytorch-training/train-pytorch-ic-mobilenet-v2.tar.gz",
"hosting_script_key": "source-directory-tarballs/pytorch/inference/ic/v1.0.0/sourcedir.tar.gz",
Expand Down
23 changes: 23 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"variants": {
"ml.p2.12xlarge": {
"properties": {
"resource_requirements": {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9},
"environment_variables": {"TENSOR_PARALLEL_DEGREE": "4"},
"supported_inference_instance_types": ["ml.p5.xlarge"],
"default_inference_instance_type": "ml.p5.xlarge",
Expand All @@ -60,6 +61,11 @@
"p2": {
"regional_properties": {"image_uri": "$gpu_image_uri"},
"properties": {
"resource_requirements": {
"req2": {"2": 5, "9": 999},
"req3": 999,
"req4": "blah",
},
"supported_inference_instance_types": ["ml.p2.xlarge", "ml.p3.xlarge"],
"default_inference_instance_type": "ml.p2.xlarge",
"metrics": [
Expand Down Expand Up @@ -879,3 +885,20 @@ def test_jumpstart_training_artifact_key_instance_variants():
)
is None
)


def test_jumpstart_resource_requirements_instance_variants():
assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
instance_type="ml.p2.xlarge"
) == {"req2": {"2": 5, "9": 999}, "req3": 999, "req4": "blah"}

assert INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
instance_type="ml.p2.12xlarge"
) == {"req1": 1, "req2": {"1": 2, "2": 3}, "req3": 9, "req4": "blah"}

assert (
INSTANCE_TYPE_VARIANT.get_instance_specific_resource_requirements(
instance_type="ml.p99.12xlarge"
)
== {}
)
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
import pytest

from sagemaker import resource_requirements
from sagemaker.compute_resource_requirements.resource_requirements import ResourceRequirements
from sagemaker.jumpstart.artifacts.resource_requirements import (
REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP,
)

from tests.unit.sagemaker.jumpstart.utils import get_spec_from_base_spec, get_special_model_spec

Expand Down Expand Up @@ -50,6 +54,55 @@ def test_jumpstart_resource_requirements(patched_get_model_specs):
patched_get_model_specs.reset_mock()


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_resource_requirements_instance_type_variants(patched_get_model_specs):

patched_get_model_specs.side_effect = get_special_model_spec
region = "us-west-2"
mock_client = boto3.client("s3")
mock_session = Mock(s3_client=mock_client)

model_id, model_version = "variant-model", "*"
default_inference_resource_requirements = resource_requirements.retrieve_default(
region=region,
model_id=model_id,
model_version=model_version,
scope="inference",
sagemaker_session=mock_session,
instance_type="ml.g5.xlarge",
)
assert default_inference_resource_requirements.requests == {
"memory": 81999,
"num_accelerators": 10,
}

default_inference_resource_requirements = resource_requirements.retrieve_default(
region=region,
model_id=model_id,
model_version=model_version,
scope="inference",
sagemaker_session=mock_session,
instance_type="ml.g5.555xlarge",
)
assert default_inference_resource_requirements.requests == {
"memory": 81999,
"num_accelerators": 888810,
}

default_inference_resource_requirements = resource_requirements.retrieve_default(
region=region,
model_id=model_id,
model_version=model_version,
scope="inference",
sagemaker_session=mock_session,
instance_type="ml.f9.555xlarge",
)
assert default_inference_resource_requirements.requests == {
"memory": 81999,
"num_accelerators": 1,
}


@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
patched_get_model_specs.side_effect = get_special_model_spec
Expand Down Expand Up @@ -80,3 +133,18 @@ def test_jumpstart_no_supported_resource_requirements(patched_get_model_specs):
resource_requirements.retrieve_default(
region=region, model_id=model_id, model_version=model_version, scope="training"
)


def test_jumpstart_supports_all_resource_requirement_fields():

all_tracked_resource_requirement_fields = {
field
for requirements in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP.values()
for _, field in requirements.values()
}

excluded_resource_requirement_fields = {"requests", "limits"}
assert (
set(ResourceRequirements().__dict__.keys()) - excluded_resource_requirement_fields
== all_tracked_resource_requirement_fields
)