Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
782 changes: 782 additions & 0 deletions src/sagemaker/image_uri_config/instance_gpu_info.json

Large diffs are not rendered by default.

43 changes: 43 additions & 0 deletions src/sagemaker/instance_types_gpu_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve instance types GPU info."""
from __future__ import absolute_import

import json
import os
from typing import Dict


def retrieve(region: str) -> Dict[str, Dict[str, int]]:
"""Retrieves instance types GPU info of the given region.

Args:
region (str): The AWS region.

Returns:
dict[str, dict[str, int]]: A dictionary that contains instance types as keys
and GPU info as values or empty dictionary if the
config for the given region is not found.

Raises:
ValueError: If no config found.
"""
config_path = os.path.join(
os.path.dirname(__file__), "image_uri_config", "instance_gpu_info.json"
)
try:
with open(config_path) as f:
instance_types_gpu_info_config = json.load(f)
return instance_types_gpu_info_config.get(region, {})
except FileNotFoundError:
raise ValueError("Could not find instance types gpu info.")
110 changes: 110 additions & 0 deletions src/sagemaker/serve/utils/hardware_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utilities for detecting available GPUs and Aggregate GPU Memory size of an instance"""
from __future__ import absolute_import

import logging
from typing import Tuple

from botocore.exceptions import ClientError

from sagemaker import Session
from sagemaker import instance_types_gpu_info

logger = logging.getLogger(__name__)


def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]:
"""Get GPU info for the provided instance

Args:
instance_type (str)
session: The session to use.

Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
and aggregate memory size in MiB at index 1.

Raises:
ValueError: If The given instance type does not exist or GPU is not enabled.
"""
ec2_client = session.boto_session.client("ec2")
ec2_instance = _format_instance_type(instance_type)

try:
instance_info = ec2_client.describe_instance_types(InstanceTypes=[ec2_instance]).get(
"InstanceTypes"
)[0]
except ClientError:
raise ValueError(f"Provided instance_type is not GPU enabled: [#{ec2_instance}]")

if instance_info is not None:
gpus_info = instance_info.get("GpuInfo")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we only care about GpuInfo, iirc, inf2 and trn instances store this info under InferenceAcceleratorInfo

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scope of this milestone is GPU, inferentia support is out of scope due to some blockers.

if gpus_info is not None:
gpus = gpus_info.get("Gpus")
if gpus is not None and len(gpus) > 0:
count = gpus[0].get("Count")
total_gpu_memory_in_mib = gpus_info.get("TotalGpuMemoryInMiB")
if count and total_gpu_memory_in_mib:
instance_gpu_info = (
count,
total_gpu_memory_in_mib,
)
logger.info("GPU Info [%s]: %s", ec2_instance, instance_gpu_info)
return instance_gpu_info

raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")


def _get_gpu_info_fallback(instance_type: str, region: str) -> Tuple[int, int]:
"""Get GPU info for the provided from the config

Args:
instance_type (str):
region: The AWS region.

Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0,
and aggregate memory size in MiB at index 1.

Raises:
ValueError: If The given instance type does not exist.
"""
instance_types_gpu_info_config = instance_types_gpu_info.retrieve(region)
fallback_instance_gpu_info = instance_types_gpu_info_config.get(instance_type)

ec2_instance = _format_instance_type(instance_type)
if fallback_instance_gpu_info is None:
raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]")

fallback_instance_gpu_info = (
fallback_instance_gpu_info.get("Count"),
fallback_instance_gpu_info.get("TotalGpuMemoryInMiB"),
)
logger.info("GPU Info [%s]: %s", ec2_instance, fallback_instance_gpu_info)
return fallback_instance_gpu_info


def _format_instance_type(instance_type: str) -> str:
"""Formats provided instance type name

Args:
instance_type (str):

Returns: formatted instance type.
"""
split_instance = instance_type.split(".")

if len(split_instance) > 2:
split_instance.pop(0)

ec2_instance = ".".join(split_instance)
return ec2_instance
44 changes: 44 additions & 0 deletions tests/integ/sagemaker/serve/utils/test_hardware_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import pytest

from sagemaker.serve.utils import hardware_detector

REGION = "us-west-2"
VALID_INSTANCE_TYPE = "ml.g5.48xlarge"
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge"
EXPECTED_INSTANCE_GPU_INFO = (8, 196608)


def test_get_gpu_info_success(sagemaker_session):
gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session)

assert gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_throws(sagemaker_session):
with pytest.raises(ValueError):
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)


def test_get_gpu_info_fallback_success():
gpu_info = hardware_detector._get_gpu_info_fallback(VALID_INSTANCE_TYPE, REGION)

assert gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_fallback_throws():
with pytest.raises(ValueError):
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION)
98 changes: 98 additions & 0 deletions tests/unit/sagemaker/serve/utils/test_hardware_detector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from botocore.exceptions import ClientError
import pytest

from sagemaker.serve.utils import hardware_detector

REGION = "us-west-2"
VALID_INSTANCE_TYPE = "ml.g5.48xlarge"
INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge"
EXPECTED_INSTANCE_GPU_INFO = (8, 196608)


def test_get_gpu_info_success(sagemaker_session, boto_session):
boto_session.client("ec2").describe_instance_types.return_value = {
"InstanceTypes": [
{
"GpuInfo": {
"Gpus": [
{
"Name": "A10G",
"Manufacturer": "NVIDIA",
"Count": 8,
"MemoryInfo": {"SizeInMiB": 24576},
}
],
"TotalGpuMemoryInMiB": 196608,
},
}
]
}

instance_gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session)

boto_session.client("ec2").describe_instance_types.assert_called_once_with(
InstanceTypes=["g5.48xlarge"]
)
assert instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_throws(sagemaker_session, boto_session):
boto_session.client("ec2").describe_instance_types.return_value = {"InstanceTypes": [{}]}

with pytest.raises(ValueError):
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)


def test_get_gpu_info_describe_instance_types_throws(sagemaker_session, boto_session):
boto_session.client("ec2").describe_instance_types.side_effect = ClientError(
{
"Error": {
"Code": "InvalidInstanceType",
"Message": f"An error occurred (InvalidInstanceType) when calling the DescribeInstanceTypes "
f"operation: The following supplied instance types do not exist: [{INVALID_INSTANCE_TYPE}]",
}
},
"DescribeInstanceTypes",
)

with pytest.raises(ValueError):
hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session)


def test_get_gpu_info_fallback_success():
fallback_instance_gpu_info = hardware_detector._get_gpu_info_fallback(
VALID_INSTANCE_TYPE, REGION
)

assert fallback_instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO


def test_get_gpu_info_fallback_throws():
with pytest.raises(ValueError):
hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION)


def test_format_instance_type_success():
formatted_instance_type = hardware_detector._format_instance_type(VALID_INSTANCE_TYPE)

assert formatted_instance_type == "g5.48xlarge"


def test_format_instance_type_without_ml_success():
formatted_instance_type = hardware_detector._format_instance_type("g5.48xlarge")

assert formatted_instance_type == "g5.48xlarge"
30 changes: 30 additions & 0 deletions tests/unit/sagemaker/test_instance_types_gpu_info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

from sagemaker import instance_types_gpu_info

REGION = "us-west-2"
INVALID_REGION = "invalid-region"


def test_retrieve_success():
data = instance_types_gpu_info.retrieve(REGION)

assert len(data) > 0


def test_retrieve_throws():
data = instance_types_gpu_info.retrieve(INVALID_REGION)

assert len(data) == 0