sagemaker.train.model_trainer#

ModelTrainer class module.

Classes

Mode(value)

Enum class for training mode.

ModelTrainer(*, training_mode, ...)

Class that trains a model using AWS SageMaker.

class sagemaker.train.model_trainer.Mode(value)[source]#

Bases: Enum

Enum class for training mode.

LOCAL_CONTAINER = 'LOCAL_CONTAINER'#
SAGEMAKER_TRAINING_JOB = 'SAGEMAKER_TRAINING_JOB'#
class sagemaker.train.model_trainer.ModelTrainer(*, training_mode: ~sagemaker.train.model_trainer.Mode = Mode.SAGEMAKER_TRAINING_JOB, sagemaker_session: ~sagemaker.core.helper.session_helper.Session | None = None, role: str | None = None, base_job_name: str | None = None, source_code: ~sagemaker.core.training.configs.SourceCode | None = None, distributed: ~sagemaker.train.distributed.DistributedConfig | None = None, compute: ~sagemaker.core.training.configs.Compute | None = None, networking: ~sagemaker.core.training.configs.Networking | None = None, stopping_condition: ~sagemaker.core.shapes.shapes.StoppingCondition | None = None, training_image: str | ~sagemaker.core.helper.pipeline_variable.PipelineVariable | None = None, training_image_config: ~sagemaker.core.shapes.shapes.TrainingImageConfig | None = None, algorithm_name: str | ~sagemaker.core.helper.pipeline_variable.PipelineVariable | None = None, output_data_config: ~sagemaker.core.shapes.shapes.OutputDataConfig | None = None, input_data_config: ~typing.List[~sagemaker.core.shapes.shapes.Channel | ~sagemaker.core.training.configs.InputData] | None = None, checkpoint_config: ~sagemaker.core.shapes.shapes.CheckpointConfig | None = None, training_input_mode: str | ~sagemaker.core.helper.pipeline_variable.PipelineVariable | None = 'File', environment: ~typing.Dict[str, str | ~sagemaker.core.helper.pipeline_variable.PipelineVariable] | None = {}, hyperparameters: ~typing.Dict[str, ~typing.Any] | str | None = {}, tags: ~typing.List[~sagemaker.core.shapes.shapes.Tag] | None = None, local_container_root: str | None = '/home/docs/checkouts/readthedocs.org/user_builds/sagemaker/checkouts/5737/docs', config_mgr: ~sagemaker.core.config.config_manager.SageMakerConfig = <sagemaker.core.config.config_manager.SageMakerConfig object>)[source]#

Bases: BaseModel

Class that trains a model using AWS SageMaker.

Example:

from sagemaker.train import ModelTrainer
from sagemaker.train.configs import SourceCode, Compute, InputData

ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
model_trainer = ModelTrainer(
    training_image=training_image,
    source_code=source_code,
)

train_data = InputData(channel_name="train", data_source="s3://bucket/train")
model_trainer.train(input_data_config=[train_data])

training_job = model_trainer._latest_training_job
Parameters:
  • training_mode (Mode) – The training mode. Valid values are “Mode.LOCAL_CONTAINER” or “Mode.SAGEMAKER_TRAINING_JOB”.

  • sagemaker_session (Optiona(Session)) – The SageMaker Session. This object can be used to manage underlying boto3 clients and to specify aritfact upload paths via the default_bucket and default_bucket_prefix attributes. If not specified, a new session will be created.

  • role (Optional(str)) – The IAM role ARN for the training job. If not specified, the default SageMaker execution role will be used.

  • base_job_name (Optional[str]) – The base name for the training job. If not specified, a default name will be generated using the algorithm name or training image name.

  • source_code (Optional[SourceCode]) – The source code configuration. This is used to configure the source code for running the training job.

  • distributed (Optional[DistributedConfig]) – The distributed runner for the training job. This is used to configure a distributed training job. If specifed, source_code must also be provided.

  • compute (Optional[Compute]) – The compute configuration. This is used to specify the compute resources for the training job. If not specified, will default to 1 instance of ml.m5.xlarge.

  • networking (Optional[Networking]) – The networking configuration. This is used to specify the networking settings for the training job.

  • stopping_condition (Optional[StoppingCondition]) – The stopping condition. This is used to specify the different stopping conditions for the training job. If not specified, will default to 1 hour max run time.

  • algorithm_name (Optional[str]) – The SageMaker marketplace algorithm name/arn to use for the training job. algorithm_name cannot be specified if training_image is specified.

  • training_image (Optional[str]) – The training image URI to use for the training job container. training_image cannot be specified if algorithm_name is specified. To find available sagemaker distributed images, see: https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths

  • training_image_config (Optional[TrainingImageConfig]) – Training image Config. This is the configuration to use an image from a private Docker registry for a training job.

  • output_data_config (Optional[OutputDataConfig]) – The output data configuration. This is used to specify the output data location for the training job. If not specified in the session, will default to s3://<default_bucket>/<default_prefix>/<base_job_name>/.

  • input_data_config (Optional[List[Union[Channel, InputData]]]) – The input data config for the training job. Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.

  • checkpoint_config (Optional[CheckpointConfig]) – Contains information about the output location for managed spot training checkpoint data.

  • training_input_mode (Optional[str]) – The input mode for the training job. Valid values are “Pipe”, “File”, “FastFile”. Defaults to “File”.

  • environment (Optional[Dict[str, str]]) – The environment variables for the training job.

  • hyperparameters (Optional[Union[Dict[str, Any], str]) – The hyperparameters for the training job. Can be a dictionary of hyperparameters or a path to hyperparameters json/yaml file.

  • tags (Optional[List[Tag]]) – An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment.

  • local_container_root (Optional[str]) – The local root directory to store artifacts from a training job launched in “LOCAL_CONTAINER” mode.

CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = ['role', 'base_job_name', 'source_code', 'compute', 'networking', 'stopping_condition', 'training_image', 'training_image_config', 'algorithm_name', 'output_data_config', 'checkpoint_config', 'training_input_mode', 'environment', 'hyperparameters']#
SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = {'checkpoint_config': <class 'sagemaker.core.training.configs.CheckpointConfig'>, 'compute': <class 'sagemaker.core.training.configs.Compute'>, 'networking': <class 'sagemaker.core.training.configs.Networking'>, 'output_data_config': <class 'sagemaker.core.training.configs.OutputDataConfig'>, 'source_code': <class 'sagemaker.core.training.configs.SourceCode'>, 'stopping_condition': <class 'sagemaker.core.shapes.shapes.StoppingCondition'>, 'training_image_config': <class 'sagemaker.core.shapes.shapes.TrainingImageConfig'>}#
algorithm_name: str | PipelineVariable | None#
base_job_name: str | None#
checkpoint_config: CheckpointConfig | None#
compute: Compute | None#
config_mgr: SageMakerConfig#
create_input_data_channel(channel_name: str, data_source: str | S3DataSource | FileSystemDataSource, key_prefix: str | None = None, ignore_patterns: List[str] | None = None) Channel[source]#

Create an input data channel for the training job.

Parameters:
  • channel_name (str) – The name of the input data channel.

  • data_source (DataSourceType) – The data source for the input data channel. DataSourceType can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.

  • key_prefix (Optional[str]) –

    The key prefix to use when uploading data to S3. Only applicable when data_source is a local file path string. If not specified, local data will be uploaded to: s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/

    If specified, local data will be uploaded to: s3://<default_bucket_path>/<key_prefix>/<channel_name>/

  • ignore_patterns – (Optional[List[str]]) : The ignore patterns to ignore specific files/folders when uploading to S3. If not specified, default to: [‘.env’, ‘.git’, ‘__pycache__’, ‘.DS_Store’, ‘.cache’, ‘.ipynb_checkpoints’].

distributed: DistributedConfig | None#
environment: Dict[str, str | PipelineVariable] | None#
classmethod from_jumpstart_config(jumpstart_config: JumpStartConfig, source_code: SourceCode | None = None, compute: Compute | None = None, networking: Networking | None = None, stopping_condition: StoppingCondition | None = None, training_image: str | None = None, training_image_config: TrainingImageConfig | None = None, output_data_config: OutputDataConfig | None = None, input_data_config: List[Channel | InputData] | None = None, checkpoint_config: CheckpointConfig | None = None, training_input_mode: str | None = 'File', environment: Dict[str, str] | None = {}, hyperparameters: Dict[str, Any] | str | None = {}, tags: List[Tag] | None = None, sagemaker_session: Session | None = None, role: str | None = None, base_job_name: str | None = None) ModelTrainer[source]#

Create a ModelTrainer from a JumpStart Model ID.

from sagemaker.train import ModelTrainer
from sagemaker.train.configs import InputData
from sagemaker.core.jumpstart import JumpStartConfig

jumpstart_config = JumpStartConfig(model_id="xxxxxxx")

model_trainer = ModelTrainer.from_jumpstart_config(
    jumpstart_config=jumpstart_config
)

training_data = InputData(channel_name="training", data_source="s3://bucket/path")
model_trainer.train(input_data_config=[training_data])
Parameters:
  • jumpstart_config (JumpStart) – The JumpStart model configuration. This is used to specify the model ID, version, and other parameters for the training job.

  • source_code (Optional[SourceCode]) – The source code configuration. This is used to configure the source code for running the training job.

  • compute (Optional[Compute]) – The compute configuration. This is used to specify the compute resources for the training job.

  • networking (Optional[Networking]) – The networking configuration. This is used to specify the networking settings for the training job.

  • stopping_condition (Optional[StoppingCondition]) – The stopping condition. This is used to specify the different stopping conditions for the training job. If not specified, will default to 1 hour max run time.

  • training_image (Optional[str]) – The training image URI to use for the training job container. If not specified, the training image will be determined from the recipe.

  • training_image_config (Optional[TrainingImageConfig]) – Training image Config. This is the configuration to use an image from a private Docker registry for a training job.

  • output_data_config (Optional[OutputDataConfig]) – The output data configuration. This is used to specify the output data location for the training job. If not specified, will default to s3://<default_bucket>/<default_prefix>/<base_job_name>/.

  • input_data_config (Optional[List[Union[Channel, InputData]]]) – The input data config for the training job. Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.

  • checkpoint_config (Optional[CheckpointConfig]) – Contains information about the output location for managed spot training checkpoint data.

  • training_input_mode (Optional[str]) – The input mode for the training job. Valid values are “Pipe”, “File”, “FastFile”. Defaults to “File”.

  • environment (Optional[Dict[str, str]]) – The environment variables for the training job.

  • hyperparameters (Optional[Union[Dict[str, Any], str]]) – The hyperparameters for the training job. Can be a dictionary of hyperparameters or a path to hyperparameters json/yaml file.

  • tags (Optional[List[Tag]]) – An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment.

  • sagemaker_session (Optional[Session]) – The SageMaker Session. This object can be used to manage underlying boto3 clients and to specify aritfact upload paths via the default_bucket and default_bucket_prefix attributes. If not specified, a new session will be created.

  • role (Optional[str]) – The IAM role ARN for the training job. If not specified, the default SageMaker execution role will be used.

  • base_job_name (Optional[str]) – The base name for the training job. If not specified, a default name will be generated using the algorithm name or training image name.

classmethod from_recipe(training_recipe: str, compute: Compute, recipe_overrides: Dict[str, Any] | None = None, networking: Networking | None = None, stopping_condition: StoppingCondition | None = None, requirements: str | None = None, training_image: str | None = None, training_image_config: TrainingImageConfig | None = None, output_data_config: OutputDataConfig | None = None, input_data_config: List[Channel | InputData] | None = None, checkpoint_config: CheckpointConfig | None = None, training_input_mode: str | None = 'File', environment: Dict[str, str] | None = None, hyperparameters: Dict[str, Any] | str | None = {}, tags: List[Tag] | None = None, sagemaker_session: Session | None = None, role: str | None = None, base_job_name: str | None = None) ModelTrainer[source]#

Create a ModelTrainer from a training recipe.

Example:

from sagemaker.train import ModelTrainer
from sagemaker.train.configs import Compute

recipe_overrides = {
    "run": {
        "results_dir": "/opt/ml/model",
    },
    "model": {
        "data": {
            "use_synthetic_data": True
        }
    }
}

compute = Compute(
    instance_type="ml.p5.48xlarge",
    keep_alive_period_in_seconds=3600
)

model_trainer = ModelTrainer.from_recipe(
    training_recipe="fine-tuning/deepseek/hf_deepseek_r1_distilled_llama_8b_seq8k_gpu_fine_tuning",
    recipe_overrides=recipe_overrides,
    compute=compute,
)

model_trainer.train(wait=False)
Parameters:
  • training_recipe (str) – The training recipe to use for training the model. This must be the name of a sagemaker training recipe or a path to a local training recipe .yaml file. For available training recipes, see: aws/sagemaker-hyperpod-recipes

  • compute (Compute) – The compute configuration. This is used to specify the compute resources for the training job. Specifying instance_type is required for training recipes. Must be a GPU or Tranium instance type.

  • recipe_overrides (Optional[Dict[str, Any]]) – The recipe overrides. This is used to override the default recipe parameters.

  • networking (Optional[Networking]) – The networking configuration. This is used to specify the networking settings for the training job.

  • stopping_condition (Optional[StoppingCondition]) – The stopping condition. This is used to specify the different stopping conditions for the training job. If not specified, will default to 1 hour max run time.

  • requirements (Optional[str]) – The path to a requirements file to install in the training job container.

  • training_image (Optional[str]) – The training image URI to use for the training job container. If not specified, the training image will be determined from the recipe.

  • training_image_config (Optional[TrainingImageConfig]) – Training image Config. This is the configuration to use an image from a private Docker registry for a training job.

  • output_data_config (Optional[OutputDataConfig]) – The output data configuration. This is used to specify the output data location for the training job. If not specified, will default to s3://<default_bucket>/<base_job_name>/output/.

  • input_data_config (Optional[List[Union[Channel, InputData]]]) – The input data config for the training job. Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.

  • checkpoint_config (Optional[CheckpointConfig]) – Contains information about the output location for managed spot training checkpoint data.

  • training_input_mode (Optional[str]) – The input mode for the training job. Valid values are “Pipe”, “File”, “FastFile”. Defaults to “File”.

  • environment (Optional[Dict[str, str]]) – The environment variables for the training job.

  • tags (Optional[List[Tag]]) – An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment.

  • sagemaker_session (Optional[Session]) – The SageMakerCore session. If not specified, a new session will be created.

  • role (Optional[str]) – The IAM role ARN for the training job. If not specified, the default SageMaker execution role will be used.

  • base_job_name (Optional[str]) – The base name for the training job. If not specified, a default name will be generated using the algorithm name or training image.

hyperparameters: Dict[str, Any] | str | None#
input_data_config: List[Channel | InputData] | None#
local_container_root: str | None#
model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'validate_assignment': True}#

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

model_post_init(_ModelTrainer__context: Any)[source]#

Post init method to perform custom validation and set default values.

networking: Networking | None#
output_data_config: OutputDataConfig | None#
role: str | None#
sagemaker_session: Session | None#
source_code: SourceCode | None#
stopping_condition: StoppingCondition | None#
tags: List[Tag] | None#
train(input_data_config: List[Channel | InputData] | None = None, wait: bool | None = True, logs: bool | None = True)[source]#

Train a model using AWS SageMaker.

Parameters:
  • input_data_config (Optional[List[Union[Channel, InputData]]]) – The input data config for the training job. Takes a list of Channel objects or a dictionary of channel names to DataSourceType. DataSourceType can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.

  • wait (Optional[bool]) – Whether to wait for the training job to complete before returning. Defaults to True.

  • logs (Optional[bool]) – Whether to display the training container logs while training. Defaults to True.

training_image: str | PipelineVariable | None#
training_image_config: TrainingImageConfig | None#
training_input_mode: str | PipelineVariable | None#
training_mode: Mode#
with_checkpoint_config(checkpoint_config: CheckpointConfig | None = None) ModelTrainer[source]#

Set the checkpoint configuration for the training job.

Example:

from sagemaker.train import ModelTrainer

model_trainer = ModelTrainer(
    ...
).with_checkpoint_config()
Parameters:

