Source code for sagemaker.train.utils

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Utils module."""
from __future__ import absolute_import

import re
import os
import json
import subprocess
import tempfile
from pathlib import Path

from datetime import datetime
from typing import Literal, Any

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.shapes import Unassigned
from sagemaker.train import logger
from sagemaker.core.workflow.parameters import PipelineVariable


def _default_bucket_and_prefix(session: Session) -> str:
    """Helper function to get the bucket name with the corresponding prefix if applicable

    Returns a string like:
    * ``default_bucket/default_bucket_prefix`` if the prefix is set
    * ``default_bucket`` if the prefix is not set

    Args:
        session (Session): The SageMaker session to use

    Returns:
        str: The bucket name with the prefix if applicable
    """
    if session.default_bucket_prefix is not None:
        return f"{session.default_bucket()}/{session.default_bucket_prefix}"
    return session.default_bucket()


def _default_s3_uri(session: Session, additional_path: str = "") -> str:
    """Helper function to get the default S3 URI for the SageMaker session.

    Returns a string like:
    * ``s3://default_bucket/default_bucket_prefix`` if the prefix is set
    * ``s3://default_bucket`` if the prefix is not set

    Args:
        session (Session): The SageMaker session to use
        additional_path (str): Additional path to append to the S3 URI. Defaults to "".

    Returns:
        str: The default S3 URI for the SageMaker session
    """
    bucket_and_prefix = _default_bucket_and_prefix(session)
    additional_path = additional_path.lstrip("/")  # Remove leading slash if present
    return (
        f"s3://{bucket_and_prefix}/{additional_path}"
        if additional_path
        else f"s3://{bucket_and_prefix}"
    )


def _is_valid_s3_uri(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
    """Check if the path is a valid S3 URI.

    This method checks if the path is a valid S3 URI. If the path_type is specified,
    it will also check if the path is a file or a directory.
    This method does not check if the S3 bucket or object exists.

    Args:
        path (str): S3 URI to validate
        path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate.
            Defaults to "Any".

    Returns:
        bool: True if the path is a valid S3 URI, False otherwise
    """
    # Check if the path is a valid S3 URI
    if not path.startswith("s3://"):
        return False

    if path_type == "File":
        # If it's a file, it should not end with a slash
        return not path.endswith("/")
    if path_type == "Directory":
        # If it's a directory, it should end with a slash
        return path.endswith("/")

    return path_type == "Any"


def _is_valid_path(path: str, path_type: Literal["File", "Directory", "Any"] = "Any") -> bool:
    """Check if the path is a valid local path.

    Args:
        path (str): Local path to validate
        path_type (Optional(Literal["File", "Directory", "Any"])): The type of the path to validate.
            Defaults to "Any".

    Returns:
        bool: True if the path is a valid local path, False otherwise
    """
    if not os.path.exists(path):
        return False

    if path_type == "File":
        return os.path.isfile(path)
    if path_type == "Directory":
        return os.path.isdir(path)

    return path_type == "Any"


def _get_unique_name(base, max_length=63):
    """Generate a unique name based on the base name.

    This method generates a unique name based on the base name.
    The unique name is generated by appending the current timestamp
    to the base name.

    Args:
        base (str): The base name to use
        max_length (int): The maximum length of the unique name. Defaults to 63.

    Returns:
        str: The unique name
    """
    current_time = datetime.now().strftime("%Y%m%d%H%M%S")
    base = base.replace("_", "-")
    unique_name = f"{base}-{current_time}"
    unique_name = unique_name[:max_length]  # Truncate to max_length
    return unique_name


def _get_repo_name_from_image(image: str) -> str:
    """Get the repository name from the image URI.

    Example:
    ``` python
    _get_repo_name_from_image("123456789012.dkr.ecr.us-west-2.amazonaws.com/my-repo:latest")
    # Returns "my-repo"
    ```

    Args:
        image (str): The image URI

    Returns:
        str: The repository name
    """
    return image.split("/")[-1].split(":")[0].split("@")[0]


[docs] def convert_unassigned_to_none(instance) -> Any: """Convert Unassigned values to None for any instance.""" for name, value in instance.__dict__.items(): if isinstance(value, Unassigned): setattr(instance, name, None) return instance
[docs] def safe_serialize(data): """Serialize the data without wrapping strings in quotes. This function handles the following cases: 1. If `data` is a string, it returns the string as-is without wrapping in quotes. 2. If `data` is of type `PipelineVariable`, it returns the json representation of the PipelineVariable 3. If `data` is serializable (e.g., a dictionary, list, int, float), it returns the JSON-encoded string using `json.dumps()`. 4. If `data` cannot be serialized (e.g., a custom object), it returns the string representation of the data using `str(data)`. Args: data (Any): The data to serialize. Returns: str: The serialized JSON-compatible string or the string representation of the input. """ if isinstance(data, str): return data elif isinstance(data, PipelineVariable): return data try: return json.dumps(data) except TypeError: return str(data)
def _run_clone_command_silent(repo_url, dest_dir): """Run the 'git clone' command with the repo url and the directory to clone the repo into. Args: repo_url (str): Git repo url to be cloned. dest_dir: (str): Local path where the repo should be cloned into. Raises: CalledProcessError: If failed to clone git repo. """ my_env = os.environ.copy() if repo_url.startswith("https://"): try: my_env["GIT_TERMINAL_PROMPT"] = "0" subprocess.check_call( ["git", "clone", repo_url, dest_dir], env=my_env, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) except subprocess.CalledProcessError as e: logger.error(f"Failed to clone repository: {repo_url}") logger.error(f"Error output:\n{e}") raise elif repo_url.startswith("git@") or repo_url.startswith("ssh://"): try: with tempfile.TemporaryDirectory() as tmp_dir: custom_ssh_executable = Path(tmp_dir) / "ssh_batch" with open(custom_ssh_executable, "w") as pipe: print("#!/bin/sh", file=pipe) print("ssh -oBatchMode=yes $@", file=pipe) os.chmod(custom_ssh_executable, 0o511) my_env["GIT_SSH"] = str(custom_ssh_executable) subprocess.check_call( ["git", "clone", repo_url, dest_dir], env=my_env, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) except subprocess.CalledProcessError as e: del my_env["GIT_SSH"] logger.error(f"Failed to clone repository: {repo_url}") logger.error(f"Error output:\n{e}") raise def _get_studio_tags(model_id: str, hub_name: str): return [ { "key": "sagemaker-studio:jumpstart-model-id", "value": model_id }, { "key": "sagemaker-studio:jumpstart-hub-name", "value": hub_name } ] def _get_training_job_name_from_training_job_arn(training_job_arn: str) -> str: """Extract Training job name from Training job arn. Args: training_job_arn: Training job arn. Returns: Training job name. """ if training_job_arn is None: return None pattern = "arn:aws[a-z-]*:sagemaker:[a-z0-9-]*:[0-9]{12}:training-job/(.+)" match = re.match(pattern, training_job_arn) if match: return match.group(1) return None