Source code for sagemaker.core.local.utils

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import

import os
import logging
import shutil
import subprocess
import json
import re
import errno

from sagemaker.core import s3
from six.moves.urllib.parse import urlparse


logger = logging.getLogger(__name__)

STUDIO_APP_TYPES = ["KernelGateway", "CodeEditor", "JupyterLab"]


[docs] def copy_directory_structure(destination_directory, relative_path): """Creates intermediate directory structure for relative_path. Create all the intermediate directories required for relative_path to exist within destination_directory. This assumes that relative_path is a directory located within root_dir. Examples: destination_directory: /tmp/destination relative_path: test/unit/ will create: /tmp/destination/test/unit Args: destination_directory (str): root of the destination directory where the directory structure will be created. relative_path (str): relative path that will be created within destination_directory """ full_path = os.path.join(destination_directory, relative_path) os.makedirs(full_path, exist_ok=True)
[docs] def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""): """Move source to destination. Can handle uploading to S3. Args: source (str): root directory to move destination (str): file:// or s3:// URI that source will be moved to. job_name (str): SageMaker job name. sagemaker_session (sagemaker.Session): a sagemaker_session to interact with S3 if needed prefix (str, optional): the directory on S3 used to save files, default to the root of ``destination`` Returns: (str): destination URI """ parsed_uri = urlparse(destination) if parsed_uri.scheme == "file": dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path) recursive_copy(source, dir_path) final_uri = destination elif parsed_uri.scheme == "s3": bucket = parsed_uri.netloc path = s3.s3_path_join(parsed_uri.path, job_name, prefix) final_uri = s3.s3_path_join("s3://", bucket, path) sagemaker_session.upload_data(source, bucket, path) else: raise ValueError("Invalid destination URI, must be s3:// or file://, got: %s" % destination) try: shutil.rmtree(source) except OSError as exc: # on Linux, when docker writes to any mounted volume, it uses the container's user. In most # cases this is root. When the container exits and we try to delete them we can't because # root owns those files. We expect this to happen, so we handle EACCESS. Any other error # we will raise the exception up. if exc.errno == errno.EACCES: logger.warning("Failed to delete: %s Please remove it manually.", source) else: logger.error("Failed to delete: %s", source) raise return final_uri
[docs] def recursive_copy(source, destination): """A wrapper around shutil.copy_tree. This won't throw any exception when the source directory does not exist. Args: source (str): source path destination (str): destination path """ if os.path.isdir(source): shutil.copytree(source, destination, dirs_exist_ok=True)
[docs] def kill_child_processes(pid): """Kill child processes Kills all nested child process ids for a specific pid Args: pid (int): process id """ child_pids = get_child_process_ids(pid) for child_pid in child_pids: os.kill(child_pid, 15)
[docs] def get_child_process_ids(pid): """Retrieve all child pids for a certain pid Recursively scan each childs process tree and add it to the output Args: pid (int): process id Returns: (List[int]): Child process ids """ if not str(pid).isdigit(): raise ValueError("Invalid PID") cmd = ["pgrep", "-P", str(pid)] process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, err = process.communicate() if err: return [] pids = [int(pid) for pid in output.decode("utf-8").split()] if pids: for child_pid in pids: return pids + get_child_process_ids(child_pid) else: return []
[docs] def get_docker_host(): """Discover remote docker host address (if applicable) or use "localhost" Use "docker context inspect" to read current docker host endpoint url, url must start with "tcp://" Args: Returns: docker_host (str): Docker host DNS or IP address """ cmd = "docker context inspect".split() process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) output, err = process.communicate() if err: return "localhost" docker_context_string = output.decode("utf-8") docker_context_host_url = json.loads(docker_context_string)[0]["Endpoints"]["docker"]["Host"] parsed_url = urlparse(docker_context_host_url) if parsed_url.hostname and parsed_url.scheme == "tcp": return parsed_url.hostname return "localhost"
[docs] def get_using_dot_notation(dictionary, keys): """Extract `keys` from dictionary where keys is a string in dot notation. Args: dictionary (Dict) keys (str) Returns: Nested object within dictionary as defined by "keys" Raises: ValueError if the provided key does not exist in input dictionary """ try: if keys is None: return dictionary split_keys = keys.split(".", 1) key = split_keys[0] rest = None if len(split_keys) > 1: rest = split_keys[1] bracket_accessors = re.findall(r"\[(.+?)]", key) if bracket_accessors: pre_bracket_key = key.split("[", 1)[0] inner_dict = dictionary[pre_bracket_key] else: inner_dict = dictionary[key] for bracket_accessor in bracket_accessors: if ( bracket_accessor.startswith("'") and bracket_accessor.endswith("'") or bracket_accessor.startswith('"') and bracket_accessor.endswith('"') ): # key accessor inner_key = bracket_accessor[1:-1] else: # list accessor inner_key = int(bracket_accessor) inner_dict = inner_dict[inner_key] return get_using_dot_notation(inner_dict, rest) except (KeyError, IndexError, TypeError): raise ValueError(f"{keys} does not exist in input dictionary.")
[docs] def check_for_studio(): """Helper function to determine if the run environment is studio. Returns (bool): Returns True if valid Studio request. Raises: NotImplementedError: if run environment = Studio and AppType not in STUDIO_APP_TYPES """ is_studio = False if os.path.exists("/opt/ml/metadata/resource-metadata.json"): with open("/opt/ml/metadata/resource-metadata.json", "r") as handle: metadata = json.load(handle) app_type = metadata.get("AppType") if app_type: # check if the execution is triggered from Studio KernelGateway App if app_type in STUDIO_APP_TYPES: is_studio = True else: raise NotImplementedError( f"AppType {app_type} in Studio does not support Local Mode." ) # if no apptype, case of classic notebooks return is_studio