Source code for sagemaker.serve.detector.dependency_manager
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""SageMaker model builder dependency managing module.
This must be kept independent of SageMaker PySDK
"""
from __future__ import absolute_import
from pathlib import Path
import logging
import subprocess
import sys
import re
_SUPPORTED_SUFFIXES = [".txt"]
# TODO : Move PKL_FILE_NAME to common location
PKL_FILE_NAME = "serve.pkl"
logger = logging.getLogger(__name__)
[docs]
def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool = False):
"""Placeholder docstring"""
path = work_dir.joinpath("requirements.txt")
if "auto" in dependencies and dependencies["auto"]:
import site
pkl_path = work_dir.joinpath(PKL_FILE_NAME).resolve()
dest_path = path.resolve()
site_packages_dir = site.getsitepackages()[0]
pickle_command_dir = "/sagemaker/serve/detector"
command = [
sys.executable,
"-c",
]
if capture_all:
command.append(
f"from sagemaker.serve.detector.pickle_dependencies import get_all_requirements;"
f'get_all_requirements("{dest_path}")'
)
else:
command.append(
f"from sagemaker.serve.detector.pickle_dependencies import get_requirements_for_pkl_file;"
f'get_requirements_for_pkl_file("{pkl_path}", "{dest_path}")'
)
subprocess.run(
command,
env={"SETUPTOOLS_USE_DISTUTILS": "stdlib"},
check=True,
cwd=site_packages_dir + pickle_command_dir,
)
with open(path, "r") as f:
autodetect_depedencies = f.read().splitlines()
# autodetect_depedencies.append("sagemaker[huggingface]>=2.199")
# autodetect_depedencies = []
else:
autodetect_depedencies = []
# autodetect_depedencies = ["sagemaker[huggingface]>=2.199"]
module_version_dict = _parse_dependency_list(autodetect_depedencies)
if "requirements" in dependencies:
module_version_dict = _process_customer_provided_requirements(
requirements_file=dependencies["requirements"], module_version_dict=module_version_dict
)
if "custom" in dependencies:
module_version_dict = _process_custom_dependencies(
custom_dependencies=dependencies.get("custom"), module_version_dict=module_version_dict
)
with open(path, "w") as f:
for module, version in module_version_dict.items():
f.write(f"{module}{version}\n")
def _process_custom_dependencies(custom_dependencies: list, module_version_dict: dict):
"""Placeholder docstring"""
custom_module_version_dict = _parse_dependency_list(custom_dependencies)
module_version_dict.update(custom_module_version_dict)
return module_version_dict
def _process_customer_provided_requirements(requirements_file: str, module_version_dict: dict):
"""Placeholder docstring"""
requirements_file = Path(requirements_file)
if not requirements_file.is_file() or not _is_valid_requirement_file(requirements_file):
raise Exception(f"Path: {requirements_file} to requirements.txt doesn't exist")
logger.debug("Packaging provided requirements.txt from %s", requirements_file)
with open(requirements_file, "r") as f:
custom_dependencies = f.read().splitlines()
module_version_dict.update(_parse_dependency_list(custom_dependencies))
return module_version_dict
def _is_valid_requirement_file(path):
"""Placeholder docstring"""
# In the future, we can also check the if the content of customer provided file has valid format
for suffix in _SUPPORTED_SUFFIXES:
if path.name.endswith(suffix):
return True
return False
def _parse_dependency_list(depedency_list: list) -> dict:
"""Placeholder docstring"""
# Divide a string into 2 part, first part is the module name
# and second part is its version constraint or the url
# checkout tests/unit/sagemaker/serve/detector/test_dependency_manager.py
# for examples
pattern = r"^([\w.-]+)(@[^,\n]+|((?:[<>=!~]=?[\w.*-]+,?)+)?)$"
module_version_dict = {}
for dependency in depedency_list:
if dependency.startswith("#"):
continue
match = re.match(pattern, dependency)
if match:
package = match.group(1)
# Group 2 is either a URL or version constraint, if present
url_or_version = match.group(2) if match.group(2) else ""
module_version_dict.update({package: url_or_version})
else:
module_version_dict.update({dependency: ""})
return module_version_dict