Source code for sagemaker.serve.model_server.tensorflow_serving.prepare

"""Module for artifacts preparation for tensorflow_serving"""

from __future__ import absolute_import
from pathlib import Path
import shutil
from typing import List, Dict, Any

from sagemaker.serve.model_format.mlflow.utils import (
    _get_saved_model_path_for_tensorflow_and_keras_flavor,
    _move_contents,
)
from sagemaker.serve.detector.dependency_manager import capture_dependencies
from sagemaker.serve.validations.check_integrity import (
    compute_hash,
)
from sagemaker.core.remote_function.core.serialization import _MetaData


[docs] def prepare_for_tf_serving( model_path: str, shared_libs: List[str], dependencies: Dict[str, Any], ) -> str: """Prepares the model for serving. Args: model_path (str): Path to the model directory. shared_libs (List[str]): List of shared libraries. dependencies (Dict[str, Any]): Dictionary of dependencies. Returns: str: Secret key. """ _model_path = Path(model_path) if not _model_path.exists(): _model_path.mkdir() elif not _model_path.is_dir(): raise Exception("model_dir is not a valid directory") code_dir = _model_path.joinpath("code") code_dir.mkdir(exist_ok=True) shutil.copy2(Path(__file__).parent.joinpath("inference.py"), code_dir) shared_libs_dir = _model_path.joinpath("shared_libs") shared_libs_dir.mkdir(exist_ok=True) for shared_lib in shared_libs: shutil.copy2(Path(shared_lib), shared_libs_dir) capture_dependencies(dependencies=dependencies, work_dir=code_dir) saved_model_bundle_dir = _model_path.joinpath("1") saved_model_bundle_dir.mkdir(exist_ok=True) mlflow_saved_model_dir = _get_saved_model_path_for_tensorflow_and_keras_flavor(model_path) if not mlflow_saved_model_dir: raise ValueError("SavedModel is not found for Tensorflow or Keras flavor.") _move_contents(src_dir=mlflow_saved_model_dir, dest_dir=saved_model_bundle_dir) with open(str(code_dir.joinpath("serve.pkl")), "rb") as f: buffer = f.read() hash_value = compute_hash(buffer=buffer) with open(str(code_dir.joinpath("metadata.json")), "wb") as metadata: metadata.write(_MetaData(hash_value).to_json())