Source code for sagemaker.serve.detector.pickle_dependencies

"""Load a pickled object to detect the dependencies it requires"""

from __future__ import absolute_import
from pathlib import Path
from typing import List
import email.parser
import email.policy
import json
import inspect
import itertools
import subprocess
import sys
import tqdm

# non native imports. Ideally add as little as possible here
# because it will add to requirements.txt
import cloudpickle
import boto3

pipcmd = [sys.executable, "-m", "pip", "--disable-pip-version-check"]


[docs] def get_all_files_for_installed_packages_pip(packages: List[str]): """Placeholder docstring""" proc = subprocess.Popen(pipcmd + ["show", "-f"] + packages, stdout=subprocess.PIPE) with proc.stdout: lines = [] for line in iter(proc.stdout.readline, b""): if line == b"---\n": yield lines lines = [] else: lines.append(line) yield lines proc.wait(timeout=10) # wait for the subprocess to exit
[docs] def get_all_files_for_installed_packages(packages: List[str]): """Placeholder docstring""" ret = {} for rawmsg in get_all_files_for_installed_packages_pip(packages): parser = email.parser.BytesParser(policy=email.policy.default) msg = parser.parsebytes(b"".join(iter(rawmsg))) if not msg.get("Files"): continue ret[msg.get("Name")] = { Path(msg.get("Location")).joinpath(x) for x in msg.get("Files").split() } return ret
[docs] def batched(iterable, n): """Batch data into tuples of length n. The last batch may be shorter.""" # batched('ABCDEFG', 3) --> ABC DEF G if n < 1: raise ValueError("n must be at least one") it = iter(iterable) while True: batch = tuple(itertools.islice(it, n)) if not batch: break yield batch
[docs] def get_all_installed_packages(): """Placeholder docstring""" proc = subprocess.run(pipcmd + ["list", "--format", "json"], stdout=subprocess.PIPE, check=True) return json.loads(proc.stdout)
[docs] def map_package_names_to_files(package_names: List[str]): """Placeholder docstring""" m = {} batch_size = 20 with tqdm.tqdm(total=len(package_names), desc="Scanning for dependencies", ncols=100) as pbar: for pkg_names in batched(package_names, batch_size): m.update(get_all_files_for_installed_packages(list(pkg_names))) pbar.update(batch_size) return m
[docs] def get_currently_used_packages(): """Placeholder docstring""" all_installed_packages = get_all_installed_packages() package_to_file_names = map_package_names_to_files([x["name"] for x in all_installed_packages]) currently_used_files = { Path(m.__file__) for m in sys.modules.values() if inspect.ismodule(m) and hasattr(m, "__file__") and m.__file__ } currently_used_packages = set() for file in currently_used_files: for package in package_to_file_names: if file in package_to_file_names[package]: currently_used_packages.add(package) return currently_used_packages
[docs] def get_requirements_for_pkl_file(pkl_path: Path, dest: Path): """Placeholder docstring""" with open(pkl_path, mode="rb") as file: cloudpickle.load(file) currently_used_packages = get_currently_used_packages() with open(dest, mode="w+") as out: for x in get_all_installed_packages(): name = x["name"] version = x["version"] # skip only for dev if name == "boto3": boto3_version = boto3.__version__ out.write(f"boto3=={boto3_version}\n") elif name in currently_used_packages: out.write(f"{name}=={version}\n")
[docs] def get_all_requirements(dest: Path): """Placeholder docstring""" all_installed_packages = get_all_installed_packages() with open(dest, mode="w+") as out: for package_info in all_installed_packages: name = package_info.get("name") version = package_info.get("version") out.write(f"{name}=={version}\n")