checkpoint_config (sagemaker.modules.configs.CheckpointConfig) – The checkpoint configuration for the training job.

with_infra_check_config(infra_check_config: InfraCheckConfig | None = None) ModelTrainer[source]#

Set the infra check configuration for the training job.

Example:

..code:: python

from sagemaker.train import ModelTrainer

model_trainer = ModelTrainer(

).with_infra_check_config(infra_check_config)

Parameters:

infra_check_config (sagemaker.train.configs.InfraCheckConfig) – The infra check configuration for the training job.

with_metric_definitions(metric_definitions: List[MetricDefinition]) ModelTrainer[source]#

Set the metric definitions for the training job. Example: .. code:: python

from sagemaker.modules.train import ModelTrainer from sagemaker.modules.configs import MetricDefinition metric_definitions = [

MetricDefinition(

name=”loss”, regex=”Loss: (.*?)”,

)

] model_trainer = ModelTrainer(

).with_metric_definitions(metric_definitions)

Parameters:

metric_definitions (List[MetricDefinition]) – The metric definitions for the training job.

with_remote_debug_config(remote_debug_config: RemoteDebugConfig) ModelTrainer[source]#

Set the remote debug configuration for the training job.

Example:

..code:: python

from sagemaker.train import ModelTrainer

model_trainer = ModelTrainer(

).with_remote_debug_config()

