sagemaker.train.model_trainer#
ModelTrainer class module.
Classes
|
Enum class for training mode. |
|
Class that trains a model using AWS SageMaker. |
- class sagemaker.train.model_trainer.Mode(value)[source]#
Bases:
EnumEnum class for training mode.
- LOCAL_CONTAINER = 'LOCAL_CONTAINER'#
- SAGEMAKER_TRAINING_JOB = 'SAGEMAKER_TRAINING_JOB'#
- class sagemaker.train.model_trainer.ModelTrainer(*, training_mode: ~sagemaker.train.model_trainer.Mode = Mode.SAGEMAKER_TRAINING_JOB, sagemaker_session: ~sagemaker.core.helper.session_helper.Session | None = None, role: str | None = None, base_job_name: str | None = None, source_code: ~sagemaker.core.training.configs.SourceCode | None = None, distributed: ~sagemaker.train.distributed.DistributedConfig | None = None, compute: ~sagemaker.core.training.configs.Compute | None = None, networking: ~sagemaker.core.training.configs.Networking | None = None, stopping_condition: ~sagemaker.core.shapes.shapes.StoppingCondition | None = None, training_image: str | ~sagemaker.core.helper.pipeline_variable.PipelineVariable | None = None, training_image_config: ~sagemaker.core.shapes.shapes.TrainingImageConfig | None = None, algorithm_name: str | ~sagemaker.core.helper.pipeline_variable.PipelineVariable | None = None, output_data_config: ~sagemaker.core.shapes.shapes.OutputDataConfig | None = None, input_data_config: ~typing.List[~sagemaker.core.shapes.shapes.Channel | ~sagemaker.core.training.configs.InputData] | None = None, checkpoint_config: ~sagemaker.core.shapes.shapes.CheckpointConfig | None = None, training_input_mode: str | ~sagemaker.core.helper.pipeline_variable.PipelineVariable | None = 'File', environment: ~typing.Dict[str, str | ~sagemaker.core.helper.pipeline_variable.PipelineVariable] | None = {}, hyperparameters: ~typing.Dict[str, ~typing.Any] | str | None = {}, tags: ~typing.List[~sagemaker.core.shapes.shapes.Tag] | None = None, local_container_root: str | None = '/home/docs/checkouts/readthedocs.org/user_builds/sagemaker/checkouts/5737/docs', config_mgr: ~sagemaker.core.config.config_manager.SageMakerConfig = <sagemaker.core.config.config_manager.SageMakerConfig object>)[source]#
Bases:
BaseModelClass that trains a model using AWS SageMaker.
Example:
from sagemaker.train import ModelTrainer from sagemaker.train.configs import SourceCode, Compute, InputData ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data'] source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns) training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image" model_trainer = ModelTrainer( training_image=training_image, source_code=source_code, ) train_data = InputData(channel_name="train", data_source="s3://bucket/train") model_trainer.train(input_data_config=[train_data]) training_job = model_trainer._latest_training_job
- Parameters:
training_mode (Mode) – The training mode. Valid values are “Mode.LOCAL_CONTAINER” or “Mode.SAGEMAKER_TRAINING_JOB”.
sagemaker_session (Optiona(Session)) – The SageMaker Session. This object can be used to manage underlying boto3 clients and to specify aritfact upload paths via the
default_bucketanddefault_bucket_prefixattributes. If not specified, a new session will be created.role (Optional(str)) – The IAM role ARN for the training job. If not specified, the default SageMaker execution role will be used.
base_job_name (Optional[str]) – The base name for the training job. If not specified, a default name will be generated using the algorithm name or training image name.
source_code (Optional[SourceCode]) – The source code configuration. This is used to configure the source code for running the training job.
distributed (Optional[DistributedConfig]) – The distributed runner for the training job. This is used to configure a distributed training job. If specifed,
source_codemust also be provided.compute (Optional[Compute]) – The compute configuration. This is used to specify the compute resources for the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
networking (Optional[Networking]) – The networking configuration. This is used to specify the networking settings for the training job.
stopping_condition (Optional[StoppingCondition]) – The stopping condition. This is used to specify the different stopping conditions for the training job. If not specified, will default to 1 hour max run time.
algorithm_name (Optional[str]) – The SageMaker marketplace algorithm name/arn to use for the training job. algorithm_name cannot be specified if training_image is specified.
training_image (Optional[str]) – The training image URI to use for the training job container. training_image cannot be specified if algorithm_name is specified. To find available sagemaker distributed images, see: https://docs.aws.amazon.com/sagemaker/latest/dg-ecr-paths/sagemaker-algo-docker-registry-paths
training_image_config (Optional[TrainingImageConfig]) – Training image Config. This is the configuration to use an image from a private Docker registry for a training job.
output_data_config (Optional[OutputDataConfig]) – The output data configuration. This is used to specify the output data location for the training job. If not specified in the session, will default to
s3://<default_bucket>/<default_prefix>/<base_job_name>/.input_data_config (Optional[List[Union[Channel, InputData]]]) – The input data config for the training job. Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.
checkpoint_config (Optional[CheckpointConfig]) – Contains information about the output location for managed spot training checkpoint data.
training_input_mode (Optional[str]) – The input mode for the training job. Valid values are “Pipe”, “File”, “FastFile”. Defaults to “File”.
environment (Optional[Dict[str, str]]) – The environment variables for the training job.
hyperparameters (Optional[Union[Dict[str, Any], str]) – The hyperparameters for the training job. Can be a dictionary of hyperparameters or a path to hyperparameters json/yaml file.
tags (Optional[List[Tag]]) – An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment.
local_container_root (Optional[str]) – The local root directory to store artifacts from a training job launched in “LOCAL_CONTAINER” mode.
- CONFIGURABLE_ATTRIBUTES: ClassVar[List[str]] = ['role', 'base_job_name', 'source_code', 'compute', 'networking', 'stopping_condition', 'training_image', 'training_image_config', 'algorithm_name', 'output_data_config', 'checkpoint_config', 'training_input_mode', 'environment', 'hyperparameters']#
- SERIALIZABLE_CONFIG_ATTRIBUTES: ClassVar[Any] = {'checkpoint_config': <class 'sagemaker.core.training.configs.CheckpointConfig'>, 'compute': <class 'sagemaker.core.training.configs.Compute'>, 'networking': <class 'sagemaker.core.training.configs.Networking'>, 'output_data_config': <class 'sagemaker.core.training.configs.OutputDataConfig'>, 'source_code': <class 'sagemaker.core.training.configs.SourceCode'>, 'stopping_condition': <class 'sagemaker.core.shapes.shapes.StoppingCondition'>, 'training_image_config': <class 'sagemaker.core.shapes.shapes.TrainingImageConfig'>}#
- algorithm_name: str | PipelineVariable | None#
- base_job_name: str | None#
- checkpoint_config: CheckpointConfig | None#
- config_mgr: SageMakerConfig#
- create_input_data_channel(channel_name: str, data_source: str | S3DataSource | FileSystemDataSource, key_prefix: str | None = None, ignore_patterns: List[str] | None = None) Channel[source]#
Create an input data channel for the training job.
- Parameters:
channel_name (str) – The name of the input data channel.
data_source (DataSourceType) – The data source for the input data channel. DataSourceType can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.
key_prefix (Optional[str]) –
The key prefix to use when uploading data to S3. Only applicable when data_source is a local file path string. If not specified, local data will be uploaded to:
s3://<default_bucket_path>/<base_job_name>/input/<channel_name>/If specified, local data will be uploaded to:
s3://<default_bucket_path>/<key_prefix>/<channel_name>/ignore_patterns – (Optional[List[str]]) : The ignore patterns to ignore specific files/folders when uploading to S3. If not specified, default to: [‘.env’, ‘.git’, ‘__pycache__’, ‘.DS_Store’, ‘.cache’, ‘.ipynb_checkpoints’].
- distributed: DistributedConfig | None#
- environment: Dict[str, str | PipelineVariable] | None#
- classmethod from_jumpstart_config(jumpstart_config: JumpStartConfig, source_code: SourceCode | None = None, compute: Compute | None = None, networking: Networking | None = None, stopping_condition: StoppingCondition | None = None, training_image: str | None = None, training_image_config: TrainingImageConfig | None = None, output_data_config: OutputDataConfig | None = None, input_data_config: List[Channel | InputData] | None = None, checkpoint_config: CheckpointConfig | None = None, training_input_mode: str | None = 'File', environment: Dict[str, str] | None = {}, hyperparameters: Dict[str, Any] | str | None = {}, tags: List[Tag] | None = None, sagemaker_session: Session | None = None, role: str | None = None, base_job_name: str | None = None) ModelTrainer[source]#
Create a ModelTrainer from a JumpStart Model ID.
from sagemaker.train import ModelTrainer from sagemaker.train.configs import InputData from sagemaker.core.jumpstart import JumpStartConfig jumpstart_config = JumpStartConfig(model_id="xxxxxxx") model_trainer = ModelTrainer.from_jumpstart_config( jumpstart_config=jumpstart_config ) training_data = InputData(channel_name="training", data_source="s3://bucket/path") model_trainer.train(input_data_config=[training_data])
- Parameters:
jumpstart_config (JumpStart) – The JumpStart model configuration. This is used to specify the model ID, version, and other parameters for the training job.
source_code (Optional[SourceCode]) – The source code configuration. This is used to configure the source code for running the training job.
compute (Optional[Compute]) – The compute configuration. This is used to specify the compute resources for the training job.
networking (Optional[Networking]) – The networking configuration. This is used to specify the networking settings for the training job.
stopping_condition (Optional[StoppingCondition]) – The stopping condition. This is used to specify the different stopping conditions for the training job. If not specified, will default to 1 hour max run time.
training_image (Optional[str]) – The training image URI to use for the training job container. If not specified, the training image will be determined from the recipe.
training_image_config (Optional[TrainingImageConfig]) – Training image Config. This is the configuration to use an image from a private Docker registry for a training job.
output_data_config (Optional[OutputDataConfig]) – The output data configuration. This is used to specify the output data location for the training job. If not specified, will default to
s3://<default_bucket>/<default_prefix>/<base_job_name>/.input_data_config (Optional[List[Union[Channel, InputData]]]) – The input data config for the training job. Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.
checkpoint_config (Optional[CheckpointConfig]) – Contains information about the output location for managed spot training checkpoint data.
training_input_mode (Optional[str]) – The input mode for the training job. Valid values are “Pipe”, “File”, “FastFile”. Defaults to “File”.
environment (Optional[Dict[str, str]]) – The environment variables for the training job.
hyperparameters (Optional[Union[Dict[str, Any], str]]) – The hyperparameters for the training job. Can be a dictionary of hyperparameters or a path to hyperparameters json/yaml file.
tags (Optional[List[Tag]]) – An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment.
sagemaker_session (Optional[Session]) – The SageMaker Session. This object can be used to manage underlying boto3 clients and to specify aritfact upload paths via the
default_bucketanddefault_bucket_prefixattributes. If not specified, a new session will be created.role (Optional[str]) – The IAM role ARN for the training job. If not specified, the default SageMaker execution role will be used.
base_job_name (Optional[str]) – The base name for the training job. If not specified, a default name will be generated using the algorithm name or training image name.
- classmethod from_recipe(training_recipe: str, compute: Compute, recipe_overrides: Dict[str, Any] | None = None, networking: Networking | None = None, stopping_condition: StoppingCondition | None = None, requirements: str | None = None, training_image: str | None = None, training_image_config: TrainingImageConfig | None = None, output_data_config: OutputDataConfig | None = None, input_data_config: List[Channel | InputData] | None = None, checkpoint_config: CheckpointConfig | None = None, training_input_mode: str | None = 'File', environment: Dict[str, str] | None = None, hyperparameters: Dict[str, Any] | str | None = {}, tags: List[Tag] | None = None, sagemaker_session: Session | None = None, role: str | None = None, base_job_name: str | None = None) ModelTrainer[source]#
Create a ModelTrainer from a training recipe.
Example:
from sagemaker.train import ModelTrainer from sagemaker.train.configs import Compute recipe_overrides = { "run": { "results_dir": "/opt/ml/model", }, "model": { "data": { "use_synthetic_data": True } } } compute = Compute( instance_type="ml.p5.48xlarge", keep_alive_period_in_seconds=3600 ) model_trainer = ModelTrainer.from_recipe( training_recipe="fine-tuning/deepseek/hf_deepseek_r1_distilled_llama_8b_seq8k_gpu_fine_tuning", recipe_overrides=recipe_overrides, compute=compute, ) model_trainer.train(wait=False)
- Parameters:
training_recipe (str) – The training recipe to use for training the model. This must be the name of a sagemaker training recipe or a path to a local training recipe .yaml file. For available training recipes, see: aws/sagemaker-hyperpod-recipes
compute (Compute) – The compute configuration. This is used to specify the compute resources for the training job. Specifying instance_type is required for training recipes. Must be a GPU or Tranium instance type.
recipe_overrides (Optional[Dict[str, Any]]) – The recipe overrides. This is used to override the default recipe parameters.
networking (Optional[Networking]) – The networking configuration. This is used to specify the networking settings for the training job.
stopping_condition (Optional[StoppingCondition]) – The stopping condition. This is used to specify the different stopping conditions for the training job. If not specified, will default to 1 hour max run time.
requirements (Optional[str]) – The path to a requirements file to install in the training job container.
training_image (Optional[str]) – The training image URI to use for the training job container. If not specified, the training image will be determined from the recipe.
training_image_config (Optional[TrainingImageConfig]) – Training image Config. This is the configuration to use an image from a private Docker registry for a training job.
output_data_config (Optional[OutputDataConfig]) – The output data configuration. This is used to specify the output data location for the training job. If not specified, will default to
s3://<default_bucket>/<base_job_name>/output/.input_data_config (Optional[List[Union[Channel, InputData]]]) – The input data config for the training job. Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.
checkpoint_config (Optional[CheckpointConfig]) – Contains information about the output location for managed spot training checkpoint data.
training_input_mode (Optional[str]) – The input mode for the training job. Valid values are “Pipe”, “File”, “FastFile”. Defaults to “File”.
environment (Optional[Dict[str, str]]) – The environment variables for the training job.
tags (Optional[List[Tag]]) – An array of key-value pairs. You can use tags to categorize your AWS resources in different ways, for example, by purpose, owner, or environment.
sagemaker_session (Optional[Session]) – The SageMakerCore session. If not specified, a new session will be created.
role (Optional[str]) – The IAM role ARN for the training job. If not specified, the default SageMaker execution role will be used.
base_job_name (Optional[str]) – The base name for the training job. If not specified, a default name will be generated using the algorithm name or training image.
- hyperparameters: Dict[str, Any] | str | None#
- local_container_root: str | None#
- model_config: ClassVar[ConfigDict] = {'arbitrary_types_allowed': True, 'extra': 'forbid', 'validate_assignment': True}#
Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].
- model_post_init(_ModelTrainer__context: Any)[source]#
Post init method to perform custom validation and set default values.
- networking: Networking | None#
- output_data_config: OutputDataConfig | None#
- role: str | None#
- source_code: SourceCode | None#
- stopping_condition: StoppingCondition | None#
- train(input_data_config: List[Channel | InputData] | None = None, wait: bool | None = True, logs: bool | None = True)[source]#
Train a model using AWS SageMaker.
- Parameters:
input_data_config (Optional[List[Union[Channel, InputData]]]) – The input data config for the training job. Takes a list of Channel objects or a dictionary of channel names to DataSourceType. DataSourceType can be an S3 URI string, local file path string, S3DataSource object, or FileSystemDataSource object.
wait (Optional[bool]) – Whether to wait for the training job to complete before returning. Defaults to True.
logs (Optional[bool]) – Whether to display the training container logs while training. Defaults to True.
- training_image: str | PipelineVariable | None#
- training_image_config: TrainingImageConfig | None#
- training_input_mode: str | PipelineVariable | None#
- with_checkpoint_config(checkpoint_config: CheckpointConfig | None = None) ModelTrainer[source]#
Set the checkpoint configuration for the training job.
Example:
from sagemaker.train import ModelTrainer model_trainer = ModelTrainer( ... ).with_checkpoint_config()
- Parameters:
checkpoint_config (sagemaker.modules.configs.CheckpointConfig) – The checkpoint configuration for the training job.
- with_infra_check_config(infra_check_config: InfraCheckConfig | None = None) ModelTrainer[source]#
Set the infra check configuration for the training job.
Example:
- ..code:: python
from sagemaker.train import ModelTrainer
- model_trainer = ModelTrainer(
…
).with_infra_check_config(infra_check_config)
- Parameters:
infra_check_config (sagemaker.train.configs.InfraCheckConfig) – The infra check configuration for the training job.
- with_metric_definitions(metric_definitions: List[MetricDefinition]) ModelTrainer[source]#
Set the metric definitions for the training job. Example: .. code:: python
from sagemaker.modules.train import ModelTrainer from sagemaker.modules.configs import MetricDefinition metric_definitions = [
- MetricDefinition(
name=”loss”, regex=”Loss: (.*?)”,
)
] model_trainer = ModelTrainer(
…
).with_metric_definitions(metric_definitions)
- Parameters:
metric_definitions (List[MetricDefinition]) – The metric definitions for the training job.
- with_remote_debug_config(remote_debug_config: RemoteDebugConfig) ModelTrainer[source]#
Set the remote debug configuration for the training job.
Example:
..code:: python
from sagemaker.train import ModelTrainer
- model_trainer = ModelTrainer(
…
).with_remote_debug_config()
- Parameters:
remote_debug_config (sagemaker.train.configs.RemoteDebugConfig) – The remote debug configuration for the training job.
- with_retry_strategy(retry_strategy: RetryStrategy) ModelTrainer[source]#
Set the retry strategy for the training job.
Example:
..code:: python
from sagemaker.train import ModelTrainer from sagemaker.train.configs import RetryStrategy
retry_strategy = RetryStrategy(maximum_retry_attempts=3)
- model_trainer = ModelTrainer(
…
).with_retry_strategy(retry_strategy)
- Parameters:
retry_strategy (sagemaker.train.configs.RetryStrategy) – The retry strategy for the training job.
- with_session_chaining_config(session_chaining_config: SessionChainingConfig | None = None) ModelTrainer[source]#
Set the session chaining configuration for the training job.
Example:
..code:: python
from sagemaker.train import ModelTrainer
- model_trainer = ModelTrainer(
…
).with_session_chaining_config()
- Parameters:
session_chaining_config (sagemaker.train.configs.SessionChainingConfig) – The session chaining configuration for the training job.
- with_tensorboard_output_config(tensorboard_output_config: TensorBoardOutputConfig | None = None) ModelTrainer[source]#
Set the TensorBoard output configuration.
Example:
from sagemaker.train import ModelTrainer model_trainer = ModelTrainer( ... ).with_tensorboard_output_config()
- Parameters:
tensorboard_output_config (sagemaker.train.configs.TensorBoardOutputConfig) – The TensorBoard output configuration.