-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Open
Description
PySDK Version
PySDK V3 (3.x)
Describe the bug
HyperparameterTuner.tune() does not use the enable_managed_spot_training flag that is provided under the model_trainer.compute object.
Rather, it launches the training jobs with EnableManagedSpotTraining set to False.
To reproduce
Two files:
training.py:
print("Test set: Average loss: 0.9")
demo.py (a simplified version of SageMaker V3 Hyperparameter Tuning Example):
# V3 Imports
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import Compute, SourceCode, StoppingCondition
from sagemaker.train.tuner import HyperparameterTuner
from sagemaker.core.parameter import ContinuousParameter
from sagemaker.core.helper.session_helper import Session, get_execution_role
sagemaker_session = Session()
region = sagemaker_session.boto_region_name
role = get_execution_role()
# Configure source code
source_code = SourceCode(
source_dir=".",
entry_script="training.py"
)
# Configure compute resources
compute = Compute(
enable_managed_spot_training=True
)
# Configure stopping condition
stopping_condition = StoppingCondition(
max_runtime_in_seconds=3600, # 1 hour
max_wait_time_in_seconds=3600
)
# Get PyTorch training image
training_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:1.10.0-gpu-py38"
# Create ModelTrainer
model_trainer = ModelTrainer(
training_image=training_image,
source_code=source_code,
compute=compute,
stopping_condition=stopping_condition,
role=role,
)
# Define hyperparameter ranges to tune
hyperparameter_ranges = {
"lr": ContinuousParameter(0.001, 0.1),
}
# Define objective metric
objective_metric_name = "average test loss"
# Define metric definitions
metric_definitions = [
{
"Name": "average test loss",
"Regex": "Test set: Average loss: ([0-9\\.]+)"
}
]
# Create HyperparameterTuner
tuner = HyperparameterTuner(
model_trainer=model_trainer,
objective_metric_name=objective_metric_name,
hyperparameter_ranges=hyperparameter_ranges,
metric_definitions=metric_definitions,
)
tuner.tune(wait=True)
job = next(tuner.describe().get_all_training_jobs())
describe_job = sagemaker_session.sagemaker_client.describe_training_job(
TrainingJobName=job.training_job_name
)
print("EnableManagedSpotTraining", describe_job['EnableManagedSpotTraining'])
Expected behavior
When running the provided script, the final line should be:
EnableManagedSpotTraining True
Screenshots or logs
$ python demo.py
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml
[02/26/26 16:47:04] INFO SageMaker session not provided. Using default Session. defaults.py:61
INFO Base name not provided. Using default name: defaults.py:90
pytorch-training-job
INFO Instance type not provided. Using default: defaults.py:105
ml.m5.xlarge
INFO Instance count not provided. Using default: defaults.py:108
1
INFO OutputDataConfig not provided. Using default: defaults.py:150
s3_output_path='s3://sagemaker-us-west-2-xxxxxxxxxxxx/pytorch-training-job' kms_key_id=None
compression_type='GZIP'
remove_job_name_from_s3_output_path=<sagemaker.core.utils.utils.Unassigned object at
0x7f298f0fa150> disable_model_upload=<sagemaker.core.utils.utils.Unassigned object at
0x7f298f0fa150> channels=<sagemaker.core.utils.utils.Unassigned object at 0x7f298f0fa150>
INFO Training image URI: model_trainer.py:548
763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38
INFO Creating hyper_parameter_tuning_job resource. resources.py:17214
INFO Runs on sagemaker prod, region:us-west-2 utils.py:354
[02/26/26 16:52:42] INFO Final Resource Status: Completed resources.py:17443
EnableManagedSpotTraining False
System information
A description of your system. Please provide:
- SageMaker Python SDK version: 3.4.1
- Framework name (eg. PyTorch) or algorithm (eg. KMeans): PyTorch
- Framework version: 1.10.0
- Python version: 3.12.9
- CPU or GPU: CPU
- Custom Docker image (Y/N): N
Additional context
The following branch appears to fix this issue:
toddstep/sagemaker-python-sdk/tree/fix_hypertuner_managed_spot
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels