# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utils module."""
from __future__ import absolute_import
import re
import os
import json
import subprocess
import tempfile
from pathlib import Path
from datetime import datetime
from typing import Literal, Any
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.shapes import Unassigned
from sagemaker.train import logger
from sagemaker.core.workflow.parameters import PipelineVariable
def _default_bucket_and_prefix(session: Session) -> str:
"""Helper function to get the bucket name with the corresponding prefix if applicable
Returns a string like:
* ``default_bucket/default_bucket_prefix`` if the prefix is set
* ``default_bucket`` if the prefix is not set
Args:
session (Session): The SageMaker session to use
Returns:
str: The bucket name with the prefix if applicable
"""
if session.default_bucket_prefix is not None:
return f"{session.default_bucket()}/{session.default_bucket_prefix}"
return session.default_bucket()
def _default_s3_uri(session: Session, additional_path: str = "") -> str:
"""Helper function to get the default S3 URI for the SageMaker session.
Returns a string like:
* ``s3://default_bucket/default_bucket_prefix`` if the prefix is set
* ``s3://default_bucket`` if the prefix is not set
Args:
session (Session): The SageMaker session to use
additional_path (str): Additional path to append to the S3 URI. Defaults to "".
Returns:
str: The default S3 URI for the SageMaker session
"""
bucket_and_prefix = _default_bucket_and_prefix(session)
additional_path = additional_path.lstrip("/") # Remove leading slash if present
return (
f"s3://{bucket_and_prefix}/{additional_path}"
if additional_path
else f"s3://{bucket_and_prefix}"
)
def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
"""Check if the path is a valid S3 URI.
This method checks if the path is a valid S3 URI. If the path_type is specified,
it will also check if the path is a file or a directory.
This method does not check if the S3 bucket or object exists.
Args:
path (str): S3 URI to validate
path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate.
Defaults to "Any".
Returns:
bool: True if the path is a valid S3 URI, False otherwise
"""
# Check if the path is a valid S3 URI
if not path.startswith("s3://"):
return False
if path_type == "File":
# If it's a file, it should not end with a slash
return not path.endswith("/")
if path_type == "Directory":
# If it's a directory, it should end with a slash
return path.endswith("/")
return path_type == "Any"
def _is_valid_path(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
"""Check if the path is a valid local path.
Args:
path (str): Local path to validate
path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate.
Defaults to "Any".
Returns:
bool: True if the path is a valid local path, False otherwise
"""
if not os.path.exists(path):
return False
if path_type == "File":
return os.path.isfile(path)
if path_type == "Directory":
return os.path.isdir(path)
return path_type == "Any"
def _get_unique_name(base, max_length=63):
"""Generate a unique name based on the base name.
This method generates a unique name based on the base name.
The unique name is generated by appending the current timestamp
to the base name.
Args:
base (str): The base name to use
max_length (int): The maximum length of the unique name. Defaults to 63.
Returns:
str: The unique name
"""
current_time = datetime.now().strftime("%Y%m%d%H%M%S")
base = base.replace("_", "-")
unique_name = f"{base}-{current_time}"
unique_name = unique_name[:max_length] # Truncate to max_length
return unique_name
def _get_repo_name_from_image(image: str) -> str:
"""Get the repository name from the image URI.
Example:
``` python
_get_repo_name_from_image("123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest")
# Returns "my-repo"
```
Args:
image (str): The image URI
Returns:
str: The repository name
"""
return image.split("/")[-1].split(":")[0].split("@")[0]
[docs]
def convert_unassigned_to_none(instance) -> Any:
"""Convert Unassigned values to None for any instance."""
for name, value in instance.__dict__.items():
if isinstance(value, Unassigned):
setattr(instance, name, None)
return instance
[docs]
def safe_serialize(data):
"""Serialize the data without wrapping strings in quotes.
This function handles the following cases:
1. If `data` is a string, it returns the string as-is without wrapping in quotes.
2. If `data` is of type `PipelineVariable`, it returns the json representation of the PipelineVariable
3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns
the JSON-encoded string using `json.dumps()`.
4. If `data` cannot be serialized (e.g., a custom object), it returns the string
representation of the data using `str(data)`.
Args:
data (Any): The data to serialize.
Returns:
str: The serialized JSON-compatible string or the string representation of the input.
"""
if isinstance(data, str):
return data
elif isinstance(data, PipelineVariable):
return data
try:
return json.dumps(data)
except TypeError:
return str(data)
def _run_clone_command_silent(repo_url, dest_dir):
"""Run the 'git clone' command with the repo url and the directory to clone the repo into.
Args:
repo_url (str): Git repo url to be cloned.
dest_dir: (str): Local path where the repo should be cloned into.
Raises:
CalledProcessError: If failed to clone git repo.
"""
my_env = os.environ.copy()
if repo_url.startswith("https://"):
try:
my_env["GIT_TERMINAL_PROMPT"] = "0"
subprocess.check_call(
["git", "clone", repo_url, dest_dir],
env=my_env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except subprocess.CalledProcessError as e:
logger.error(f"Failed to clone repository: {repo_url}")
logger.error(f"Error output:\n{e}")
raise
elif repo_url.startswith("git@") or repo_url.startswith("ssh://"):
try:
with tempfile.TemporaryDirectory() as tmp_dir:
custom_ssh_executable = Path(tmp_dir) / "ssh_batch"
with open(custom_ssh_executable, "w") as pipe:
print("#!/bin/sh", file=pipe)
print("ssh -oBatchMode=yes $@", file=pipe)
os.chmod(custom_ssh_executable, 0o511)
my_env["GIT_SSH"] = str(custom_ssh_executable)
subprocess.check_call(
["git", "clone", repo_url, dest_dir],
env=my_env,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
except subprocess.CalledProcessError as e:
del my_env["GIT_SSH"]
logger.error(f"Failed to clone repository: {repo_url}")
logger.error(f"Error output:\n{e}")
raise
def _get_studio_tags(model_id: str, hub_name: str):
return [
{
"key": "sagemaker-studio:jumpstart-model-id",
"value": model_id
},
{
"key": "sagemaker-studio:jumpstart-hub-name",
"value": hub_name
}
]
def _get_training_job_name_from_training_job_arn(training_job_arn: str) -> str:
"""Extract Training job name from Training job arn.
Args:
training_job_arn: Training job arn.
Returns: Training job name.
"""
if training_job_arn is None:
return None
pattern = "arn:aws[a-z-]*:sagemaker:[a-z0-9-]*:[0-9]{12}:training-job/(.+)"
match = re.match(pattern, training_job_arn)
if match:
return match.group(1)
return None