# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""ModelTrainer class module."""
from __future__ import absolute_import
from enum import Enum
import os
import json
import shutil
from tempfile import TemporaryDirectory
from typing import Optional, List, Union, Dict, Any, ClassVar
import yaml
from graphene.utils.str_converters import to_camel_case, to_snake_case
from sagemaker.core.config.config_manager import SageMakerConfig
from sagemaker.core import resources
from sagemaker.core.resources import TrainingJob
from sagemaker.core import shapes
from sagemaker.core.shapes import AlgorithmSpecification
from sagemaker.core.utils.utils import serialize
from sagemaker.core.apiutils._boto_functions import to_pascal_case
from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call
from sagemaker.core.config.config_schema import (
_simple_path,
SAGEMAKER,
MODEL_TRAINER,
MODULES,
PYTHON_SDK,
TRAINING_JOB_ENVIRONMENT_PATH,
TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
TRAINING_JOB_VPC_CONFIG_PATH,
TRAINING_JOB_SUBNETS_PATH,
TRAINING_JOB_SECURITY_GROUP_IDS_PATH,
TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH,
TRAINING_JOB_RESOURCE_CONFIG_PATH,
TRAINING_JOB_ROLE_ARN_PATH,
TRAINING_JOB_TAGS_PATH,
)
from sagemaker.core.helper.session_helper import Session
from sagemaker.train import configs
from sagemaker.train.configs import (
Compute,
StoppingCondition,
RetryStrategy,
SourceCode,
TrainingImageConfig,
Channel,
DataSource,
S3DataSource,
FileSystemDataSource,
Networking,
Tag,
InfraCheckConfig,
RemoteDebugConfig,
SessionChainingConfig,
InputData,
MetricDefinition,
)
from sagemaker.train.distributed import Torchrun, DistributedConfig
from sagemaker.train.utils import (
_default_s3_uri,
_get_unique_name,
_is_valid_path,
_is_valid_s3_uri,
safe_serialize,
)
from sagemaker.train.types import DataSourceType
from sagemaker.train.constants import (
SM_CODE,
SM_CODE_CONTAINER_PATH,
SM_DRIVERS,
SM_DRIVERS_LOCAL_PATH,
SM_RECIPE,
SM_RECIPE_YAML,
SM_RECIPE_CONTAINER_PATH,
TRAIN_SCRIPT,
DEFAULT_CONTAINER_ENTRYPOINT,
DEFAULT_CONTAINER_ARGUMENTS,
SOURCE_CODE_JSON,
DISTRIBUTED_JSON,
)
from sagemaker.train.templates import (
TRAIN_SCRIPT_TEMPLATE,
EXECUTE_BASE_COMMANDS,
EXEUCTE_DISTRIBUTED_DRIVER,
EXECUTE_BASIC_SCRIPT_DRIVER,
INSTALL_AUTO_REQUIREMENTS,
INSTALL_REQUIREMENTS,
)
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
from sagemaker.train import logger
from sagemaker.train.sm_recipes.utils import (
_get_args_from_recipe,
_determine_device_type,
_is_nova_recipe,
_is_llmft_recipe,
_load_base_recipe,
)
from sagemaker.core.jumpstart.configs import JumpStartConfig
from sagemaker.core.jumpstart.document import get_hub_content_and_document
from sagemaker.core.jumpstart.utils import get_eula_url
from sagemaker.train.defaults import TrainDefaults, JumpStartTrainDefaults
from sagemaker.core.workflow.pipeline_context import PipelineSession, runnable_by_pipeline
from sagemaker.core.helper.pipeline_variable import StrPipeVar
from sagemaker.train.local.local_container import _LocalContainer
[docs]
class Mode(Enum):
"""Enum class for training mode."""
LOCAL_CONTAINER = "LOCAL_CONTAINER"
SAGEMAKER_TRAINING_JOB = "SAGEMAKER_TRAINING_JOB"
[docs]
class ModelTrainer(BaseModel):
"""Class that trains a model using AWS SageMaker.
Example:
.. code:: python
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import SourceCode, Compute, InputData
ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
model_trainer = ModelTrainer(
training_image=training_image,
source_code=source_code,
)
train_data = InputData(channel_name="train", data_source="s3://bucket/train")
model_trainer.train(input_data_config=[train_data])
training_job = model_trainer._latest_training_job
Parameters:
training_mode (Mode):
The training mode. Valid values are "Mode.LOCAL_CONTAINER" or
"Mode.SAGEMAKER_TRAINING_JOB".
sagemaker_session (Optiona(Session)):
The SageMaker Session. This object can be used to manage underlying
boto3 clients and to specify aritfact upload paths via the ``default_bucket``
and ``default_bucket_prefix`` attributes.
If not specified, a new session will be created.
role (Optional(str)):
The IAM role ARN for the training job.
If not specified, the default SageMaker execution role will be used.
base_job_name (Optional[str]):
The base name for the training job.
If not specified, a default name will be generated using the algorithm name
or training image name.
source_code (Optional[SourceCode]):
The source code configuration. This is used to configure the source code for
running the training job.
distributed (Optional[DistributedConfig]):
The distributed runner for the training job. This is used to configure
a distributed training job. If specifed, ``source_code`` must also
be provided.
compute (Optional[Compute]):
The compute configuration. This is used to specify the compute resources for
the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
networking (Optional[Networking]):
The networking configuration. This is used to specify the networking settings
for the training job.
stopping_condition (Optional[StoppingCondition]):
The stopping condition. This is used to specify the different stopping
conditions for the training job.
If not specified, will default to 1 hour max run time.
algorithm_name (Optional[str]):
The SageMaker marketplace algorithm name/arn to use for the training job.
algorithm_name cannot be specified if training_image is specified.
training_image (Optional[str]):
The training image URI to use for the training job container.
training_image cannot be specified if algorithm_name is specified.
To find available sagemaker distributed images,
see: https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths
training_image_config (Optional[TrainingImageConfig]):
Training image Config. This is the configuration to use an image from a private
Docker registry for a training job.
output_data_config (Optional[OutputDataConfig]):
The output data configuration. This is used to specify the output data location
for the training job.
If not specified in the session, will default to
``s3://<default_bucket>/<default_prefix>/<base_job_name>/``.
input_data_config (Optional[List[Union[Channel, InputData]]]):
The input data config for the training job.
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
string, local file path string, S3DataSource object, or FileSystemDataSource object.
checkpoint_config (Optional[CheckpointConfig]):
Contains information about the output location for managed spot training checkpoint
data.
training_input_mode (Optional[str]):
The input mode for the training job. Valid values are "Pipe", "File", "FastFile".
Defaults to "File".
environment (Optional[Dict[str, str]]):
The environment variables for the training job.
hyperparameters (Optional[Union[Dict[str, Any], str]):
The hyperparameters for the training job. Can be a dictionary of hyperparameters
or a path to hyperparameters json/yaml file.
tags (Optional[List[Tag]]):
An array of key-value pairs. You can use tags to categorize your AWS resources
in different ways, for example, by purpose, owner, or environment.
local_container_root (Optional[str]):
The local root directory to store artifacts from a training job launched in
"LOCAL_CONTAINER" mode.
"""
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid"
)
training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
sagemaker_session: Optional[Session] = None
role: Optional[str] = None
base_job_name: Optional[str] = None
source_code: Optional[SourceCode] = None
distributed: Optional[DistributedConfig] = None
compute: Optional[Compute] = None
networking: Optional[Networking] = None
stopping_condition: Optional[StoppingCondition] = None
training_image: Optional[StrPipeVar] = None
training_image_config: Optional[TrainingImageConfig] = None
algorithm_name: Optional[StrPipeVar] = None
output_data_config: Optional[shapes.OutputDataConfig] = None
input_data_config: Optional[List[Union[Channel, InputData]]] = None
checkpoint_config: Optional[shapes.CheckpointConfig] = None
training_input_mode: Optional[StrPipeVar] = "File"
environment: Optional[Dict[str, StrPipeVar]] = {}
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
tags: Optional[List[Tag]] = None
local_container_root: Optional[str] = os.getcwd()
# Created Artifacts
_latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None)
# Private TrainingJob Parameters
_tensorboard_output_config: Optional[shapes.TensorBoardOutputConfig] = PrivateAttr(default=None)
_retry_strategy: Optional[RetryStrategy] = PrivateAttr(default=None)
_infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None)
_session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None)
_remote_debug_config: Optional[RemoteDebugConfig] = PrivateAttr(default=None)
_metric_definitions: Optional[List[MetricDefinition]] = PrivateAttr(default=None)
_is_nova_recipe: Optional[bool] = PrivateAttr(default=None)
_is_llmft_recipe: Optional[bool] = PrivateAttr(default=None)
# Private Attributes for Recipes
_temp_recipe_train_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
# Private Attributes for JumpStart
_jumpstart_config: Optional[JumpStartConfig] = PrivateAttr(default=None)
# Private Attributes for AWS_Batch
_temp_code_dir: Optional[TemporaryDirectory] = PrivateAttr(default=None)
CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = [
"role",
"base_job_name",
"source_code",
"compute",
"networking",
"stopping_condition",
"training_image",
"training_image_config",
"algorithm_name",
"output_data_config",
"checkpoint_config",
"training_input_mode",
"environment",
"hyperparameters",
]
SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = {
"source_code": SourceCode,
"compute": Compute,
"networking": Networking,
"stopping_condition": StoppingCondition,
"training_image_config": TrainingImageConfig,
"output_data_config": configs.OutputDataConfig,
"checkpoint_config": configs.CheckpointConfig,
}
config_mgr: SageMakerConfig = SageMakerConfig()
def _populate_intelligent_defaults(self):
"""Function to populate all the possible default configs
Model Trainer specific configs take precedence over the generic training job ones.
"""
self._populate_intelligent_defaults_from_model_trainer_space()
self._populate_intelligent_defaults_from_training_job_space()
def _populate_intelligent_defaults_from_training_job_space(self):
"""Function to populate all the possible default configs from Training Job Space"""
if not self.environment:
self.environment = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_ENVIRONMENT_PATH, sagemaker_session=self.sagemaker_session
)
default_enable_network_isolation = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_ENABLE_NETWORK_ISOLATION_PATH,
sagemaker_session=self.sagemaker_session,
)
default_vpc_config = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_VPC_CONFIG_PATH, sagemaker_session=self.sagemaker_session
)
if not self.networking:
if default_enable_network_isolation is not None or default_vpc_config is not None:
self.networking = Networking(
default_enable_network_isolation=default_enable_network_isolation,
subnets=self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_SUBNETS_PATH
),
security_group_ids=self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_SECURITY_GROUP_IDS_PATH
),
)
else:
if self.networking.enable_network_isolation is None:
self.networking.enable_network_isolation = default_enable_network_isolation
if self.networking.subnets is None:
self.networking.subnets = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_SUBNETS_PATH
)
if self.networking.security_group_ids is None:
self.networking.subnets = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_SUBNETS_PATH
)
if not self.output_data_config:
default_output_data_config = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH
)
if default_output_data_config:
self.output_data_config = configs.OutputDataConfig(
**self._convert_keys_to_snake(default_output_data_config)
)
if not self.compute:
default_resource_config = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_RESOURCE_CONFIG_PATH
)
if default_resource_config:
self.compute = Compute(**self._convert_keys_to_snake(default_resource_config))
if not self.role:
self.role = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_ROLE_ARN_PATH
)
if not self.tags:
self.tags = self.config_mgr.resolve_value_from_config(
config_path=TRAINING_JOB_TAGS_PATH
)
def _convert_keys_to_snake(self, config: dict) -> dict:
"""Utility helper function that converts the keys of a dictionary into snake case"""
return {to_snake_case(key): value for key, value in config.items()}
def _populate_intelligent_defaults_from_model_trainer_space(self):
"""Function to populate all the possible default configs from Model Trainer Space"""
for configurable_attribute in self.CONFIGURABLE_ATTRIBUTES:
if getattr(self, configurable_attribute) is None:
default_config = self.config_mgr.resolve_value_from_config(
config_path=_simple_path(
SAGEMAKER,
PYTHON_SDK,
MODULES,
MODEL_TRAINER,
to_camel_case(configurable_attribute),
),
sagemaker_session=self.sagemaker_session,
)
if default_config is not None:
if configurable_attribute in self.SERIALIZABLE_CONFIG_ATTRIBUTES:
default_config = self.SERIALIZABLE_CONFIG_ATTRIBUTES.get(
configurable_attribute
)(
**default_config # pylint: disable=E1134
)
setattr(self, configurable_attribute, default_config)
def __del__(self):
"""Destructor method to clean up the temporary directory."""
# Clean up the temporary directory if it exists and class was initialized
if hasattr(self, "__pydantic_fields_set__"):
if self._temp_recipe_train_dir is not None:
self._temp_recipe_train_dir.cleanup()
if self._temp_code_dir is not None:
self._temp_code_dir.cleanup()
def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
):
"""Validate that only one of 'training_image' or 'algorithm_name' is provided."""
if not training_image and not algorithm_name:
raise ValueError(
"Atleast one of 'training_image' or 'algorithm_name' must be provided.",
)
if training_image and algorithm_name:
raise ValueError(
"Only one of 'training_image' or 'algorithm_name' must be provided.",
)
def _validate_distributed_config(
self,
source_code: Optional[SourceCode],
distributed: Optional[DistributedConfig],
):
"""Validate the distribution configuration."""
if distributed and not source_code.entry_script:
raise ValueError(
"Must provide 'entry_script' if 'distribution' " + "is provided in 'source_code'.",
)
def _validate_source_code(self, source_code: Optional[SourceCode]):
"""Validate the source code configuration."""
if source_code:
if source_code.requirements or source_code.entry_script:
source_dir = source_code.source_dir
requirements = source_code.requirements
entry_script = source_code.entry_script
if not source_dir:
raise ValueError(
"If 'requirements' or 'entry_script' is provided in 'source_code', "
"'source_dir' must also be provided."
)
if not (
_is_valid_path(source_dir, path_type="Directory")
or _is_valid_s3_uri(source_dir, path_type="Directory")
or (
_is_valid_path(source_dir, path_type="File")
and source_dir.endswith(".tar.gz")
)
or (
_is_valid_s3_uri(source_dir, path_type="File")
and source_dir.endswith(".tar.gz")
)
):
raise ValueError(
f"Invalid 'source_dir' path: {source_dir}. "
"Must be a valid local directory, "
"s3 uri or path to tar.gz file stored locally or in s3."
)
if requirements:
if not source_dir.endswith(".tar.gz"):
if not _is_valid_path(
f"{source_dir}/{requirements}", path_type="File"
) and not _is_valid_s3_uri(
f"{source_dir}/{requirements}", path_type="File"
):
raise ValueError(
f"Invalid 'requirements': {requirements}. "
"Must be a valid file within the 'source_dir'.",
)
if entry_script:
if not source_dir.endswith(".tar.gz"):
if not _is_valid_path(
f"{source_dir}/{entry_script}", path_type="File"
) and not _is_valid_s3_uri(
f"{source_dir}/{entry_script}", path_type="File"
):
raise ValueError(
f"Invalid 'entry_script': {entry_script}. "
"Must be a valid file within the 'source_dir'.",
)
@staticmethod
def _validate_and_fetch_hyperparameters_file(hyperparameters_file: str):
"""Validate and fetch hyperparameters from a file."""
if not os.path.exists(hyperparameters_file):
raise ValueError(f"Hyperparameters file not found: {hyperparameters_file}")
logger.info(f"Loading hyperparameters from file: {hyperparameters_file}")
with open(hyperparameters_file, "r") as f:
contents = f.read()
try:
hyperparameters = json.loads(contents)
logger.debug("Hyperparameters loaded as JSON")
except json.JSONDecodeError:
try:
logger.info(f"contents: {contents}")
hyperparameters = yaml.safe_load(contents)
if not isinstance(hyperparameters, dict):
raise ValueError("YAML contents must be a valid mapping")
logger.info(f"hyperparameters: {hyperparameters}")
logger.debug("Hyperparameters loaded as YAML")
except (yaml.YAMLError, ValueError):
raise ValueError(
f"Invalid hyperparameters file: {hyperparameters_file}. "
"Must be a valid JSON or YAML file."
)
return hyperparameters
[docs]
def model_post_init(self, __context: Any):
"""Post init method to perform custom validation and set default values."""
self._validate_training_image_and_algorithm_name(self.training_image, self.algorithm_name)
self._validate_source_code(self.source_code)
self._validate_distributed_config(self.source_code, self.distributed)
if self.hyperparameters and isinstance(self.hyperparameters, str):
self.hyperparameters = self._validate_and_fetch_hyperparameters_file(
hyperparameters_file=self.hyperparameters
)
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB:
self.sagemaker_session = TrainDefaults.get_sagemaker_session(self.sagemaker_session)
self.role = TrainDefaults.get_role(
role=self.role, sagemaker_session=self.sagemaker_session
)
self.base_job_name = TrainDefaults.get_base_job_name(
base_job_name=self.base_job_name,
algorithm_name=self.algorithm_name,
training_image=self.training_image,
)
self.compute = TrainDefaults.get_compute(compute=self.compute)
self.stopping_condition = TrainDefaults.get_stopping_condition(
stopping_condition=self.stopping_condition
)
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB:
self.output_data_config = TrainDefaults.get_output_data_config(
base_job_name=self.base_job_name,
output_data_config=self.output_data_config,
sagemaker_session=self.sagemaker_session,
)
if self.training_image:
from sagemaker.core.helper.pipeline_variable import PipelineVariable
if isinstance(self.training_image, PipelineVariable):
logger.info("Training image URI: (PipelineVariable - resolved at pipeline execution)")
else:
logger.info(f"Training image URI: {self.training_image}")
def _create_training_job_args(
self,
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
boto3: bool = False,
) -> Dict[str, Any]:
"""Create the training job arguments.
Args:
input_data_config (Optional[List[Union[Channel, InputData]]]):
input_data_config (Optional[List[Union[Channel, InputData]]]):
The input data config for the training job.
Takes a list of Channel objects or a dictionary of channel names to DataSourceType.
DataSourceType can be an S3 URI string, local file path string,
S3DataSource object, or FileSystemDataSource object.
boto3 (bool): Whether to return the arguments in boto3 format. Defaults to False.
By default, the arguments are returned in the format used by the SageMaker Core.
Returns:
Dict[str, Any]: The training job arguments.
"""
self._populate_intelligent_defaults()
current_training_job_name = _get_unique_name(self.base_job_name)
input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input"
final_input_data_config = self.input_data_config.copy() if self.input_data_config else []
if input_data_config:
# merge the inputs with method parameter taking precedence
existing_channels = {input.channel_name: input for input in final_input_data_config}
new_channels = []
for new_input in input_data_config:
if new_input.channel_name in existing_channels:
existing_channels[new_input.channel_name] = new_input
else:
new_channels.append(new_input)
final_input_data_config = list(existing_channels.values()) + new_channels
if self._is_nova_recipe or self._is_llmft_recipe:
for input_data in final_input_data_config:
if input_data.channel_name == SM_RECIPE:
raise ValueError(
"Cannot use reserved channel name 'recipe' as an input channel name "
" for Nova or LLMFT Recipe"
)
recipe_file_path = os.path.join(self._temp_recipe_train_dir.name, SM_RECIPE_YAML)
recipe_channel = self.create_input_data_channel(
channel_name=SM_RECIPE,
data_source=recipe_file_path,
key_prefix=input_data_key_prefix,
)
final_input_data_config.append(recipe_channel)
if self._is_nova_recipe or self._is_llmft_recipe:
self.hyperparameters.update({"sagemaker_recipe_local_path": SM_RECIPE_CONTAINER_PATH})
if final_input_data_config:
final_input_data_config = self._get_input_data_config(
final_input_data_config, input_data_key_prefix
)
if self.checkpoint_config and not self.checkpoint_config.s3_uri:
self.checkpoint_config.s3_uri = _default_s3_uri(
self.sagemaker_session,
f"{self.base_job_name}/{current_training_job_name}/checkpoints",
)
if self._tensorboard_output_config and not self._tensorboard_output_config.s3_output_path:
self._tensorboard_output_config.s3_output_path = _default_s3_uri(
self.sagemaker_session, self.base_job_name
)
string_hyper_parameters = {}
if self.hyperparameters:
for hyper_parameter, value in self.hyperparameters.items():
string_hyper_parameters[hyper_parameter] = safe_serialize(value)
container_entrypoint = None
container_arguments = None
if self.source_code:
if self.training_mode == Mode.LOCAL_CONTAINER:
self._temp_code_dir = TemporaryDirectory(prefix=os.path.join(self.local_container_root + "/"))
else:
self._temp_code_dir = TemporaryDirectory()
# Copy everything under container_drivers/ to a temporary directory
shutil.copytree(SM_DRIVERS_LOCAL_PATH, self._temp_code_dir.name, dirs_exist_ok=True)
# If distributed is provided, overwrite code under <root>/drivers
if self.distributed:
distributed_driver_dir = self.distributed.driver_dir
driver_dir = os.path.join(self._temp_code_dir.name, "distributed_drivers")
shutil.copytree(distributed_driver_dir, driver_dir, dirs_exist_ok=True)
# If source code is provided, create a channel for the source code
# The source code will be mounted at /opt/ml/input/data/code in the container
if self.source_code.source_dir:
source_code_channel = self.create_input_data_channel(
channel_name=SM_CODE,
data_source=self.source_code.source_dir,
key_prefix=input_data_key_prefix,
ignore_patterns=self.source_code.ignore_patterns,
)
final_input_data_config.append(source_code_channel)
self._prepare_train_script(
tmp_dir=self._temp_code_dir,
source_code=self.source_code,
distributed=self.distributed,
)
if isinstance(self.distributed, Torchrun) and self.distributed.smp:
mp_parameters = self.distributed.smp._to_mp_hyperparameters()
string_hyper_parameters.update(mp_parameters)
self._write_source_code_json(tmp_dir=self._temp_code_dir, source_code=self.source_code)
self._write_distributed_json(tmp_dir=self._temp_code_dir, distributed=self.distributed)
# Create an input channel for drivers packaged by the sdk
sm_drivers_channel = self.create_input_data_channel(
channel_name=SM_DRIVERS,
data_source=self._temp_code_dir.name,
key_prefix=input_data_key_prefix,
ignore_patterns=self.source_code.ignore_patterns,
)
final_input_data_config.append(sm_drivers_channel)
# If source_code is provided, we will always use
# the default container entrypoint and arguments
# to execute the sm_train.sh script.
# Any commands generated from the source_code will be
# executed from the sm_train.sh script.
container_entrypoint = DEFAULT_CONTAINER_ENTRYPOINT
container_arguments = DEFAULT_CONTAINER_ARGUMENTS
algorithm_specification = AlgorithmSpecification(
algorithm_name=self.algorithm_name,
training_image=self.training_image,
training_input_mode=self.training_input_mode,
training_image_config=self.training_image_config,
container_entrypoint=container_entrypoint,
container_arguments=container_arguments,
metric_definitions=self._metric_definitions,
)
resource_config = self.compute._to_resource_config()
vpc_config = self.networking._to_vpc_config() if self.networking else None
# Convert tags to dictionaries if they are Tag objects
tags_as_dicts = None
if self.tags:
tags_as_dicts = []
for tag in self.tags:
if hasattr(tag, 'model_dump'):
tags_as_dicts.append(tag.model_dump())
elif isinstance(tag, dict):
tags_as_dicts.append(tag)
else:
# Fallback for any other tag-like object
tags_as_dicts.append({"key": getattr(tag, 'key', ''), "value": getattr(tag, 'value', '')})
# Build training request with snake_case keys (Python SDK convention)
training_request = {
"training_job_name": current_training_job_name,
"algorithm_specification": algorithm_specification,
"hyper_parameters": string_hyper_parameters,
"input_data_config": final_input_data_config,
"resource_config": resource_config,
"vpc_config": vpc_config,
"role_arn": self.role,
"tags": tags_as_dicts,
"stopping_condition": self.stopping_condition,
"output_data_config": self.output_data_config,
"checkpoint_config": self.checkpoint_config,
"environment": self.environment,
"enable_managed_spot_training": self.compute.enable_managed_spot_training,
"enable_inter_container_traffic_encryption": (
self.networking.enable_inter_container_traffic_encryption
if self.networking
else None
),
"enable_network_isolation": (
self.networking.enable_network_isolation if self.networking else None
),
"remote_debug_config": self._remote_debug_config,
"tensor_board_output_config": self._tensorboard_output_config,
"retry_strategy": self._retry_strategy,
"infra_check_config": self._infra_check_config,
"session_chaining_config": self._session_chaining_config,
}
if boto3 or isinstance(self.sagemaker_session, PipelineSession):
if isinstance(self.sagemaker_session, PipelineSession):
training_request.pop("training_job_name", None)
# Convert snake_case to PascalCase for AWS API
pipeline_request = {to_pascal_case(k): v for k, v in training_request.items()}
serialized_request = serialize(pipeline_request)
return serialized_request
return training_request
[docs]
@_telemetry_emitter(feature=Feature.MODEL_TRAINER, func_name="model_trainer.train")
@runnable_by_pipeline
@validate_call
def train(
self,
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
wait: Optional[bool] = True,
logs: Optional[bool] = True,
):
"""Train a model using AWS SageMaker.
Args:
input_data_config (Optional[List[Union[Channel, InputData]]]):
The input data config for the training job.
Takes a list of Channel objects or a dictionary of channel names to DataSourceType.
DataSourceType can be an S3 URI string, local file path string,
S3DataSource object, or FileSystemDataSource object.
wait (Optional[bool]):
Whether to wait for the training job to complete before returning.
Defaults to True.
logs (Optional[bool]):
Whether to display the training container logs while training.
Defaults to True.
"""
training_request = self._create_training_job_args(input_data_config=input_data_config)
# Handle PipelineSession
if self.training_mode == Mode.SAGEMAKER_TRAINING_JOB:
if isinstance(self.sagemaker_session, PipelineSession):
self.sagemaker_session._intercept_create_request(training_request, None, "train")
return
training_job = TrainingJob.create(
session=self.sagemaker_session.boto_session,
**training_request
)
self._latest_training_job = training_job
if wait:
training_job.wait(logs=logs)
if logs and not wait:
logger.warning(
"Not displaing the training container logs as 'wait' is set to False."
)
else:
local_container = _LocalContainer(
training_job_name=training_request["training_job_name"],
instance_type=training_request["resource_config"].instance_type,
instance_count=training_request["resource_config"].instance_count,
image=training_request["algorithm_specification"].training_image,
container_root=self.local_container_root,
sagemaker_session=self.sagemaker_session,
container_entrypoint=training_request["algorithm_specification"].container_entrypoint,
container_arguments=training_request["algorithm_specification"].container_arguments,
input_data_config=training_request["input_data_config"],
hyper_parameters=training_request["hyper_parameters"],
environment=training_request["environment"],
)
local_container.train(wait)
if self._temp_code_dir is not None:
self._temp_code_dir.cleanup()
def _get_input_data_config(
self,
input_data_channels: Optional[List[Union[Channel, InputData]]],
key_prefix: Optional[str] = None,
) -> List[Channel]:
"""Get the input data configuration for the training job.
Args:
input_data_channels (Optional[List[Union[Channel, InputData]]]):
The input data config for the training job.
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
string, local file path string, S3DataSource object, or FileSystemDataSource object.
"""
if input_data_channels is None:
return []
channels = []
for input_data in input_data_channels:
if isinstance(input_data, Channel):
channels.append(input_data)
elif isinstance(input_data, InputData):
channel = self.create_input_data_channel(
input_data.channel_name, input_data.data_source, key_prefix=key_prefix
)
if input_data.content_type:
channel.content_type = input_data.content_type
channels.append(channel)
else:
raise ValueError(
f"Invalid input data channel: {input_data}. "
"Must be a Channel or InputDataSource."
)
return channels
def _write_source_code_json(self, tmp_dir: TemporaryDirectory, source_code: SourceCode):
"""Write the source code configuration to a JSON file."""
file_path = os.path.join(tmp_dir.name, SOURCE_CODE_JSON)
with open(file_path, "w") as f:
dump = source_code.model_dump() if source_code else {}
f.write(json.dumps(dump))
def _write_distributed_json(
self,
tmp_dir: TemporaryDirectory,
distributed: Optional[DistributedConfig] = None,
):
"""Write the distributed runner configuration to a JSON file."""
file_path = os.path.join(tmp_dir.name, DISTRIBUTED_JSON)
with open(file_path, "w") as f:
dump = distributed.model_dump() if distributed else {}
f.write(json.dumps(dump))
def _prepare_train_script(
self,
tmp_dir: TemporaryDirectory,
source_code: SourceCode,
distributed: Optional[DistributedConfig] = None,
):
"""Prepare the training script to be executed in the training job container.
Args:
source_code (SourceCode): The source code configuration.
"""
base_command = ""
if source_code.command:
if source_code.entry_script:
logger.warning(
"Both 'command' and 'entry_script' are provided in the SourceCode. "
"Defaulting to 'command'."
)
base_command = source_code.command.split()
base_command = " ".join(base_command)
install_requirements = ""
if source_code.requirements:
if self._jumpstart_config and source_code.requirements == "auto":
install_requirements = INSTALL_AUTO_REQUIREMENTS
else:
install_requirements = INSTALL_REQUIREMENTS.format(
requirements_file=source_code.requirements
)
working_dir = ""
if source_code.source_dir:
working_dir = f"cd {SM_CODE_CONTAINER_PATH} \n"
if source_code.source_dir.endswith(".tar.gz"):
tarfile_name = os.path.basename(source_code.source_dir)
working_dir += f"tar -xzf {tarfile_name} \n"
if base_command:
execute_driver = EXECUTE_BASE_COMMANDS.format(base_command=base_command)
elif distributed:
execute_driver = EXEUCTE_DISTRIBUTED_DRIVER.format(
driver_name=distributed.__class__.__name__,
driver_script=distributed.driver_script,
)
elif source_code.entry_script and not source_code.command and not distributed:
if not source_code.entry_script.endswith((".py", ".sh")):
raise ValueError(
f"Unsupported entry script: {source_code.entry_script}."
+ "Only .py and .sh scripts are supported."
)
execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER
else:
# This should never be reached, as the source_code should have been validated.
raise ValueError(
f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}."
"Please provide a valid configuration with atleast one of 'command'"
" or 'entry_script'."
)
train_script = TRAIN_SCRIPT_TEMPLATE.format(
working_dir=working_dir,
install_requirements=install_requirements,
execute_driver=execute_driver,
)
with open(os.path.join(tmp_dir.name, TRAIN_SCRIPT), "w") as f:
f.write(train_script)
[docs]
@classmethod
def from_recipe(
cls,
training_recipe: str,
compute: Compute,
recipe_overrides: Optional[Dict[str, Any]] = None,
networking: Optional[Networking] = None,
stopping_condition: Optional[StoppingCondition] = None,
requirements: Optional[str] = None,
training_image: Optional[str] = None,
training_image_config: Optional[TrainingImageConfig] = None,
output_data_config: Optional[shapes.OutputDataConfig] = None,
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
checkpoint_config: Optional[shapes.CheckpointConfig] = None,
training_input_mode: Optional[str] = "File",
environment: Optional[Dict[str, str]] = None,
hyperparameters: Optional[Union[Dict[str, Any], str]] = {},
tags: Optional[List[Tag]] = None,
sagemaker_session: Optional[Session] = None,
role: Optional[str] = None,
base_job_name: Optional[str] = None,
) -> "ModelTrainer": # noqa: D412
"""Create a ModelTrainer from a training recipe.
Example:
.. code:: python
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import Compute
recipe_overrides = {
"run": {
"results_dir": "/opt/ml/model",
},
"model": {
"data": {
"use_synthetic_data": True
}
}
}
compute = Compute(
instance_type="ml.p5.48xlarge",
keep_alive_period_in_seconds=3600
)
model_trainer = ModelTrainer.from_recipe(
training_recipe="fine-tuning/deepseek/hf_deepseek_r1_distilled_llama_8b_seq8k_gpu_fine_tuning",
recipe_overrides=recipe_overrides,
compute=compute,
)
model_trainer.train(wait=False)
Args:
training_recipe (str):
The training recipe to use for training the model. This must be the name of
a sagemaker training recipe or a path to a local training recipe .yaml file.
For available training recipes, see: https://github.com/aws/sagemaker-hyperpod-recipes/
compute (Compute):
The compute configuration. This is used to specify the compute resources for
the training job. Specifying instance_type is required for training recipes.
Must be a GPU or Tranium instance type.
recipe_overrides (Optional[Dict[str, Any]]):
The recipe overrides. This is used to override the default recipe parameters.
networking (Optional[Networking]):
The networking configuration. This is used to specify the networking settings
for the training job.
stopping_condition (Optional[StoppingCondition]):
The stopping condition. This is used to specify the different stopping
conditions for the training job.
If not specified, will default to 1 hour max run time.
requirements (Optional[str]):
The path to a requirements file to install in the training job container.
training_image (Optional[str]):
The training image URI to use for the training job container. If not specified,
the training image will be determined from the recipe.
training_image_config (Optional[TrainingImageConfig]):
Training image Config. This is the configuration to use an image from a private
Docker registry for a training job.
output_data_config (Optional[OutputDataConfig]):
The output data configuration. This is used to specify the output data location
for the training job.
If not specified, will default to ``s3://<default_bucket>/<base_job_name>/output/``.
input_data_config (Optional[List[Union[Channel, InputData]]]):
The input data config for the training job.
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
string, local file path string, S3DataSource object, or FileSystemDataSource object.
checkpoint_config (Optional[CheckpointConfig]):
Contains information about the output location for managed spot training checkpoint
data.
training_input_mode (Optional[str]):
The input mode for the training job. Valid values are "Pipe", "File", "FastFile".
Defaults to "File".
environment (Optional[Dict[str, str]]):
The environment variables for the training job.
tags (Optional[List[Tag]]):
An array of key-value pairs. You can use tags to categorize your AWS resources
in different ways, for example, by purpose, owner, or environment.
sagemaker_session (Optional[Session]):
The SageMakerCore session.
If not specified, a new session will be created.
role (Optional[str]):
The IAM role ARN for the training job.
If not specified, the default SageMaker execution role will be used.
base_job_name (Optional[str]):
The base name for the training job.
If not specified, a default name will be generated using the algorithm name
or training image.
"""
if compute.instance_type is None:
raise ValueError("Must set ``instance_type`` in Compute when using training recipes.")
device_type = _determine_device_type(compute.instance_type)
recipe = _load_base_recipe(
training_recipe=training_recipe, recipe_overrides=recipe_overrides
)
is_nova = _is_nova_recipe(recipe=recipe)
is_llmft = _is_llmft_recipe(recipe=recipe)
if device_type == "cpu" and not (is_nova or is_llmft):
raise ValueError(
"Training recipes are not supported for CPU instances. "
"Please provide a GPU or Tranium instance type."
)
if training_image is None and (is_nova or is_llmft):
raise ValueError("training_image must be provided when using recipe for Nova or LLMFT")
if training_image_config and training_image is None:
raise ValueError("training_image must be provided when using training_image_config.")
sagemaker_session = TrainDefaults.get_sagemaker_session(sagemaker_session)
role = TrainDefaults.get_role(role=role, sagemaker_session=sagemaker_session)
# The training recipe is used to prepare the following args:
# - source_code
# - training_image
# - distributed
# - compute
# - hyperparameters
model_trainer_args, tmp_dir = _get_args_from_recipe(
training_recipe=recipe,
recipe_overrides=recipe_overrides,
requirements=requirements,
compute=compute,
region_name=sagemaker_session.boto_region_name,
role=role,
)
if training_image is not None:
model_trainer_args["training_image"] = training_image
if hyperparameters and not is_nova:
logger.warning(
"Hyperparameters are not supported for general and LLMFT training recipes. "
+ "Ignoring hyperparameters input."
)
if is_nova:
if hyperparameters and isinstance(hyperparameters, str):
hyperparameters = cls._validate_and_load_hyperparameters_file(hyperparameters)
model_trainer_args["hyperparameters"].update(hyperparameters)
elif hyperparameters and isinstance(hyperparameters, dict):
model_trainer_args["hyperparameters"].update(hyperparameters)
model_trainer = cls(
sagemaker_session=sagemaker_session,
role=role,
base_job_name=base_job_name,
networking=networking,
stopping_condition=stopping_condition,
training_image_config=training_image_config,
output_data_config=output_data_config,
input_data_config=input_data_config,
checkpoint_config=checkpoint_config,
training_input_mode=training_input_mode,
environment=environment,
tags=tags,
**model_trainer_args,
)
model_trainer._is_nova_recipe = is_nova
model_trainer._is_llmft_recipe = is_llmft
model_trainer._temp_recipe_train_dir = tmp_dir
return model_trainer
[docs]
@classmethod
def from_jumpstart_config(
cls,
jumpstart_config: JumpStartConfig,
source_code: Optional[SourceCode] = None,
compute: Optional[Compute] = None,
networking: Optional[Networking] = None,
stopping_condition: Optional[StoppingCondition] = None,
training_image: Optional[str] = None,
training_image_config: Optional[TrainingImageConfig] = None,
output_data_config: Optional[shapes.OutputDataConfig] = None,
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
checkpoint_config: Optional[shapes.CheckpointConfig] = None,
training_input_mode: Optional[str] = "File",
environment: Optional[Dict[str, str]] = {},
hyperparameters: Optional[Union[Dict[str, Any], str]] = {},
tags: Optional[List[Tag]] = None,
sagemaker_session: Optional[Session] = None,
role: Optional[str] = None,
base_job_name: Optional[str] = None,
) -> "ModelTrainer": # noqa: D412
"""Create a ModelTrainer from a JumpStart Model ID.
.. code:: python
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import InputData
from sagemaker.core.jumpstart import JumpStartConfig
jumpstart_config = JumpStartConfig(model_id="xxxxxxx")
model_trainer = ModelTrainer.from_jumpstart_config(
jumpstart_config=jumpstart_config
)
training_data = InputData(channel_name="training", data_source="s3://bucket/path")
model_trainer.train(input_data_config=[training_data])
Args:
jumpstart_config (JumpStart):
The JumpStart model configuration. This is used to specify the model ID,
version, and other parameters for the training job.
source_code (Optional[SourceCode]):
The source code configuration. This is used to configure the source code for
running the training job.
compute (Optional[Compute]):
The compute configuration. This is used to specify the compute resources for
the training job.
networking (Optional[Networking]):
The networking configuration. This is used to specify the networking settings
for the training job.
stopping_condition (Optional[StoppingCondition]):
The stopping condition. This is used to specify the different stopping
conditions for the training job.
If not specified, will default to 1 hour max run time.
training_image (Optional[str]):
The training image URI to use for the training job container. If not specified,
the training image will be determined from the recipe.
training_image_config (Optional[TrainingImageConfig]):
Training image Config. This is the configuration to use an image from a private
Docker registry for a training job.
output_data_config (Optional[OutputDataConfig]):
The output data configuration. This is used to specify the output data location
for the training job. If not specified, will default
to ``s3://<default_bucket>/<default_prefix>/<base_job_name>/``.
input_data_config (Optional[List[Union[Channel, InputData]]]):
The input data config for the training job.
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
string, local file path string, S3DataSource object, or FileSystemDataSource object.
checkpoint_config (Optional[CheckpointConfig]):
Contains information about the output location for managed spot training checkpoint
data.
training_input_mode (Optional[str]):
The input mode for the training job. Valid values are "Pipe", "File", "FastFile".
Defaults to "File".
environment (Optional[Dict[str, str]]):
The environment variables for the training job.
hyperparameters (Optional[Union[Dict[str, Any], str]]):
The hyperparameters for the training job. Can be a dictionary of hyperparameters
or a path to hyperparameters json/yaml file.
tags (Optional[List[Tag]]):
An array of key-value pairs. You can use tags to categorize your AWS resources
in different ways, for example, by purpose, owner, or environment.
sagemaker_session (Optional[Session]):
The SageMaker Session. This object can be used to manage underlying
boto3 clients and to specify aritfact upload paths via the ``default_bucket``
and ``default_bucket_prefix`` attributes.
If not specified, a new session will be created.
role (Optional[str]):
The IAM role ARN for the training job.
If not specified, the default SageMaker execution role will be used.
base_job_name (Optional[str]):
The base name for the training job.
If not specified, a default name will be generated using the algorithm name
or training image name.
"""
sagemaker_session = TrainDefaults.get_sagemaker_session(sagemaker_session=sagemaker_session)
role = TrainDefaults.get_role(role=role, sagemaker_session=sagemaker_session)
_, document = get_hub_content_and_document(
jumpstart_config=jumpstart_config, sagemaker_session=sagemaker_session
)
# Basic Validation
if not document.TrainingSupported:
raise ValueError(
f"Training is not supported for the model ID: {jumpstart_config.model_id}.\n"
"Please check that the model ID is available for training."
)
if compute and document.SupportedTrainingInstanceTypes:
if compute.instance_type not in document.SupportedTrainingInstanceTypes:
raise ValueError(
"Training is not supported for model ID with instance type: "
f" {compute.instance_type}.\n"
"This model ID is only supported for the following instance types:\n"
f"{document.SupportedTrainingInstanceTypes}.\n"
)
if document.GatedBucket:
eula_url = get_eula_url(sagemaker_session=sagemaker_session, document=document)
if not jumpstart_config.accept_eula:
raise ValueError(
f"Model {jumpstart_config.model_id} is a gated model "
"and requires accepting the EULA via the `accept_eula` parameter.\n"
f"See {eula_url} for terms of use."
)
logger.warning(f"Model {jumpstart_config.model_id} is a gated model ")
print(f"See {eula_url} for terms of use.")
compute = JumpStartTrainDefaults.get_compute(
jumpstart_config=jumpstart_config,
compute=compute,
sagemaker_session=sagemaker_session,
)
networking = JumpStartTrainDefaults.get_networking(
jumpstart_config=jumpstart_config,
networking=networking,
sagemaker_session=sagemaker_session,
)
training_image = JumpStartTrainDefaults.get_training_image(
jumpstart_config=jumpstart_config,
compute=compute,
training_image=training_image,
sagemaker_session=sagemaker_session,
)
base_job_name = JumpStartTrainDefaults.get_base_job_name(
jumpstart_config=jumpstart_config,
base_job_name=base_job_name,
)
environment = JumpStartTrainDefaults.get_enviornment(
jumpstart_config=jumpstart_config,
compute=compute,
environment=environment,
sagemaker_session=sagemaker_session,
)
if hyperparameters and isinstance(hyperparameters, str):
hyperparameters = cls._validate_and_fetch_hyperparameters_file(hyperparameters)
hyperparameters = JumpStartTrainDefaults.get_hyperparameters(
jumpstart_config=jumpstart_config,
compute=compute,
hyperparameters=hyperparameters,
environment=environment,
sagemaker_session=sagemaker_session,
)
source_code = JumpStartTrainDefaults.get_source_code(
jumpstart_config=jumpstart_config,
source_code=source_code,
sagemaker_session=sagemaker_session,
)
input_data_config = JumpStartTrainDefaults.get_training_dataset_input(
jumpstart_config=jumpstart_config,
input_data_config=input_data_config,
sagemaker_session=sagemaker_session,
)
input_data_config = JumpStartTrainDefaults.get_model_artifact_input(
jumpstart_config=jumpstart_config,
compute=compute,
input_data_config=input_data_config,
environment=environment,
sagemaker_session=sagemaker_session,
)
output_data_config = JumpStartTrainDefaults.get_output_data_config(
jumpstart_config=jumpstart_config,
base_job_name=base_job_name,
output_data_config=output_data_config,
sagemaker_session=sagemaker_session,
)
tags = JumpStartTrainDefaults.get_tags(
jumpstart_config=jumpstart_config,
tags=tags,
sagemaker_session=sagemaker_session,
)
model_trainer = cls(
source_code=source_code,
compute=compute,
networking=networking,
stopping_condition=stopping_condition,
training_image=training_image,
training_image_config=training_image_config,
output_data_config=output_data_config,
input_data_config=input_data_config,
checkpoint_config=checkpoint_config,
training_input_mode=training_input_mode,
environment=environment,
hyperparameters=hyperparameters,
tags=tags,
sagemaker_session=sagemaker_session,
role=role,
base_job_name=base_job_name,
)
model_trainer._jumpstart_config = jumpstart_config
return model_trainer
[docs]
def with_tensorboard_output_config(
self, tensorboard_output_config: Optional[shapes.TensorBoardOutputConfig] = None
) -> "ModelTrainer": # noqa: D412
"""Set the TensorBoard output configuration.
Example:
.. code:: python
from sagemaker.train import ModelTrainer
model_trainer = ModelTrainer(
...
).with_tensorboard_output_config()
Args:
tensorboard_output_config (sagemaker.train.configs.TensorBoardOutputConfig):
The TensorBoard output configuration.
"""
self._tensorboard_output_config = (
tensorboard_output_config or configs.TensorBoardOutputConfig()
)
return self
[docs]
def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer": # noqa: D412
"""Set the retry strategy for the training job.
Example:
..code:: python
from sagemaker.train import ModelTrainer
from sagemaker.train.configs import RetryStrategy
retry_strategy = RetryStrategy(maximum_retry_attempts=3)
model_trainer = ModelTrainer(
...
).with_retry_strategy(retry_strategy)
Args:
retry_strategy (sagemaker.train.configs.RetryStrategy):
The retry strategy for the training job.
"""
self._retry_strategy = retry_strategy
return self
[docs]
def with_infra_check_config(
self, infra_check_config: Optional[InfraCheckConfig] = None
) -> "ModelTrainer": # noqa: D412
"""Set the infra check configuration for the training job.
Example:
..code:: python
from sagemaker.train import ModelTrainer
model_trainer = ModelTrainer(
...
).with_infra_check_config(infra_check_config)
Args:
infra_check_config (sagemaker.train.configs.InfraCheckConfig):
The infra check configuration for the training job.
"""
self._infra_check_config = infra_check_config or InfraCheckConfig(enable_infra_check=True)
return self
[docs]
def with_session_chaining_config(
self, session_chaining_config: Optional[SessionChainingConfig] = None
) -> "ModelTrainer": # noqa: D412
"""Set the session chaining configuration for the training job.
Example:
..code:: python
from sagemaker.train import ModelTrainer
model_trainer = ModelTrainer(
...
).with_session_chaining_config()
Args:
session_chaining_config (sagemaker.train.configs.SessionChainingConfig):
The session chaining configuration for the training job.
"""
self._session_chaining_config = session_chaining_config or SessionChainingConfig(
enable_session_tag_chaining=True
)
return self
[docs]
def with_remote_debug_config(
self, remote_debug_config: RemoteDebugConfig
) -> "ModelTrainer": # noqa: D412
"""Set the remote debug configuration for the training job.
Example:
..code:: python
from sagemaker.train import ModelTrainer
model_trainer = ModelTrainer(
...
).with_remote_debug_config()
Args:
remote_debug_config (sagemaker.train.configs.RemoteDebugConfig):
The remote debug configuration for the training job.
"""
self._remote_debug_config = remote_debug_config or RemoteDebugConfig(
enable_remote_debug=True
)
return self
[docs]
def with_checkpoint_config(
self, checkpoint_config: Optional[shapes.CheckpointConfig] = None
) -> "ModelTrainer": # noqa: D412
"""Set the checkpoint configuration for the training job.
Example:
.. code:: python
from sagemaker.train import ModelTrainer
model_trainer = ModelTrainer(
...
).with_checkpoint_config()
Args:
checkpoint_config (sagemaker.modules.configs.CheckpointConfig):
The checkpoint configuration for the training job.
"""
self.checkpoint_config = checkpoint_config or configs.CheckpointConfig()
return self
[docs]
def with_metric_definitions(
self,
metric_definitions: List[MetricDefinition]
) -> "ModelTrainer": # noqa: D412
"""Set the metric definitions for the training job.
Example:
.. code:: python
from sagemaker.modules.train import ModelTrainer
from sagemaker.modules.configs import MetricDefinition
metric_definitions = [
MetricDefinition(
name="loss",
regex="Loss: (.*?)",
)
]
model_trainer = ModelTrainer(
...
).with_metric_definitions(metric_definitions)
Args:
metric_definitions (List[MetricDefinition]):
The metric definitions for the training job.
"""
self._metric_definitions = metric_definitions
return self