Skip to content

HyperparameterTuner does not use spot instances #5584

@toddstep

Description

@toddstep

PySDK Version
PySDK V3 (3.x)

Describe the bug
HyperparameterTuner.tune() does not use the enable_managed_spot_training flag that is provided under the model_trainer.compute object.
Rather, it launches the training jobs with EnableManagedSpotTraining set to False.

To reproduce
Two files:
training.py:

print("Test set: Average loss: 0.9")

demo.py (a simplified version of SageMaker V3 Hyperparameter Tuning Example):

# V3 Imports
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import Compute, SourceCode, StoppingCondition
from sagemaker.train.tuner import HyperparameterTuner
from sagemaker.core.parameter import ContinuousParameter
from sagemaker.core.helper.session_helper import Session, get_execution_role

sagemaker_session = Session()
region = sagemaker_session.boto_region_name
role = get_execution_role()

# Configure source code
source_code = SourceCode(
    source_dir=".",
    entry_script="training.py"
)

# Configure compute resources
compute = Compute(
    enable_managed_spot_training=True
)

# Configure stopping condition
stopping_condition = StoppingCondition(
    max_runtime_in_seconds=3600,  # 1 hour
    max_wait_time_in_seconds=3600
)

# Get PyTorch training image
training_image = f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-training:1.10.0-gpu-py38"

# Create ModelTrainer
model_trainer = ModelTrainer(
    training_image=training_image,
    source_code=source_code,
    compute=compute,
    stopping_condition=stopping_condition,
    role=role,
)

# Define hyperparameter ranges to tune
hyperparameter_ranges = {
    "lr": ContinuousParameter(0.001, 0.1),
}

# Define objective metric
objective_metric_name = "average test loss"

# Define metric definitions
metric_definitions = [
    {
        "Name": "average test loss",
        "Regex": "Test set: Average loss: ([0-9\\.]+)"
    }
]

# Create HyperparameterTuner
tuner = HyperparameterTuner(
    model_trainer=model_trainer,
    objective_metric_name=objective_metric_name,
    hyperparameter_ranges=hyperparameter_ranges,
    metric_definitions=metric_definitions,
)

tuner.tune(wait=True)

job = next(tuner.describe().get_all_training_jobs())
describe_job = sagemaker_session.sagemaker_client.describe_training_job(
    TrainingJobName=job.training_job_name
)
print("EnableManagedSpotTraining", describe_job['EnableManagedSpotTraining'])

Expected behavior
When running the provided script, the final line should be:

EnableManagedSpotTraining True

Screenshots or logs

$ python demo.py 
sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/sagemaker-user/.config/sagemaker/config.yaml
[02/26/26 16:47:04] INFO     SageMaker session not provided. Using default Session.                                            defaults.py:61
                    INFO     Base name not provided. Using default name:                                                       defaults.py:90
                             pytorch-training-job                                                                                            
                    INFO     Instance type not provided. Using default:                                                       defaults.py:105
                             ml.m5.xlarge                                                                                                    
                    INFO     Instance count not provided. Using default:                                                      defaults.py:108
                             1                                                                                                               
                    INFO     OutputDataConfig not provided. Using default:                                                    defaults.py:150
                             s3_output_path='s3://sagemaker-us-west-2-xxxxxxxxxxxx/pytorch-training-job' kms_key_id=None                     
                             compression_type='GZIP'                                                                                         
                             remove_job_name_from_s3_output_path=<sagemaker.core.utils.utils.Unassigned object at                            
                             0x7f298f0fa150> disable_model_upload=<sagemaker.core.utils.utils.Unassigned object at                           
                             0x7f298f0fa150> channels=<sagemaker.core.utils.utils.Unassigned object at 0x7f298f0fa150>                       
                    INFO     Training image URI:                                                                         model_trainer.py:548
                             763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.10.0-gpu-py38                                   
                    INFO     Creating hyper_parameter_tuning_job resource.                                                 resources.py:17214
                    INFO     Runs on sagemaker prod, region:us-west-2                                                            utils.py:354
[02/26/26 16:52:42] INFO     Final Resource Status: Completed                                                              resources.py:17443
EnableManagedSpotTraining False

System information
A description of your system. Please provide:

  • SageMaker Python SDK version: 3.4.1
  • Framework name (eg. PyTorch) or algorithm (eg. KMeans): PyTorch
  • Framework version: 1.10.0
  • Python version: 3.12.9
  • CPU or GPU: CPU
  • Custom Docker image (Y/N): N

Additional context
The following branch appears to fix this issue:

toddstep/sagemaker-python-sdk/tree/fix_hypertuner_managed_spot

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions