-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Logic to detect hardware GPU count and aggregate GPU memory size in MiB #4389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
ecbe66f
Add logic to detect hardware GPU count and aggregate GPU memory size …
27a620a
Fix all formatting
e5d7c16
Merge branch 'master' into master
makungaj1 cf49ca8
Addressed PR review comments
3b63301
Merge branch 'master' into master
makungaj1 477d6c2
Addressed PR Review messages
6adfa69
Merge branch 'master' into master
makungaj1 27abb4c
Addressed PR Review Messages
51c8649
Addressed PR Review comments
9521f87
Addressed PR Review Comments
7592fac
Add integration tests
680b466
Merge branch 'master' into master
makungaj1 4f63b2b
Merge branch 'master' into master
makungaj1 b19c286
Add config
f3e4a2a
Fix integration tests
e9547fe
Include Instance Types GPU infor Config files
46fb391
Addressed PR review comments
625ba7c
Fix unit tests
8c75b1e
Fix unit test: 'Mock' object is not subscriptable
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Accessors to retrieve instance types GPU info.""" | ||
| from __future__ import absolute_import | ||
|
|
||
| import json | ||
| import os | ||
| from typing import Dict | ||
|
|
||
|
|
||
| def retrieve(region: str) -> Dict[str, Dict[str, int]]: | ||
| """Retrieves instance types GPU info of the given region. | ||
|
|
||
| Args: | ||
| region (str): The AWS region. | ||
|
|
||
| Returns: | ||
| dict[str, dict[str, int]]: A dictionary that contains instance types as keys | ||
| and GPU info as values or empty dictionary if the | ||
| config for the given region is not found. | ||
|
|
||
| Raises: | ||
| ValueError: If no config found. | ||
| """ | ||
| config_path = os.path.join( | ||
| os.path.dirname(__file__), "image_uri_config", "instance_gpu_info.json" | ||
| ) | ||
| try: | ||
| with open(config_path) as f: | ||
| instance_types_gpu_info_config = json.load(f) | ||
| return instance_types_gpu_info_config.get(region, {}) | ||
| except FileNotFoundError: | ||
| raise ValueError("Could not find instance types gpu info.") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| """Utilities for detecting available GPUs and Aggregate GPU Memory size of an instance""" | ||
| from __future__ import absolute_import | ||
|
|
||
| import logging | ||
| from typing import Tuple | ||
|
|
||
| from botocore.exceptions import ClientError | ||
|
|
||
| from sagemaker import Session | ||
| from sagemaker import instance_types_gpu_info | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _get_gpu_info(instance_type: str, session: Session) -> Tuple[int, int]: | ||
| """Get GPU info for the provided instance | ||
|
|
||
| Args: | ||
| instance_type (str) | ||
| session: The session to use. | ||
|
|
||
| Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0, | ||
| and aggregate memory size in MiB at index 1. | ||
|
|
||
| Raises: | ||
| ValueError: If The given instance type does not exist or GPU is not enabled. | ||
| """ | ||
| ec2_client = session.boto_session.client("ec2") | ||
| ec2_instance = _format_instance_type(instance_type) | ||
|
|
||
| try: | ||
| instance_info = ec2_client.describe_instance_types(InstanceTypes=[ec2_instance]).get( | ||
| "InstanceTypes" | ||
| )[0] | ||
| except ClientError: | ||
| raise ValueError(f"Provided instance_type is not GPU enabled: [#{ec2_instance}]") | ||
|
|
||
| if instance_info is not None: | ||
| gpus_info = instance_info.get("GpuInfo") | ||
| if gpus_info is not None: | ||
| gpus = gpus_info.get("Gpus") | ||
| if gpus is not None and len(gpus) > 0: | ||
| count = gpus[0].get("Count") | ||
| total_gpu_memory_in_mib = gpus_info.get("TotalGpuMemoryInMiB") | ||
| if count and total_gpu_memory_in_mib: | ||
| instance_gpu_info = ( | ||
| count, | ||
| total_gpu_memory_in_mib, | ||
| ) | ||
| logger.info("GPU Info [%s]: %s", ec2_instance, instance_gpu_info) | ||
| return instance_gpu_info | ||
|
|
||
| raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]") | ||
|
|
||
|
|
||
| def _get_gpu_info_fallback(instance_type: str, region: str) -> Tuple[int, int]: | ||
| """Get GPU info for the provided from the config | ||
|
|
||
| Args: | ||
| instance_type (str): | ||
| region: The AWS region. | ||
|
|
||
| Returns: tuple[int, int]: A tuple that contains number of GPUs available at index 0, | ||
| and aggregate memory size in MiB at index 1. | ||
|
|
||
| Raises: | ||
| ValueError: If The given instance type does not exist. | ||
| """ | ||
| instance_types_gpu_info_config = instance_types_gpu_info.retrieve(region) | ||
| fallback_instance_gpu_info = instance_types_gpu_info_config.get(instance_type) | ||
|
|
||
| ec2_instance = _format_instance_type(instance_type) | ||
| if fallback_instance_gpu_info is None: | ||
| raise ValueError(f"Provided instance_type is not GPU enabled: [{ec2_instance}]") | ||
|
|
||
| fallback_instance_gpu_info = ( | ||
| fallback_instance_gpu_info.get("Count"), | ||
| fallback_instance_gpu_info.get("TotalGpuMemoryInMiB"), | ||
| ) | ||
| logger.info("GPU Info [%s]: %s", ec2_instance, fallback_instance_gpu_info) | ||
| return fallback_instance_gpu_info | ||
|
|
||
|
|
||
| def _format_instance_type(instance_type: str) -> str: | ||
| """Formats provided instance type name | ||
|
|
||
| Args: | ||
| instance_type (str): | ||
|
|
||
| Returns: formatted instance type. | ||
| """ | ||
| split_instance = instance_type.split(".") | ||
|
|
||
| if len(split_instance) > 2: | ||
| split_instance.pop(0) | ||
|
|
||
| ec2_instance = ".".join(split_instance) | ||
| return ec2_instance | ||
44 changes: 44 additions & 0 deletions
44
tests/integ/sagemaker/serve/utils/test_hardware_detector.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,44 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| from __future__ import absolute_import | ||
|
|
||
| import pytest | ||
|
|
||
| from sagemaker.serve.utils import hardware_detector | ||
|
|
||
| REGION = "us-west-2" | ||
| VALID_INSTANCE_TYPE = "ml.g5.48xlarge" | ||
| INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge" | ||
| EXPECTED_INSTANCE_GPU_INFO = (8, 196608) | ||
|
|
||
|
|
||
| def test_get_gpu_info_success(sagemaker_session): | ||
| gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session) | ||
|
|
||
| assert gpu_info == EXPECTED_INSTANCE_GPU_INFO | ||
|
|
||
|
|
||
| def test_get_gpu_info_throws(sagemaker_session): | ||
| with pytest.raises(ValueError): | ||
| hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session) | ||
|
|
||
|
|
||
| def test_get_gpu_info_fallback_success(): | ||
| gpu_info = hardware_detector._get_gpu_info_fallback(VALID_INSTANCE_TYPE, REGION) | ||
|
|
||
| assert gpu_info == EXPECTED_INSTANCE_GPU_INFO | ||
|
|
||
|
|
||
| def test_get_gpu_info_fallback_throws(): | ||
| with pytest.raises(ValueError): | ||
| hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION) |
98 changes: 98 additions & 0 deletions
98
tests/unit/sagemaker/serve/utils/test_hardware_detector.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,98 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| from __future__ import absolute_import | ||
|
|
||
| from botocore.exceptions import ClientError | ||
| import pytest | ||
|
|
||
| from sagemaker.serve.utils import hardware_detector | ||
|
|
||
| REGION = "us-west-2" | ||
| VALID_INSTANCE_TYPE = "ml.g5.48xlarge" | ||
| INVALID_INSTANCE_TYPE = "fl.c5.57xxlarge" | ||
| EXPECTED_INSTANCE_GPU_INFO = (8, 196608) | ||
|
|
||
|
|
||
| def test_get_gpu_info_success(sagemaker_session, boto_session): | ||
| boto_session.client("ec2").describe_instance_types.return_value = { | ||
| "InstanceTypes": [ | ||
| { | ||
| "GpuInfo": { | ||
| "Gpus": [ | ||
| { | ||
| "Name": "A10G", | ||
| "Manufacturer": "NVIDIA", | ||
| "Count": 8, | ||
| "MemoryInfo": {"SizeInMiB": 24576}, | ||
| } | ||
| ], | ||
| "TotalGpuMemoryInMiB": 196608, | ||
| }, | ||
| } | ||
| ] | ||
| } | ||
|
|
||
| instance_gpu_info = hardware_detector._get_gpu_info(VALID_INSTANCE_TYPE, sagemaker_session) | ||
|
|
||
| boto_session.client("ec2").describe_instance_types.assert_called_once_with( | ||
| InstanceTypes=["g5.48xlarge"] | ||
| ) | ||
| assert instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO | ||
|
|
||
|
|
||
| def test_get_gpu_info_throws(sagemaker_session, boto_session): | ||
| boto_session.client("ec2").describe_instance_types.return_value = {"InstanceTypes": [{}]} | ||
|
|
||
| with pytest.raises(ValueError): | ||
| hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session) | ||
|
|
||
|
|
||
| def test_get_gpu_info_describe_instance_types_throws(sagemaker_session, boto_session): | ||
| boto_session.client("ec2").describe_instance_types.side_effect = ClientError( | ||
| { | ||
| "Error": { | ||
| "Code": "InvalidInstanceType", | ||
| "Message": f"An error occurred (InvalidInstanceType) when calling the DescribeInstanceTypes " | ||
| f"operation: The following supplied instance types do not exist: [{INVALID_INSTANCE_TYPE}]", | ||
| } | ||
| }, | ||
| "DescribeInstanceTypes", | ||
| ) | ||
|
|
||
| with pytest.raises(ValueError): | ||
| hardware_detector._get_gpu_info(INVALID_INSTANCE_TYPE, sagemaker_session) | ||
|
|
||
|
|
||
| def test_get_gpu_info_fallback_success(): | ||
| fallback_instance_gpu_info = hardware_detector._get_gpu_info_fallback( | ||
| VALID_INSTANCE_TYPE, REGION | ||
| ) | ||
|
|
||
| assert fallback_instance_gpu_info == EXPECTED_INSTANCE_GPU_INFO | ||
|
|
||
|
|
||
| def test_get_gpu_info_fallback_throws(): | ||
| with pytest.raises(ValueError): | ||
| hardware_detector._get_gpu_info_fallback(INVALID_INSTANCE_TYPE, REGION) | ||
|
|
||
|
|
||
| def test_format_instance_type_success(): | ||
| formatted_instance_type = hardware_detector._format_instance_type(VALID_INSTANCE_TYPE) | ||
|
|
||
| assert formatted_instance_type == "g5.48xlarge" | ||
|
|
||
|
|
||
| def test_format_instance_type_without_ml_success(): | ||
| formatted_instance_type = hardware_detector._format_instance_type("g5.48xlarge") | ||
|
|
||
| assert formatted_instance_type == "g5.48xlarge" |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"). You | ||
| # may not use this file except in compliance with the License. A copy of | ||
| # the License is located at | ||
| # | ||
| # http://aws.amazon.com/apache2.0/ | ||
| # | ||
| # or in the "license" file accompanying this file. This file is | ||
| # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF | ||
| # ANY KIND, either express or implied. See the License for the specific | ||
| # language governing permissions and limitations under the License. | ||
| from __future__ import absolute_import | ||
|
|
||
| from sagemaker import instance_types_gpu_info | ||
|
|
||
| REGION = "us-west-2" | ||
| INVALID_REGION = "invalid-region" | ||
|
|
||
|
|
||
| def test_retrieve_success(): | ||
| data = instance_types_gpu_info.retrieve(REGION) | ||
|
|
||
| assert len(data) > 0 | ||
|
|
||
|
|
||
| def test_retrieve_throws(): | ||
| data = instance_types_gpu_info.retrieve(INVALID_REGION) | ||
|
|
||
| assert len(data) == 0 |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we only care about GpuInfo, iirc, inf2 and trn instances store this info under
InferenceAcceleratorInfoThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Scope of this milestone is GPU, inferentia support is out of scope due to some blockers.