Skip to content

fix: jumpstart estimator for gated uncompressed training #5175

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

68 changes: 64 additions & 4 deletions src/sagemaker/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1310,6 +1310,7 @@ def fit(
logs: str = "All",
job_name: Optional[str] = None,
experiment_config: Optional[Dict[str, str]] = None,
accept_eula: Optional[bool] = None,
):
"""Train a model using the input training dataset.

Expand Down Expand Up @@ -1363,14 +1364,21 @@ def fit(
* Both `ExperimentName` and `TrialName` will be ignored if the Estimator instance
is built with :class:`~sagemaker.workflow.pipeline_context.PipelineSession`.
However, the value of `TrialComponentDisplayName` is honored for display in Studio.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
Returns:
None or pipeline step arguments in case the Estimator instance is built with
:class:`~sagemaker.workflow.pipeline_context.PipelineSession`
"""
self._prepare_for_training(job_name=job_name)

experiment_config = check_and_get_run_experiment_config(experiment_config)
self.latest_training_job = _TrainingJob.start_new(self, inputs, experiment_config)
self.latest_training_job = _TrainingJob.start_new(
self, inputs, experiment_config, accept_eula
)
self.jobs.append(self.latest_training_job)
forward_to_mlflow_tracking_server = False
if os.environ.get("MLFLOW_TRACKING_URI") and self.enable_network_isolation():
Expand Down Expand Up @@ -2484,7 +2492,7 @@ class _TrainingJob(_Job):
"""Placeholder docstring"""

@classmethod
def start_new(cls, estimator, inputs, experiment_config):
def start_new(cls, estimator, inputs, experiment_config, accept_eula=None):
"""Create a new Amazon SageMaker training job from the estimator.

Args:
Expand All @@ -2504,19 +2512,24 @@ def start_new(cls, estimator, inputs, experiment_config):
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
* `RunName` is used to record an experiment run.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).
Returns:
sagemaker.estimator._TrainingJob: Constructed object that captures
all information about the started training job.
"""
train_args = cls._get_train_args(estimator, inputs, experiment_config)
train_args = cls._get_train_args(estimator, inputs, experiment_config, accept_eula)

logger.debug("Train args after processing defaults: %s", train_args)
estimator.sagemaker_session.train(**train_args)

return cls(estimator.sagemaker_session, estimator._current_job_name)

@classmethod
def _get_train_args(cls, estimator, inputs, experiment_config):
def _get_train_args(cls, estimator, inputs, experiment_config, accept_eula=None):
"""Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.

Args:
Expand All @@ -2536,6 +2549,11 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
will be unassociated.
* `TrialComponentDisplayName` is used for display in Studio.
* `RunName` is used to record an experiment run.
accept_eula (bool): For models that require a Model Access Config, specify True or
False to indicate whether model terms of use have been accepted.
The `accept_eula` value must be explicitly defined as `True` in order to
accept the end-user license agreement (EULA) that some
models require. (Default: None).

Returns:
Dict: dict for `sagemaker.session.Session.train` method
Expand Down Expand Up @@ -2652,6 +2670,9 @@ def _get_train_args(cls, estimator, inputs, experiment_config):
if estimator.get_session_chaining_config() is not None:
train_args["session_chaining_config"] = estimator.get_session_chaining_config()

if accept_eula is not None:
cls._set_accept_eula_for_input_data_config(train_args, accept_eula)

return train_args

@classmethod
Expand All @@ -2674,6 +2695,45 @@ def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):
raise ValueError("Setting checkpoint_local_path is not supported in local mode.")
train_args["checkpoint_local_path"] = estimator.checkpoint_local_path

@classmethod
def _set_accept_eula_for_input_data_config(cls, train_args, accept_eula):
"""Set the AcceptEula flag for all input data configurations.

This method sets the AcceptEula flag in the ModelAccessConfig for all S3DataSources
in the InputDataConfig array. It handles cases where keys might not exist in the
nested dictionary structure.

