Source code for sagemaker.train.common

from typing import Dict, Any
from enum import Enum
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature

JOB_TYPE = "FineTuning"

[docs] class TrainingType(Enum): """Training types for fine-tuning.""" LORA = "LORA" FULL = "FULL"
[docs] class CustomizationTechnique(Enum): """Customization techniques for fine-tuning.""" SFT = "SFT" RLVR = "RLVR" RLAIF = "RLAIF" DPO = "DPO"
[docs] class FineTuningOptions: """Dynamic class for fine-tuning options with validation.""" def __init__(self, options_dict: Dict[str, Any]): self._specs = options_dict.copy() self._initialized = False # Extract default values and set as attributes (no validation during init) for key, spec in options_dict.items(): default_value = spec.get('default') if isinstance(spec, dict) else spec super().__setattr__(key, default_value) self._initialized = True
[docs] def to_dict(self) -> Dict[str, Any]: """Convert back to dictionary for hyperparameters with string values.""" return {k: str(getattr(self, k)) for k in self._specs.keys()}
def __setattr__(self, name: str, value: Any): if name.startswith('_'): super().__setattr__(name, value) elif hasattr(self, '_specs') and name in self._specs: # Only validate if initialized (user is setting values) if getattr(self, '_initialized', False): spec = self._specs[name] if isinstance(spec, dict): self._validate_value(name, value, spec) super().__setattr__(name, value) elif hasattr(self, '_specs'): raise AttributeError(f"'{name}' is not a valid fine-tuning option. Valid options: {list(self._specs.keys())}") else: super().__setattr__(name, value) def _validate_value(self, name: str, value: Any, spec: Dict[str, Any]): """Validate value against parameter specification.""" # Type validation expected_type = spec.get('type') if expected_type == 'float' and not isinstance(value, (int, float)): raise ValueError(f"{name} must be a number, got {type(value).__name__}") elif expected_type == 'integer' and not isinstance(value, int): raise ValueError(f"{name} must be an integer, got {type(value).__name__}") elif expected_type == 'string' and not isinstance(value, str): raise ValueError(f"{name} must be a string, got {type(value).__name__}") # Range validation if 'min' in spec and value < spec['min']: raise ValueError(f"{name} must be >= {spec['min']}, got {value}") if 'max' in spec and value > spec['max']: raise ValueError(f"{name} must be <= {spec['max']}, got {value}") # Enum validation if 'enum' in spec and value not in spec['enum']: raise ValueError(f"{name} must be one of {spec['enum']}, got {value}")
[docs] @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="FineTuningOptions.get_info") def get_info(self, param_name: str = None): """Display parameter information in a user-friendly format.""" if param_name: if param_name not in self._specs: raise ValueError(f"Parameter '{param_name}' not found. Available: {list(self._specs.keys())}") params_to_show = {param_name: self._specs[param_name]} else: params_to_show = self._specs for name, spec in params_to_show.items(): if isinstance(spec, dict): print(f"\n{name}:") print(f" Current value: {getattr(self, name)}") print(f" Type: {spec.get('type', 'unknown')}") print(f" Default: {spec.get('default', 'N/A')}") if 'min' in spec and 'max' in spec: print(f" Range: {spec['min']} - {spec['max']}") elif 'min' in spec: print(f" Min: {spec['min']}") elif 'max' in spec: print(f" Max: {spec['max']}") if 'enum' in spec: print(f" Valid options: {spec['enum']}") if spec.get('required'): print(f" Required: Yes") else: print(f"\n{name}: {getattr(self, name)}")