diff --git a/src/sagemaker/remote_function/core/serialization.py b/src/sagemaker/remote_function/core/serialization.py index 821744ee6b..229cf1ed0d 100644 --- a/src/sagemaker/remote_function/core/serialization.py +++ b/src/sagemaker/remote_function/core/serialization.py @@ -141,7 +141,12 @@ def deserialize(s3_uri: str, bytes_to_deserialize: bytes) -> Any: return cloudpickle.loads(bytes_to_deserialize) except Exception as e: raise DeserializationError( - "Error when deserializing bytes downloaded from {}: {}".format(s3_uri, repr(e)) + "Error when deserializing bytes downloaded from {}: {}. " + "NOTE: this may be caused by inconsistent sagemaker python sdk versions " + "where remote function runs versus the one used on client side. " + "If the sagemaker versions do not match, a warning message would " + "be logged starting with 'Inconsistent sagemaker versions found'. " + "Please check it to validate.".format(s3_uri, repr(e)) ) from e diff --git a/src/sagemaker/remote_function/job.py b/src/sagemaker/remote_function/job.py index 71530ac4dd..a854de9135 100644 --- a/src/sagemaker/remote_function/job.py +++ b/src/sagemaker/remote_function/job.py @@ -786,6 +786,12 @@ def compile( container_args.extend( ["--client_python_version", RuntimeEnvironmentManager()._current_python_version()] ) + container_args.extend( + [ + "--client_sagemaker_pysdk_version", + RuntimeEnvironmentManager()._current_sagemaker_pysdk_version(), + ] + ) container_args.extend( [ "--dependency_settings", diff --git a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py index 0fbc926aae..d5d879cb08 100644 --- a/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py +++ b/src/sagemaker/remote_function/runtime_environment/bootstrap_runtime_environment.py @@ -56,6 +56,7 @@ def main(sys_args=None): try: args = _parse_args(sys_args) client_python_version = args.client_python_version + client_sagemaker_pysdk_version = args.client_sagemaker_pysdk_version job_conda_env = args.job_conda_env pipeline_execution_id = args.pipeline_execution_id dependency_settings = _DependencySettings.from_string(args.dependency_settings) @@ -64,6 +65,9 @@ def main(sys_args=None): conda_env = job_conda_env or os.getenv("SAGEMAKER_JOB_CONDA_ENV") RuntimeEnvironmentManager()._validate_python_version(client_python_version, conda_env) + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + client_sagemaker_pysdk_version + ) user = getpass.getuser() if user != "root": @@ -274,6 +278,7 @@ def _parse_args(sys_args): parser = argparse.ArgumentParser() parser.add_argument("--job_conda_env", type=str) parser.add_argument("--client_python_version", type=str) + parser.add_argument("--client_sagemaker_pysdk_version", type=str, default=None) parser.add_argument("--pipeline_execution_id", type=str) parser.add_argument("--dependency_settings", type=str) parser.add_argument("--func_step_s3_dir", type=str) diff --git a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py index 97ca4f08e4..0dd5f0d219 100644 --- a/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py +++ b/src/sagemaker/remote_function/runtime_environment/runtime_environment_manager.py @@ -24,6 +24,8 @@ import dataclasses import json +import sagemaker + class _UTCFormatter(logging.Formatter): """Class that overrides the default local time provider in log formatter.""" @@ -326,6 +328,11 @@ def _current_python_version(self): return f"{sys.version_info.major}.{sys.version_info.minor}".strip() + def _current_sagemaker_pysdk_version(self): + """Returns the current sagemaker python sdk version where program is running""" + + return sagemaker.__version__ + def _validate_python_version(self, client_python_version: str, conda_env: str = None): """Validate the python version @@ -344,6 +351,29 @@ def _validate_python_version(self, client_python_version: str, conda_env: str = f"is same as the local python version." ) + def _validate_sagemaker_pysdk_version(self, client_sagemaker_pysdk_version): + """Validate the sagemaker python sdk version + + Validates if the sagemaker python sdk version where remote function runs + matches the one used on client side. + Otherwise, log a warning to call out that unexpected behaviors + may occur in this case. + """ + job_sagemaker_pysdk_version = self._current_sagemaker_pysdk_version() + if ( + client_sagemaker_pysdk_version + and client_sagemaker_pysdk_version != job_sagemaker_pysdk_version + ): + logger.warning( + "Inconsistent sagemaker versions found: " + "sagemaker pysdk version found in the container is " + "'%s' which does not match the '%s' on the local client. " + "Please make sure that the python version used in the training container " + "is the same as the local python version in case of unexpected behaviors.", + job_sagemaker_pysdk_version, + client_sagemaker_pysdk_version, + ) + def _run_and_get_output_shell_cmd(cmd: str) -> str: """Run and return the output of the given shell command""" diff --git a/tests/unit/sagemaker/remote_function/core/test_serialization.py b/tests/unit/sagemaker/remote_function/core/test_serialization.py index b2d5ec6a88..98280af51b 100644 --- a/tests/unit/sagemaker/remote_function/core/test_serialization.py +++ b/tests/unit/sagemaker/remote_function/core/test_serialization.py @@ -198,7 +198,8 @@ def square(x): with pytest.raises( DeserializationError, match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: " - + r"RuntimeError\('some failure when loads'\)", + + r"RuntimeError\('some failure when loads'\). " + + r"NOTE: this may be caused by inconsistent sagemaker python sdk versions", ): deserialize_func_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) @@ -397,7 +398,8 @@ def __init__(self, x): with pytest.raises( DeserializationError, match=rf"Error when deserializing bytes downloaded from {s3_uri}/payload.pkl: " - + r"RuntimeError\('some failure when loads'\)", + + r"RuntimeError\('some failure when loads'\). " + + r"NOTE: this may be caused by inconsistent sagemaker python sdk versions", ): deserialize_obj_from_s3(sagemaker_session=Mock(), s3_uri=s3_uri, hmac_key=HMAC_KEY) diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py index ee83388a15..b7d9e10047 100644 --- a/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_bootstrap_runtime_environment.py @@ -27,6 +27,7 @@ CURR_WORKING_DIR = "/user/set/workdir" TEST_DEPENDENCIES_PATH = "/user/set/workdir/sagemaker_remote_function_workspace" TEST_PYTHON_VERSION = "3.10" +TEST_SAGEMAKER_PYSDK_VERSION = "2.205.0" TEST_WORKSPACE_ARCHIVE_DIR_PATH = "/opt/ml/input/data/sm_rf_user_ws" TEST_WORKSPACE_ARCHIVE_PATH = "/opt/ml/input/data/sm_rf_user_ws/workspace.zip" TEST_EXECUTION_ID = "test_execution_id" @@ -44,6 +45,8 @@ def args_for_remote(): TEST_JOB_CONDA_ENV, "--client_python_version", TEST_PYTHON_VERSION, + "--client_sagemaker_pysdk_version", + TEST_SAGEMAKER_PYSDK_VERSION, "--dependency_settings", _DependencySettings(TEST_DEPENDENCY_FILE_NAME).to_string(), ] @@ -55,6 +58,8 @@ def args_for_step(): TEST_JOB_CONDA_ENV, "--client_python_version", TEST_PYTHON_VERSION, + "--client_sagemaker_pysdk_version", + TEST_SAGEMAKER_PYSDK_VERSION, "--pipeline_execution_id", TEST_EXECUTION_ID, "--func_step_s3_dir", @@ -63,6 +68,10 @@ def args_for_step(): @patch("sys.exit") +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_sagemaker_pysdk_version" +) @patch( "sagemaker.remote_function.runtime_environment.runtime_environment_manager." "RuntimeEnvironmentManager._validate_python_version" @@ -90,12 +99,75 @@ def test_main_success_remote_job_with_root_user( run_pre_exec_script, bootstrap_runtime, validate_python, + validate_sagemaker, _exit_process, ): bootstrap.main(args_for_remote()) change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) + bootstrap_remote.assert_called_once_with( + TEST_PYTHON_VERSION, + TEST_JOB_CONDA_ENV, + _DependencySettings(TEST_DEPENDENCY_FILE_NAME), + ) + run_pre_exec_script.assert_not_called() + bootstrap_runtime.assert_not_called() + _exit_process.assert_called_with(0) + + +@patch("sys.exit") +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_sagemaker_pysdk_version" +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_python_version" +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager.bootstrap" +) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager.run_pre_exec_script" +) +@patch( + "sagemaker.remote_function.runtime_environment.bootstrap_runtime_environment." + "_bootstrap_runtime_env_for_remote_function" +) +@patch("getpass.getuser", MagicMock(return_value="root")) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager.change_dir_permission" +) +def test_main_success_with_obsoleted_args_that_missing_sagemaker_version( + change_dir_permission, + bootstrap_remote, + run_pre_exec_script, + bootstrap_runtime, + validate_python, + validate_sagemaker, + _exit_process, +): + # This test is to test the backward compatibility + # In old version of SDK, the client side sagemaker_pysdk_version is not passed to job + # thus it would be None and would not lead to the warning + obsoleted_args = [ + "--job_conda_env", + TEST_JOB_CONDA_ENV, + "--client_python_version", + TEST_PYTHON_VERSION, + "--dependency_settings", + _DependencySettings(TEST_DEPENDENCY_FILE_NAME).to_string(), + ] + bootstrap.main(obsoleted_args) + + change_dir_permission.assert_not_called() + validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + validate_sagemaker.assert_called_once_with(None) bootstrap_remote.assert_called_once_with( TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV, @@ -107,6 +179,10 @@ def test_main_success_remote_job_with_root_user( @patch("sys.exit") +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_sagemaker_pysdk_version" +) @patch( "sagemaker.remote_function.runtime_environment.runtime_environment_manager." "RuntimeEnvironmentManager._validate_python_version" @@ -134,11 +210,13 @@ def test_main_success_pipeline_step_with_root_user( run_pre_exec_script, bootstrap_runtime, validate_python, + validate_sagemaker, _exit_process, ): bootstrap.main(args_for_step()) change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) bootstrap_step.assert_called_once_with( TEST_PYTHON_VERSION, FUNC_STEP_WORKSPACE, @@ -150,6 +228,10 @@ def test_main_success_pipeline_step_with_root_user( _exit_process.assert_called_with(0) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_sagemaker_pysdk_version" +) @patch( "sagemaker.remote_function.runtime_environment.runtime_environment_manager." "RuntimeEnvironmentManager._validate_python_version" @@ -178,6 +260,7 @@ def test_main_failure_remote_job_with_root_user( write_failure, _exit_process, validate_python, + validate_sagemaker, ): runtime_err = RuntimeEnvironmentError("some failure reason") bootstrap_runtime.side_effect = runtime_err @@ -186,12 +269,17 @@ def test_main_failure_remote_job_with_root_user( change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) run_pre_exec_script.assert_not_called() bootstrap_runtime.assert_called() write_failure.assert_called_with(str(runtime_err)) _exit_process.assert_called_with(1) +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_sagemaker_pysdk_version" +) @patch( "sagemaker.remote_function.runtime_environment.runtime_environment_manager." "RuntimeEnvironmentManager._validate_python_version" @@ -220,6 +308,7 @@ def test_main_failure_pipeline_step_with_root_user( write_failure, _exit_process, validate_python, + validate_sagemaker, ): runtime_err = RuntimeEnvironmentError("some failure reason") bootstrap_runtime.side_effect = runtime_err @@ -228,6 +317,7 @@ def test_main_failure_pipeline_step_with_root_user( change_dir_permission.assert_not_called() validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) run_pre_exec_script.assert_not_called() bootstrap_runtime.assert_called() write_failure.assert_called_with(str(runtime_err)) @@ -235,6 +325,10 @@ def test_main_failure_pipeline_step_with_root_user( @patch("sys.exit") +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_sagemaker_pysdk_version" +) @patch( "sagemaker.remote_function.runtime_environment.runtime_environment_manager." "RuntimeEnvironmentManager._validate_python_version" @@ -262,6 +356,7 @@ def test_main_remote_job_with_non_root_user( run_pre_exec_script, bootstrap_runtime, validate_python, + validate_sagemaker, _exit_process, ): bootstrap.main(args_for_remote()) @@ -270,6 +365,7 @@ def test_main_remote_job_with_non_root_user( dirs=bootstrap.JOB_OUTPUT_DIRS, new_permission="777" ) validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) bootstrap_remote.assert_called_once_with( TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV, @@ -281,6 +377,10 @@ def test_main_remote_job_with_non_root_user( @patch("sys.exit") +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._validate_sagemaker_pysdk_version" +) @patch( "sagemaker.remote_function.runtime_environment.runtime_environment_manager." "RuntimeEnvironmentManager._validate_python_version" @@ -308,6 +408,7 @@ def test_main_pipeline_step_with_non_root_user( run_pre_exec_script, bootstrap_runtime, validate_python, + validate_sagemaker, _exit_process, ): bootstrap.main(args_for_step()) @@ -316,6 +417,7 @@ def test_main_pipeline_step_with_non_root_user( dirs=bootstrap.JOB_OUTPUT_DIRS, new_permission="777" ) validate_python.assert_called_once_with(TEST_PYTHON_VERSION, TEST_JOB_CONDA_ENV) + validate_sagemaker.assert_called_once_with(TEST_SAGEMAKER_PYSDK_VERSION) bootstrap_step.assert_called_once_with( TEST_PYTHON_VERSION, FUNC_STEP_WORKSPACE, diff --git a/tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py b/tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py index 45198f3388..ce14a8c977 100644 --- a/tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py +++ b/tests/unit/sagemaker/remote_function/runtime_environment/test_runtime_environment_manager.py @@ -30,6 +30,7 @@ TEST_REQUIREMENTS_TXT = "usr/local/requirements.txt" TEST_CONDA_YML = "usr/local/conda_env.yml" CLIENT_PYTHON_VERSION = "3.10" +JOB_SAGEMAKER_PYSDK_VERSION = "2.205.0" def test_snapshot_no_dependencies(): @@ -371,6 +372,32 @@ def test_validate_python_version_error(python_version_in_conda_env): RuntimeEnvironmentManager()._validate_python_version(CLIENT_PYTHON_VERSION, "conda_env") +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._current_sagemaker_pysdk_version", + return_value=JOB_SAGEMAKER_PYSDK_VERSION, +) +def test_validate_sagemaker_pysdk_version(mock_sagemaker_version_in_job): + # If the client sagemaker version differs from the job's, a warning is printed + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version( + "version-not-the-same-and-get-a-warning" + ) + mock_sagemaker_version_in_job.assert_called_once() + + +@patch( + "sagemaker.remote_function.runtime_environment.runtime_environment_manager." + "RuntimeEnvironmentManager._current_sagemaker_pysdk_version", + return_value=JOB_SAGEMAKER_PYSDK_VERSION, +) +def test_validate_sagemaker_pysdk_version_with_none_input(mock_sagemaker_version_in_job): + # This test is to test the backward compatibility + # In old version of SDK, the client side sagemaker_pysdk_version is not passed to job + # thus it would be None and would not lead to the warning + RuntimeEnvironmentManager()._validate_sagemaker_pysdk_version(None) + mock_sagemaker_version_in_job.assert_called_once() + + @patch("os.path.isfile", return_value=True) @patch( "sagemaker.remote_function.runtime_environment.runtime_environment_manager._log_error", Mock() diff --git a/tests/unit/sagemaker/remote_function/test_job.py b/tests/unit/sagemaker/remote_function/test_job.py index ac321d4de0..98961ad80d 100644 --- a/tests/unit/sagemaker/remote_function/test_job.py +++ b/tests/unit/sagemaker/remote_function/test_job.py @@ -382,6 +382,7 @@ def test_start( local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() mock_script_upload.assert_called_once_with( spark_config=None, @@ -441,6 +442,8 @@ def test_start( TEST_REGION, "--client_python_version", mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": null}', "--run_in_context", @@ -510,6 +513,7 @@ def test_start_with_checkpoint_location( ) mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() session().sagemaker_client.create_training_job.assert_called_once_with( TrainingJobName=job.job_name, @@ -555,6 +559,8 @@ def test_start_with_checkpoint_location( TEST_REGION, "--client_python_version", mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": null}', "--run_in_context", @@ -657,6 +663,7 @@ def test_start_with_complete_job_settings( local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() mock_bootstrap_script_upload.assert_called_once_with( spark_config=None, @@ -716,6 +723,8 @@ def test_start_with_complete_job_settings( TEST_REGION, "--client_python_version", mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": "req.txt"}', "--s3_kms_key", @@ -824,6 +833,7 @@ def test_get_train_args_under_pipeline_context( local_dependencies_path = mock_runtime_manager().snapshot() mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() mock_bootstrap_scripts_upload.assert_called_once_with( spark_config=None, @@ -905,6 +915,8 @@ def test_get_train_args_under_pipeline_context( TEST_REGION, "--client_python_version", mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": "req.txt"}', "--s3_kms_key", @@ -989,6 +1001,7 @@ def test_start_with_spark( job = _Job.start(job_settings, job_function, func_args=(1, 2), func_kwargs={"c": 3, "d": 4}) mock_python_version = mock_runtime_manager()._current_python_version() + mock_sagemaker_pysdk_version = mock_runtime_manager()._current_sagemaker_pysdk_version() assert job.job_name.startswith("job-function") @@ -1062,6 +1075,8 @@ def test_start_with_spark( TEST_REGION, "--client_python_version", mock_python_version, + "--client_sagemaker_pysdk_version", + mock_sagemaker_pysdk_version, "--dependency_settings", '{"dependency_file": null}', "--run_in_context",