# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Define Queue class for AWS Batch service"""
from __future__ import absolute_import
from typing import Dict, Optional, List
import logging
from sagemaker.train.model_trainer import ModelTrainer, Mode
from .training_queued_job import TrainingQueuedJob
from .batch_api_helper import _submit_service_job, _list_service_job
from .exception import MissingRequiredArgument
from .constants import DEFAULT_TIMEOUT, JOB_STATUS_RUNNING
[docs]
class TrainingQueue:
"""TrainingQueue class for AWS Batch service
With this class, customers are able to create a new queue and submit jobs to AWS Batch Service.
"""
def __init__(self, queue_name: str):
self.queue_name = queue_name
[docs]
def submit(
self,
training_job: ModelTrainer,
inputs,
job_name: Optional[str] = None,
retry_config: Optional[Dict] = None,
priority: Optional[int] = None,
share_identifier: Optional[str] = None,
timeout: Optional[Dict] = None,
tags: Optional[Dict] = None,
quota_share_name: Optional[str] = None,
preemption_config: Optional[Dict] = None,
) -> TrainingQueuedJob:
"""Submit a queued job and return a QueuedJob object.
Args:
training_job: Training job ModelTrainer object.
inputs: Training job inputs.
job_name: Batch job name.
retry_config: Retry configuration for Batch job.
priority: Scheduling priority for Batch job.
share_identifier: Share identifier for Batch job.
timeout: Timeout configuration for Batch job.
tags: Tags apply to Batch job. These tags are for Batch job only.
quota_share_name: Quota Share name for the Batch job.
preemption_config: Preemption configuration.
Returns: a TrainingQueuedJob object with Batch job ARN and job name.
"""
if not isinstance(training_job, ModelTrainer):
raise TypeError(
"training_job must be an instance of ModelTrainer, "
f"but got {type(training_job)}"
)
if training_job.training_mode != Mode.SAGEMAKER_TRAINING_JOB:
raise ValueError(
"TrainingQueue requires using a ModelTrainer with Mode.SAGEMAKER_TRAINING_JOB"
)
if share_identifier != None and quota_share_name != None:
raise ValueError(
"Either share_identifier or quota_share_name can be specified, but not both"
)
training_payload = training_job._create_training_job_args(
input_data_config=inputs, boto3=True
)
if timeout is None:
timeout = DEFAULT_TIMEOUT
if job_name is None:
job_name = training_payload["TrainingJobName"]
resp = _submit_service_job(
training_payload,
job_name,
self.queue_name,
retry_config,
priority,
timeout,
share_identifier,
tags,
quota_share_name,
preemption_config,
)
if "jobArn" not in resp or "jobName" not in resp:
raise MissingRequiredArgument(
"jobArn or jobName is missing in response from Batch submit_service_job API"
)
return TrainingQueuedJob(resp["jobArn"], resp["jobName"])
[docs]
def map(
self,
training_job: ModelTrainer,
inputs,
job_names: Optional[List[str]] = None,
retry_config: Optional[Dict] = None,
priority: Optional[int] = None,
share_identifier: Optional[str] = None,
timeout: Optional[Dict] = None,
tags: Optional[Dict] = None,
quota_share_name: Optional[str] = None,
) -> List[TrainingQueuedJob]:
"""Submit queued jobs to the provided estimator and return a list of TrainingQueuedJob objects.
Args:
training_job: Training job ModelTrainer object.
inputs: List of Training job inputs.
job_names: List of Batch job names.
retry_config: Retry config for the Batch jobs.
priority: Scheduling priority for the Batch jobs.
share_identifier: Share identifier for the Batch jobs.
timeout: Timeout configuration for the Batch jobs.
tags: Tags apply to Batch job. These tags are for Batch job only.
quota_share_name: Quota share name for the Batch jobs.
Returns: a list of TrainingQueuedJob objects with each Batch job ARN and job name.
"""
if job_names is not None:
if len(job_names) != len(inputs):
raise ValueError(
"When specified, the number of job names must match the number of inputs"
)
else:
job_names = [None] * len(inputs)
queued_batch_job_list = []
for index, value in enumerate(inputs):
queued_batch_job = self.submit(
training_job,
value,
job_names[index],
retry_config,
priority,
share_identifier,
timeout,
tags,
quota_share_name,
)
queued_batch_job_list.append(queued_batch_job)
return queued_batch_job_list
[docs]
def list_jobs(
self, job_name: Optional[str] = None, status: Optional[str] = JOB_STATUS_RUNNING
) -> List[TrainingQueuedJob]:
"""List Batch jobs according to job_name or status.
Args:
job_name: Batch job name.
status: Batch job status.
Returns: A list of QueuedJob.
"""
filters = None
if job_name:
filters = [{"name": "JOB_NAME", "values": [job_name]}]
status = None # job_status is ignored when job_name is specified.
jobs_to_return = []
next_token = None
for job_result_dict in _list_service_job(self.queue_name, status, filters, next_token):
for job_result in job_result_dict.get("jobSummaryList", []):
if "jobArn" in job_result and "jobName" in job_result:
jobs_to_return.append(
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None), job_result.get("quotaShareName", None))
)
else:
logging.warning("Missing JobArn or JobName in Batch ListJobs API")
continue
return jobs_to_return
[docs]
def list_jobs_by_share(
self,
status: Optional[str] = JOB_STATUS_RUNNING,
share_identifier: Optional[str] = None,
quota_share_name: Optional[str] = None,
) -> List[TrainingQueuedJob]:
"""List Batch jobs according to status and share.
Args:
status: Batch job status.
share_identifier: Batch fairshare share identifier.
quota_share_name: Batch quota management share name.
Returns: A list of QueuedJob.
"""
filters = None
if share_identifier != None and quota_share_name != None:
raise ValueError(
"Either share_identifier or quota_share_name can be specified, but not both"
)
if share_identifier:
filters = [{"name": "SHARE_IDENTIFIER", "values": [share_identifier]}]
elif quota_share_name:
filters = [{"name": "QUOTA_SHARE_NAME", "values": [quota_share_name]}]
jobs_to_return = []
next_token = None
for job_result_dict in _list_service_job(self.queue_name, status, filters, next_token):
for job_result in job_result_dict.get("jobSummaryList", []):
if "jobArn" in job_result and "jobName" in job_result:
jobs_to_return.append(
TrainingQueuedJob(job_result["jobArn"], job_result["jobName"], job_result.get("shareIdentifier", None), job_result.get("quotaShareName", None))
)
else:
logging.warning("Missing JobArn or JobName in Batch ListJobs API")
continue
return jobs_to_return
[docs]
def get_job(self, job_name):
"""Get a Batch job according to job_name.
Args:
job_name: Batch job name.
Returns: The QueuedJob with name matching job_name.
"""
jobs_to_return = self.list_jobs(job_name)
if len(jobs_to_return) == 0:
raise ValueError(f"Cannot find job: {job_name}")
return jobs_to_return[0]