Source code for sagemaker.train.aws_batch.training_queued_job

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Define QueuedJob class for AWS Batch service"""
from __future__ import absolute_import

import logging
import time
import asyncio
import re
from typing import Optional, Dict
import nest_asyncio
from sagemaker.core.resources import TrainingJob
from sagemaker.core.shapes import Unassigned
from sagemaker.train.model_trainer import ModelTrainer
from sagemaker.train.configs import (
    Compute,
    Networking,
    StoppingCondition,
    SourceCode,
    TrainingImageConfig,
)
from .batch_api_helper import _terminate_service_job, _describe_service_job, _update_service_job
from .exception import NoTrainingJob, MissingRequiredArgument
from ..utils import _get_training_job_name_from_training_job_arn
from .constants import JOB_STATUS_COMPLETED, JOB_STATUS_FAILED, POLL_IN_SECONDS

logging.basicConfig(
    format="%(asctime)s %(message)s", level=logging.INFO, datefmt="%Y-%m-%d %H:%M:%S"
)


[docs] class TrainingQueuedJob: """TrainingQueuedJob class for AWS Batch service. With this class, customers are able to attach the latest training job to a ModelTrainer. """ def __init__(self, job_arn: str, job_name: str, share_identifier: Optional[str] = None, quota_share_name: Optional[str] = None): self.job_arn = job_arn self.job_name = job_name self.share_identifier = share_identifier self.quota_share_name = quota_share_name self._no_training_job_status = {"SUBMITTED", "PENDING", "RUNNABLE"}
[docs] def get_model_trainer(self) -> ModelTrainer: """Attach the latest training job to a ModelTrainer and return. Returns: a ModelTrainer instance. """ describe_resp = self.describe() job_status = describe_resp.get("status", "") if self._training_job_created(job_status): if "latestAttempt" not in describe_resp: raise MissingRequiredArgument("No LatestAttempt in describe call") new_training_job_name = _get_new_training_job_name_from_latest_attempt( describe_resp["latestAttempt"] ) output_model_trainer = _construct_model_trainer_from_training_job_name( new_training_job_name ) _remove_system_tags_in_place_in_model_trainer_object(output_model_trainer) return output_model_trainer _output_attempt_history(describe_resp) raise NoTrainingJob("No Training job created. Job is still waiting in queue")
[docs] def terminate(self, reason: Optional[str] = "Default terminate reason") -> None: """Terminate Batch job. Args: reason: Reason for terminating a job. Returns: None """ _terminate_service_job(self.job_arn, reason)
[docs] def update(self, scheduling_priority: int) -> Dict: """Update Batch job. Args: scheduling_priority: An integer representing scheduling priority. Returns: A dict which includes jobArn, jobName and jobId. """ return _update_service_job(self.job_arn, scheduling_priority)
[docs] def describe(self) -> Dict: """Describe Batch job. Returns: A dict which includes job parameters, job status, attempts and so on. """ return _describe_service_job(self.job_arn)
def _training_job_created(self, status: str) -> bool: """Return True if a Training job has been created Args: status: Job status returned from Batch API. Returns: a boolean indicating whether a Training job has been created. """ return status not in self._no_training_job_status
[docs] def result(self, timeout: int = None) -> Dict: """Fetch the terminal result of the Batch job. Args: timeout: The time to wait for the Batch job to complete. Defaults to ``None``. Returns: The results of the Batch job, represented as a Dict. """ nest_asyncio.apply() loop = asyncio.get_event_loop() task = loop.create_task(self.fetch_job_results(timeout)) resp = loop.run_until_complete(task) return resp
[docs] async def fetch_job_results(self, timeout: int = None) -> Dict: """Async method that waits for the Batch job to complete or until timeout. Args: timeout: The time to wait for the Batch job to complete. Defaults to ``None``. Returns: The results of the Batch job, represented as a Dict, or an Error. """ self.wait(timeout) describe_resp = self.describe() if describe_resp.get("status", "") == JOB_STATUS_COMPLETED: return describe_resp if describe_resp.get("status", "") == JOB_STATUS_FAILED: raise RuntimeError(describe_resp["statusReason"]) raise TimeoutError("Reached timeout before the Batch job reached a terminal status")
[docs] def wait(self, timeout: int = None) -> Dict: """Wait for the Batch job to finish. This method blocks on the job completing for up to the timeout value (if specified). If timeout is ``None``, this method will block until the job is completed. Args: timeout (int): Timeout in seconds to wait until the job is completed. ``None`` by default. Returns: The last describe_service_job response for the Batch job. """ request_end_time = time.time() + timeout if timeout else None describe_resp = self.describe() job_status = describe_resp.get("status", "") job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) while not job_completed: if timeout and time.time() > request_end_time: logging.info( "Timeout exceeded: %d seconds elapsed. Returning current results", timeout ) break if job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED): break time.sleep(POLL_IN_SECONDS) describe_resp = self.describe() job_status = describe_resp.get("status", "") job_completed = job_status in (JOB_STATUS_COMPLETED, JOB_STATUS_FAILED) return describe_resp
def _construct_model_trainer_from_training_job_name(training_job_name: str) -> ModelTrainer: """Build ModelTrainer instance from training job name. Args: training_job_name: Training job name. Returns: a ModelTrainer instance with _latest_training_job set. """ # Step 1: Get the TrainingJob resource training_job = TrainingJob.get(training_job_name=training_job_name) # Step 2: Extract parameters from training_job to reconstruct ModelTrainer init_params = {} # Required/common parameters init_params["role"] = training_job.role_arn init_params["base_job_name"] = _extract_base_job_name(training_job_name) # Training image or algorithm if training_job.algorithm_specification and not isinstance(training_job.algorithm_specification, Unassigned): if (training_job.algorithm_specification.training_image and not isinstance(training_job.algorithm_specification.training_image, Unassigned)): init_params["training_image"] = training_job.algorithm_specification.training_image if (training_job.algorithm_specification.algorithm_name and not isinstance(training_job.algorithm_specification.algorithm_name, Unassigned)): init_params["algorithm_name"] = training_job.algorithm_specification.algorithm_name if (training_job.algorithm_specification.training_input_mode and not isinstance(training_job.algorithm_specification.training_input_mode, Unassigned)): init_params["training_input_mode"] = training_job.algorithm_specification.training_input_mode # Compute config if training_job.resource_config and not isinstance(training_job.resource_config, Unassigned): compute_params = {} if (training_job.resource_config.instance_type and not isinstance(training_job.resource_config.instance_type, Unassigned)): compute_params["instance_type"] = training_job.resource_config.instance_type if (training_job.resource_config.instance_count and not isinstance(training_job.resource_config.instance_count, Unassigned)): compute_params["instance_count"] = training_job.resource_config.instance_count if (training_job.resource_config.volume_size_in_gb and not isinstance(training_job.resource_config.volume_size_in_gb, Unassigned)): compute_params["volume_size_in_gb"] = training_job.resource_config.volume_size_in_gb # Add managed spot training if enabled (available directly on TrainingJob) if training_job.enable_managed_spot_training and not isinstance(training_job.enable_managed_spot_training, Unassigned): compute_params["enable_managed_spot_training"] = training_job.enable_managed_spot_training if compute_params: # Only create Compute if we have valid params init_params["compute"] = Compute(**compute_params) # Output config - pass the raw training job output config directly if training_job.output_data_config and not isinstance(training_job.output_data_config, Unassigned): init_params["output_data_config"] = training_job.output_data_config # Stopping condition if training_job.stopping_condition and not isinstance(training_job.stopping_condition, Unassigned): if (training_job.stopping_condition.max_runtime_in_seconds and not isinstance(training_job.stopping_condition.max_runtime_in_seconds, Unassigned)): init_params["stopping_condition"] = StoppingCondition( max_runtime_in_seconds=training_job.stopping_condition.max_runtime_in_seconds, ) # Networking if training_job.vpc_config and not isinstance(training_job.vpc_config, Unassigned): networking_params = {} if (training_job.vpc_config.subnets and not isinstance(training_job.vpc_config.subnets, Unassigned)): networking_params["subnets"] = training_job.vpc_config.subnets if (training_job.vpc_config.security_group_ids and not isinstance(training_job.vpc_config.security_group_ids, Unassigned)): networking_params["security_group_ids"] = training_job.vpc_config.security_group_ids # Add network isolation if present (available directly on TrainingJob) if training_job.enable_network_isolation and not isinstance(training_job.enable_network_isolation, Unassigned): networking_params["enable_network_isolation"] = training_job.enable_network_isolation # Add inter-container traffic encryption if present (available directly on TrainingJob) if training_job.enable_inter_container_traffic_encryption and not isinstance(training_job.enable_inter_container_traffic_encryption, Unassigned): networking_params["enable_inter_container_traffic_encryption"] = training_job.enable_inter_container_traffic_encryption if networking_params: # Only create Networking if we have valid params init_params["networking"] = Networking(**networking_params) # Hyperparameters if training_job.hyper_parameters and not isinstance(training_job.hyper_parameters, Unassigned): init_params["hyperparameters"] = training_job.hyper_parameters # Environment if training_job.environment and not isinstance(training_job.environment, Unassigned): init_params["environment"] = training_job.environment # Checkpoint config if training_job.checkpoint_config and not isinstance(training_job.checkpoint_config, Unassigned): init_params["checkpoint_config"] = training_job.checkpoint_config # Step 3: Create ModelTrainer model_trainer = ModelTrainer(**init_params) # Step 4: Set _latest_training_job model_trainer._latest_training_job = training_job return model_trainer def _extract_base_job_name(training_job_name: str) -> str: """Extract base job name from full training job name. Args: training_job_name: Full training job name. Returns: Base job name. """ # Use the same regex pattern as PySDK V2's base_from_name() function # Matches timestamps like: YYYY-MM-DD-HH-MM-SS-SSS or YYMMDD-HHMM match = re.match(r"^(.+)-(\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-\d{3}|\d{6}-\d{4})", training_job_name) return match.group(1) if match else training_job_name def _output_attempt_history(describe_resp: Dict) -> None: """Print attempt history if no Training job created. Args: describe_resp: Describe response from Batch API. Returns: None """ has_seen_status_reason = False for i, attempt_dict in enumerate(describe_resp.get("attempts", [])): if "statusReason" in attempt_dict: logging.info("Attempt %d - %s", i + 1, attempt_dict["statusReason"]) has_seen_status_reason = True if not has_seen_status_reason: logging.info("No attempts found or no statusReason found.") def _get_new_training_job_name_from_latest_attempt(latest_attempt: Dict) -> str: """Extract new Training job name from latest attempt in Batch Describe response. Args: latest_attempt: a Dict containing Training job arn. Returns: new Training job name or None if not found. """ training_job_arn = latest_attempt.get("serviceResourceId", {}).get("value", None) return _get_training_job_name_from_training_job_arn(training_job_arn) def _remove_system_tags_in_place_in_model_trainer_object(model_trainer: ModelTrainer) -> None: """Remove system tags in place. Args: model_trainer: input ModelTrainer object. Returns: None. Remove system tags in place. """ if model_trainer.tags: filtered_tags = [] for tag in model_trainer.tags: # Handle both V2 dict format {"Key": "...", "Value": "..."} and V3 object format with .key attribute if isinstance(tag, dict): # V2 format if not tag.get("Key", "").startswith("aws:"): filtered_tags.append(tag) else: # V3 format - assume it has .key attribute if hasattr(tag, 'key') and not tag.key.startswith("aws:"): filtered_tags.append(tag) elif hasattr(tag, 'Key') and not tag.Key.startswith("aws:"): # Fallback for other formats filtered_tags.append(tag) else: # If we can't determine the key, keep the tag to be safe filtered_tags.append(tag) model_trainer.tags = filtered_tags