Source code for sagemaker.serve.builder.requirements_manager
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Requirements Manager class to pull in client dependencies from a .txt or .yml file"""
from __future__ import absolute_import
import logging
import os
import subprocess
from typing import Optional
logger = logging.getLogger(__name__)
[docs]
class RequirementsManager:
"""Manages dependency installation by detecting file types"""
[docs]
def capture_and_install_dependencies(self, dependencies: Optional[str] = None) -> str:
"""Detects the type of file dependencies will be installed from
If a req.txt or conda.yml file is provided, it verifies their existence and
returns the local file path
Args:
dependencies (str): Local path where dependencies file exists.
Returns:
file path of the existing or generated dependencies file
"""
_dependencies = dependencies or self._detect_conda_env_and_local_dependencies
# Dependencies specified as either req.txt or conda_env.yml
if _dependencies.endswith(".txt"):
self._install_requirements_txt()
elif _dependencies.endswith(".yml"):
self._update_conda_env_in_path()
else:
raise ValueError(f'Invalid dependencies provided: "{_dependencies}"')
def _install_requirements_txt(self):
"""Install requirements.txt file using pip"""
logger.info("Running command to pip install")
subprocess.run("pip install -r in_process_requirements.txt", shell=True, check=True)
logger.info("Command ran successfully")
def _update_conda_env_in_path(self):
"""Update conda env using conda yml file"""
logger.info("Updating conda env")
subprocess.run("conda env update -f conda_in_process.yml", shell=True, check=True)
logger.info("Conda env updated successfully")
def _get_active_conda_env_name(self) -> str:
"""Returns the conda environment name from the set environment variable. None otherwise."""
return os.getenv("CONDA_DEFAULT_ENV")
def _get_active_conda_env_prefix(self) -> str:
"""Returns the conda prefix from the set environment variable. None otherwise."""
return os.getenv("CONDA_PREFIX")
def _detect_conda_env_and_local_dependencies(self) -> str:
"""Generates dependencies list from the user's local runtime.
Raises RuntimeEnvironmentError if not able to.
Currently supports: conda environments
"""
# Try to capture dependencies from the conda environment, if any.
conda_env_name = self._get_active_conda_env_name()
logger.info("Found conda_env_name: '%s'", conda_env_name)
conda_env_prefix = None
if conda_env_name is None:
conda_env_prefix = self._get_active_conda_env_prefix()
if conda_env_name is None and conda_env_prefix is None:
local_dependencies_path = os.path.join(os.getcwd(), "in_process_requirements.txt")
logger.info(local_dependencies_path)
return local_dependencies_path
if conda_env_name == "base":
logger.warning(
"We recommend using an environment other than base to "
"isolate your project dependencies from conda dependencies"
)
local_dependencies_path = os.path.join(os.getcwd(), "conda_in_process.yml")
logger.info(local_dependencies_path)
return local_dependencies_path