Source code for sagemaker.serve.model_server.torchserve.prepare

"""Summary of MyModule.

Extended discussion of my module.
"""

from __future__ import absolute_import
import os
from pathlib import Path
import shutil
from typing import List

from sagemaker.core.helper.session_helper import Session
from sagemaker.serve.spec.inference_spec import InferenceSpec
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
    compute_hash,
)
from sagemaker.serve.validations.check_image_uri import is_1p_image_uri
from sagemaker.core.remote_function.core.serialization import _MetaData


[docs] def prepare_for_torchserve( model_path: str, shared_libs: List[str], dependencies: dict, session: Session, image_uri: str, inference_spec: InferenceSpec = None, ) -> str: """This is a one-line summary of the function. Args:to model_path (str) : Argument shared_libs (List[]) : Argument dependencies (dict) : Argument session (Session) : Argument inference_spec (InferenceSpec, optional) : Argument (default is None) Returns: ( str ) : """ model_path = Path(model_path) if not model_path.exists(): model_path.mkdir() elif not model_path.is_dir(): raise Exception("model_dir is not a valid directory") if inference_spec: inference_spec.prepare(str(model_path)) code_dir = model_path.joinpath("code") code_dir.mkdir(exist_ok=True) # https://github.com/aws/sagemaker-python-sdk/issues/4288 if is_1p_image_uri(image_uri=image_uri) and "xgboost" in image_uri: shutil.copy2(Path(__file__).parent.joinpath("xgboost_inference.py"), code_dir) os.rename( str(code_dir.joinpath("xgboost_inference.py")), str(code_dir.joinpath("inference.py")) ) else: shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) shared_libs_dir = model_path.joinpath("shared_libs") shared_libs_dir.mkdir(exist_ok=True) for shared_lib in shared_libs: shutil.copy2(Path(shared_lib), shared_libs_dir) capture_dependencies(dependencies=dependencies, work_dir=code_dir) with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json())