Args:
train_args (dict): The training job arguments dictionary
accept_eula (bool): The value to set for AcceptEula flag
"""
if "InputDataConfig" not in train_args:
return

if accept_eula is None:
return

eula_count = 0
s3_uris = []

for idx in range(len(train_args["InputDataConfig"])):
if "DataSource" in train_args["InputDataConfig"][idx]:
data_source = train_args["InputDataConfig"][idx]["DataSource"]
if "S3DataSource" in data_source:
s3_data_source = data_source["S3DataSource"]
if "ModelAccessConfig" not in s3_data_source:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Who sets this ModelAccessConfig? Is it the default artifacts set by JumpStart or is it something user would explicitly add to their inputs?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's set by us when the customer inputs accept_eula=True

s3_data_source["ModelAccessConfig"] = {}
s3_data_source["ModelAccessConfig"]["AcceptEula"] = accept_eula
eula_count += 1

# Collect S3 URI if available
if "S3Uri" in s3_data_source:
s3_uris.append(s3_data_source["S3Uri"])

# Log info if more than one EULA needs to be accepted
if eula_count > 1:
logger.info("Accepting EULA for %d S3 data sources: %s", eula_count, ", ".join(s3_uris))

@classmethod
def _is_local_channel(cls, input_uri):
"""Placeholder docstring"""
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def fit(
sagemaker_session=self.sagemaker_session,
config_name=self.config_name,
hub_access_config=self.hub_access_config,
accept_eula=accept_eula,
)
remove_env_var_from_estimator_kwargs_if_model_access_config_present(
self.init_kwargs, self.model_access_config
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,7 @@ def get_fit_kwargs(
sagemaker_session: Optional[Session] = None,
config_name: Optional[str] = None,
hub_access_config: Optional[Dict] = None,
accept_eula: Optional[bool] = None,
) -> JumpStartEstimatorFitKwargs:
"""Returns kwargs required call `fit` on `sagemaker.estimator.Estimator` object."""

Expand All @@ -283,6 +284,7 @@ def get_fit_kwargs(
tolerate_vulnerable_model=tolerate_vulnerable_model,
sagemaker_session=sagemaker_session,
config_name=config_name,
accept_eula=accept_eula,
)

estimator_fit_kwargs, _ = _set_temp_sagemaker_session_if_not_set(kwargs=estimator_fit_kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1940,9 +1940,6 @@ def use_inference_script_uri(self) -> bool:

def use_training_model_artifact(self) -> bool:
"""Returns True if the model should use a model uri when kicking off training job."""
# gated model never use training model artifact
if self.gated_bucket:
return False

# otherwise, return true is a training model package is not set
return len(self.training_model_package_artifact_uris or {}) == 0
Expand Down Expand Up @@ -2595,6 +2592,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
"sagemaker_session",
"config_name",
"specs",
"accept_eula",
]

SERIALIZATION_EXCLUSION_SET = {
Expand Down Expand Up @@ -2625,6 +2623,7 @@ def __init__(
tolerate_vulnerable_model: Optional[bool] = None,
sagemaker_session: Optional[Session] = None,
config_name: Optional[str] = None,
accept_eula: Optional[bool] = None,
) -> None:
"""Instantiates JumpStartEstimatorInitKwargs object."""

Expand All @@ -2642,6 +2641,7 @@ def __init__(
self.tolerate_vulnerable_model = tolerate_vulnerable_model
self.sagemaker_session = sagemaker_session
self.config_name = config_name
self.accept_eula = accept_eula


class JumpStartEstimatorDeployKwargs(JumpStartKwargs):
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,7 @@ def deploy(
container_startup_health_check_timeout=container_startup_health_check_timeout,
explainer_config_dict=explainer_config_dict,
async_inference_config_dict=async_inference_config_dict,
serverless_inference_config_dict=serverless_inference_config_dict,
serverless_inference_config=serverless_inference_config_dict,
routing_config=routing_config,
inference_ami_version=inference_ami_version,
)
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/sagemaker/experiments/test_run_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_auto_pass_in_exp_config_to_train_job(mock_start_job, run_obj, sagemaker
assert _RunContext.get_current_run() == run_obj

expected_exp_config = run_obj.experiment_config
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config)
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config, None)

# _RunContext is cleaned up after exiting the with statement
assert not _RunContext.get_current_run()
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_auto_pass_in_exp_config_under_load_run(
assert loaded_run.experiment_config == run_obj.experiment_config

expected_exp_config = run_obj.experiment_config
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config)
mock_start_job.assert_called_once_with(estimator, _train_input_path, expected_exp_config, None)

# _RunContext is cleaned up after exiting the with statement
assert not _RunContext.get_current_run()
Expand Down Expand Up @@ -174,7 +174,7 @@ def test_user_supply_exp_config_to_train_job(mock_start_job, run_obj, sagemaker_

assert _RunContext.get_current_run() == run_obj

mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg)
mock_start_job.assert_called_once_with(estimator, _train_input_path, supplied_exp_cfg, None)

# _RunContext is cleaned up after exiting the with statement
assert not _RunContext.get_current_run()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,7 @@ def test_gated_model_s3_uri_with_eula_in_fit(
inputs=channels,
wait=True,
job_name="meta-textgeneration-llama-2-7b-f-8675309",
accept_eula=True,
)

assert hasattr(estimator, "model_access_config")
Expand Down Expand Up @@ -688,6 +689,7 @@ def test_gated_model_non_model_package_s3_uri(
instance_count=1,
image_uri="763104351884.dkr.ecr.us-west-2.amazonaws.com/huggingface-pyt"
"orch-training:2.0.0-transformers4.28.1-gpu-py310-cu118-ubuntu20.04",
model_uri="s3://jumpstart-private-cache-prod-us-west-2/some/dummy/key",
source_dir="s3://jumpstart-cache-prod-us-west-2/source-d"
"irectory-tarballs/meta/transfer_learning/textgeneration/prepack/v1.0.1/sourcedir.tar.gz",
entry_point="transfer_learning.py",
Expand Down Expand Up @@ -1346,7 +1348,7 @@ def test_jumpstart_estimator_kwargs_match_parent_class(self):
and reach out to JumpStart team."""

init_args_to_skip: Set[str] = set(["kwargs"])
fit_args_to_skip: Set[str] = set(["accept_eula"])
fit_args_to_skip: Set[str] = set([])
deploy_args_to_skip: Set[str] = set(["kwargs"])

parent_class_init = Estimator.__init__
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/sagemaker/jumpstart/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def test_use_training_model_artifact():
specs1 = JumpStartModelSpecs(BASE_SPEC)
assert specs1.use_training_model_artifact()
specs1.gated_bucket = True
assert not specs1.use_training_model_artifact()
assert specs1.use_training_model_artifact()
specs1.gated_bucket = False
specs1.training_model_package_artifact_uris = {"region1": "blah", "region2": "blah2"}
assert not specs1.use_training_model_artifact()
Expand Down
Loading
Loading