# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import os
import logging
import shutil
import subprocess
import json
import re
import errno
from sagemaker.core import s3
from six.moves.urllib.parse import urlparse
logger = logging.getLogger(__name__)
STUDIO_APP_TYPES = ["KernelGateway", "CodeEditor", "JupyterLab"]
[docs]
def copy_directory_structure(destination_directory, relative_path):
"""Creates intermediate directory structure for relative_path.
Create all the intermediate directories required for relative_path to
exist within destination_directory. This assumes that relative_path is a
directory located within root_dir.
Examples:
destination_directory: /tmp/destination relative_path: test/unit/
will create: /tmp/destination/test/unit
Args:
destination_directory (str): root of the destination directory where the
directory structure will be created.
relative_path (str): relative path that will be created within
destination_directory
"""
full_path = os.path.join(destination_directory, relative_path)
os.makedirs(full_path, exist_ok=True)
[docs]
def move_to_destination(source, destination, job_name, sagemaker_session, prefix=""):
"""Move source to destination.
Can handle uploading to S3.
Args:
source (str): root directory to move
destination (str): file:// or s3:// URI that source will be moved to.
job_name (str): SageMaker job name.
sagemaker_session (sagemaker.Session): a sagemaker_session to interact
with S3 if needed
prefix (str, optional): the directory on S3 used to save files, default
to the root of ``destination``
Returns:
(str): destination URI
"""
parsed_uri = urlparse(destination)
if parsed_uri.scheme == "file":
dir_path = os.path.abspath(parsed_uri.netloc + parsed_uri.path)
recursive_copy(source, dir_path)
final_uri = destination
elif parsed_uri.scheme == "s3":
bucket = parsed_uri.netloc
path = s3.s3_path_join(parsed_uri.path, job_name, prefix)
final_uri = s3.s3_path_join("s3://", bucket, path)
sagemaker_session.upload_data(source, bucket, path)
else:
raise ValueError("Invalid destination URI, must be s3:// or file://, got: %s" % destination)
try:
shutil.rmtree(source)
except OSError as exc:
# on Linux, when docker writes to any mounted volume, it uses the container's user. In most
# cases this is root. When the container exits and we try to delete them we can't because
# root owns those files. We expect this to happen, so we handle EACCESS. Any other error
# we will raise the exception up.
if exc.errno == errno.EACCES:
logger.warning("Failed to delete: %s Please remove it manually.", source)
else:
logger.error("Failed to delete: %s", source)
raise
return final_uri
[docs]
def recursive_copy(source, destination):
"""A wrapper around shutil.copy_tree.
This won't throw any exception when the source directory does not exist.
Args:
source (str): source path
destination (str): destination path
"""
if os.path.isdir(source):
shutil.copytree(source, destination, dirs_exist_ok=True)
[docs]
def kill_child_processes(pid):
"""Kill child processes
Kills all nested child process ids for a specific pid
Args:
pid (int): process id
"""
child_pids = get_child_process_ids(pid)
for child_pid in child_pids:
os.kill(child_pid, 15)
[docs]
def get_child_process_ids(pid):
"""Retrieve all child pids for a certain pid
Recursively scan each childs process tree and add it to the output
Args:
pid (int): process id
Returns:
(List[int]): Child process ids
"""
if not str(pid).isdigit():
raise ValueError("Invalid PID")
cmd = ["pgrep", "-P", str(pid)]
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if err:
return []
pids = [int(pid) for pid in output.decode("utf-8").split()]
if pids:
for child_pid in pids:
return pids + get_child_process_ids(child_pid)
else:
return []
[docs]
def get_docker_host():
"""Discover remote docker host address (if applicable) or use "localhost"
Use "docker context inspect" to read current docker host endpoint url,
url must start with "tcp://"
Args:
Returns:
docker_host (str): Docker host DNS or IP address
"""
cmd = "docker context inspect".split()
process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
output, err = process.communicate()
if err:
return "localhost"
docker_context_string = output.decode("utf-8")
docker_context_host_url = json.loads(docker_context_string)[0]["Endpoints"]["docker"]["Host"]
parsed_url = urlparse(docker_context_host_url)
if parsed_url.hostname and parsed_url.scheme == "tcp":
return parsed_url.hostname
return "localhost"
[docs]
def get_using_dot_notation(dictionary, keys):
"""Extract `keys` from dictionary where keys is a string in dot notation.
Args:
dictionary (Dict)
keys (str)
Returns:
Nested object within dictionary as defined by "keys"
Raises:
ValueError if the provided key does not exist in input dictionary
"""
try:
if keys is None:
return dictionary
split_keys = keys.split(".", 1)
key = split_keys[0]
rest = None
if len(split_keys) > 1:
rest = split_keys[1]
bracket_accessors = re.findall(r"\[(.+?)]", key)
if bracket_accessors:
pre_bracket_key = key.split("[", 1)[0]
inner_dict = dictionary[pre_bracket_key]
else:
inner_dict = dictionary[key]
for bracket_accessor in bracket_accessors:
if (
bracket_accessor.startswith("'")
and bracket_accessor.endswith("'")
or bracket_accessor.startswith('"')
and bracket_accessor.endswith('"')
):
# key accessor
inner_key = bracket_accessor[1:-1]
else:
# list accessor
inner_key = int(bracket_accessor)
inner_dict = inner_dict[inner_key]
return get_using_dot_notation(inner_dict, rest)
except (KeyError, IndexError, TypeError):
raise ValueError(f"{keys} does not exist in input dictionary.")
[docs]
def check_for_studio():
"""Helper function to determine if the run environment is studio.
Returns (bool): Returns True if valid Studio request.
Raises:
NotImplementedError:
if run environment = Studio and AppType not in STUDIO_APP_TYPES
"""
is_studio = False
if os.path.exists("/opt/ml/metadata/resource-metadata.json"):
with open("/opt/ml/metadata/resource-metadata.json", "r") as handle:
metadata = json.load(handle)
app_type = metadata.get("AppType")
if app_type:
# check if the execution is triggered from Studio KernelGateway App
if app_type in STUDIO_APP_TYPES:
is_studio = True
else:
raise NotImplementedError(
f"AppType {app_type} in Studio does not support Local Mode."
)
# if no apptype, case of classic notebooks
return is_studio