Source code for sagemaker.train.base_trainer
from abc import ABC, abstractmethod
from typing import Optional, Dict, Any, List, Union
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.training.configs import Tag, Networking, InputData, Channel
from sagemaker.core.shapes import shapes
from sagemaker.core.resources import TrainingJob
[docs]
class BaseTrainer(ABC):
"""Abstract base class for all SageMaker training workflows.
This class provides the common interface and shared functionality for all trainer implementations
including SFT, DPO, RLVR, and RLAIF trainers. It defines the standard parameters and abstract
methods that concrete trainer classes must implement.
Parameters:
sagemaker_session (Optional[Session]):
The SageMaker session for managing API calls and resources.
If not specified, a default session will be created.
role (Optional[str]):
The IAM role ARN for the training job execution.
If not specified, the default SageMaker execution role will be used.
base_job_name (Optional[str]):
The base name for training jobs. A unique suffix will be appended.
If not specified, a default name will be generated based on the trainer type.
tags (Optional[List[Tag]]):
List of tags to apply to the training job for resource management and billing.
hyperparameters (Optional[Dict[str, Any]]):
Dictionary of hyperparameters for the training job.
Trainer-specific defaults will be applied if not specified.
output_data_config (Optional[shapes.OutputDataConfig]):
Configuration for training job outputs including S3 paths and encryption.
If not specified, default output configuration will be used.
input_data_config (Optional[List[Union[Channel, InputData]]]):
List of input data channels for the training job.
Can include training and validation datasets.
environment (Optional[Dict[str, str]]):
Environment variables to set in the training container.
"""
# Class-level attributes with default values
sagemaker_session: Optional[Session] = None
role: Optional[str] = None
base_job_name: Optional[str] = None
tags: Optional[List[Tag]] = None
hyperparameters: Optional[Dict[str, Any]] = None
output_data_config: Optional[shapes.OutputDataConfig] = None
input_data_config: Optional[List[Union[Channel, InputData]]] = None
environment: Optional[Dict[str, str]] = None
latest_training_job: Optional[TrainingJob] = None
def __init__(
self,
sagemaker_session: Optional[Session] = None,
role: Optional[str] = None,
base_job_name: Optional[str] = None,
tags: Optional[List[Tag]] = None,
hyperparameters: Optional[Dict[str, Any]] = None,
output_data_config: Optional[shapes.OutputDataConfig] = None,
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
environment: Optional[Dict[str, str]] = None,
):
self.sagemaker_session = sagemaker_session
self.role = role
self.base_job_name = base_job_name
self.tags = tags
self.hyperparameters = hyperparameters or {}
self.output_data_config = output_data_config
self.input_data_config = input_data_config
self.environment = environment or {}
def _is_nova_model_for_telemetry(self) -> bool:
"""Check if the model is a Nova model for telemetry tracking."""
from sagemaker.train.common_utils.recipe_utils import _is_nova_model
model_name = getattr(self, "_model_name", None)
return _is_nova_model(model_name) if model_name else False
[docs]
@abstractmethod
def train(self, input_data_config: List[InputData], wait: bool = True, logs: bool = True):
"""Common training method that calls the specific implementation."""
pass