Parameters:

remote_debug_config (sagemaker.train.configs.RemoteDebugConfig) – The remote debug configuration for the training job.

with_retry_strategy(retry_strategy: RetryStrategy) ModelTrainer[source]#

Set the retry strategy for the training job.

Example:

..code:: python

from sagemaker.train import ModelTrainer from sagemaker.train.configs import RetryStrategy

retry_strategy = RetryStrategy(maximum_retry_attempts=3)

model_trainer = ModelTrainer(

).with_retry_strategy(retry_strategy)

Parameters:

retry_strategy (sagemaker.train.configs.RetryStrategy) – The retry strategy for the training job.

with_session_chaining_config(session_chaining_config: SessionChainingConfig | None = None) ModelTrainer[source]#

Set the session chaining configuration for the training job.

Example:

..code:: python

from sagemaker.train import ModelTrainer

model_trainer = ModelTrainer(

).with_session_chaining_config()

Parameters:

session_chaining_config (sagemaker.train.configs.SessionChainingConfig) – The session chaining configuration for the training job.

with_tensorboard_output_config(tensorboard_output_config: TensorBoardOutputConfig | None = None) ModelTrainer[source]#

Set the TensorBoard output configuration.

Example:

from sagemaker.train import ModelTrainer

model_trainer = ModelTrainer(
    ...
).with_tensorboard_output_config()
Parameters:

tensorboard_output_config (sagemaker.train.configs.TensorBoardOutputConfig) – The TensorBoard output configuration.