Source code for sagemaker.core.config.config_manager

# sagemaker_config.py

import pathlib
import copy
import inspect
import os
from typing import List, Optional
import boto3
import yaml
import jsonschema
from platformdirs import site_config_dir, user_config_dir
from botocore.utils import merge_dicts
from six.moves.urllib.parse import urlparse
from sagemaker.core.config.config_schema import SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA
from sagemaker.core.config.config_utils import (
    non_repeating_log_factory,
    get_sagemaker_config_logger,
    _log_sagemaker_config_single_substitution,
    _log_sagemaker_config_merge,
)
from functools import lru_cache

logger = get_sagemaker_config_logger()
log_info_function = non_repeating_log_factory(logger, "info")


[docs] class SageMakerConfig: _APP_NAME = "sagemaker" _CONFIG_FILE_NAME = "config.yaml" _DEFAULT_ADMIN_CONFIG_FILE_PATH = os.path.join(site_config_dir(_APP_NAME), _CONFIG_FILE_NAME) _DEFAULT_USER_CONFIG_FILE_PATH = os.path.join(user_config_dir(_APP_NAME), _CONFIG_FILE_NAME) _DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH = os.path.join( os.path.expanduser("~"), ".sagemaker", _CONFIG_FILE_NAME ) ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE = "SAGEMAKER_ADMIN_CONFIG_OVERRIDE" ENV_VARIABLE_USER_CONFIG_OVERRIDE = "SAGEMAKER_USER_CONFIG_OVERRIDE" S3_PREFIX = "s3://" def __init__(self): self.logger = get_sagemaker_config_logger() self.log_info_function = non_repeating_log_factory(self.logger, "info")
[docs] def load_sagemaker_config( self, additional_config_paths: Optional[List[str]] = None, s3_resource=None, repeat_log: bool = False, ) -> dict: default_config_path = os.getenv( self.ENV_VARIABLE_ADMIN_CONFIG_OVERRIDE, self._DEFAULT_ADMIN_CONFIG_FILE_PATH ) user_config_path = os.getenv( self.ENV_VARIABLE_USER_CONFIG_OVERRIDE, self._DEFAULT_USER_CONFIG_FILE_PATH ) config_paths = [default_config_path, user_config_path] if additional_config_paths: config_paths += additional_config_paths config_paths = list(filter(lambda item: item is not None, config_paths)) merged_config = {} log_info = self.log_info_function if repeat_log: log_info = self.logger.info for file_path in config_paths: config_from_file = {} if file_path.startswith(self.S3_PREFIX): config_from_file = self._load_config_from_s3(file_path, s3_resource) else: try: config_from_file = self._load_config_from_file(file_path) except ValueError as error: if file_path not in ( self._DEFAULT_ADMIN_CONFIG_FILE_PATH, self._DEFAULT_USER_CONFIG_FILE_PATH, ): raise self.logger.debug(error) if config_from_file: self.validate_sagemaker_config(config_from_file) merge_dicts(merged_config, config_from_file) log_info("Fetched defaults config from location: %s", file_path) else: log_info("Not applying SDK defaults from location: %s", file_path) return merged_config
[docs] @staticmethod def validate_sagemaker_config(sagemaker_config: Optional[dict] = None): jsonschema.validate(sagemaker_config, SAGEMAKER_PYTHON_SDK_CONFIG_SCHEMA)
[docs] def load_local_mode_config(self) -> Optional[dict]: try: content = self._load_config_from_file(self._DEFAULT_LOCAL_MODE_CONFIG_FILE_PATH) except ValueError: content = None return content
def _load_config_from_file(self, file_path: str) -> dict: inferred_file_path = file_path if os.path.isdir(file_path): inferred_file_path = os.path.join(file_path, self._CONFIG_FILE_NAME) if not os.path.exists(inferred_file_path): raise ValueError( f"Unable to load the config file from the location: {file_path}" f"Provide a valid file path" ) self.logger.debug("Fetching defaults config from location: %s", file_path) with open(inferred_file_path, "r") as f: content = yaml.safe_load(f) return content def _load_config_from_s3(self, s3_uri, s3_resource_for_config) -> dict: if not s3_resource_for_config: boto_session = boto3.DEFAULT_SESSION or boto3.Session() boto_region_name = boto_session.region_name if boto_region_name is None: raise ValueError( "Must setup local AWS configuration with a region supported by SageMaker." ) s3_resource_for_config = boto_session.resource("s3", region_name=boto_region_name) self.logger.debug("Fetching defaults config from location: %s", s3_uri) inferred_s3_uri = self._get_inferred_s3_uri(s3_uri, s3_resource_for_config) parsed_url = urlparse(inferred_s3_uri) bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") s3_object = s3_resource_for_config.Object(bucket, key_prefix) s3_file_content = s3_object.get()["Body"].read() return yaml.safe_load(s3_file_content.decode("utf-8")) def _get_inferred_s3_uri(self, s3_uri, s3_resource_for_config): parsed_url = urlparse(s3_uri) bucket, key_prefix = parsed_url.netloc, parsed_url.path.lstrip("/") s3_bucket = s3_resource_for_config.Bucket(name=bucket) s3_objects = s3_bucket.objects.filter(Prefix=key_prefix).all() s3_files_with_same_prefix = [ f"{self.S3_PREFIX}{bucket}/{s3_object.key}" for s3_object in s3_objects ] if len(s3_files_with_same_prefix) == 0: raise ValueError(f"Provide a valid S3 path instead of {s3_uri}") if len(s3_files_with_same_prefix) > 1: inferred_s3_uri = str(pathlib.PurePosixPath(s3_uri, self._CONFIG_FILE_NAME)).replace( "s3:/", "s3://" ) if inferred_s3_uri not in s3_files_with_same_prefix: raise ValueError( f"Provide an S3 URI of a directory that has a {self._CONFIG_FILE_NAME} file." ) return inferred_s3_uri return s3_uri
[docs] @staticmethod def get_config_value(key_path, config): """Placeholder Docstring""" if config is None: return None current_section = config for key in key_path.split("."): if key in current_section: current_section = current_section[key] else: return None return current_section
[docs] @staticmethod def get_nested_value(dictionary: dict, nested_keys: List[str]): """Returns a nested value from the given dictionary, and None if none present. Raises ValueError if the dictionary structure does not match the nested_keys """ if ( dictionary is not None and isinstance(dictionary, dict) and nested_keys is not None and len(nested_keys) > 0 ): current_section = dictionary for key in nested_keys[:-1]: current_section = current_section.get(key, None) if current_section is None: # means the full path of nested_keys doesnt exist in the dictionary # or the value was set to None return None if not isinstance(current_section, dict): raise ValueError( "Unexpected structure of dictionary.", "Expected value of type dict at key '{}' but got '{}' for dict '{}'".format( key, current_section, dictionary ), ) return current_section.get(nested_keys[-1], None) return None
[docs] @staticmethod def set_nested_value(dictionary: dict, nested_keys: List[str], value_to_set: object): """Sets a nested value in a dictionary. This sets a nested value inside the given dictionary and returns the new dictionary. Note: if provided an unintended list of nested keys, this can overwrite an unexpected part of the dict. Recommended to use after a check with get_nested_value first """ if dictionary is None: dictionary = {} if ( dictionary is not None and isinstance(dictionary, dict) and nested_keys is not None and len(nested_keys) > 0 ): current_section = dictionary for key in nested_keys[:-1]: if ( key not in current_section or current_section[key] is None or not isinstance(current_section[key], dict) ): current_section[key] = {} current_section = current_section[key] current_section[nested_keys[-1]] = value_to_set return dictionary
[docs] def resolve_value_from_config( self, direct_input=None, config_path: str = None, default_value=None, sagemaker_session=None, sagemaker_config: dict = None, ): """Decides which value for the caller to use. Note: This method incorporates information from the sagemaker config. Uses this order of prioritization: 1. direct_input 2. config value 3. default_value 4. None Args: direct_input: The value that the caller of this method starts with. Usually this is an input to the caller's class or method. config_path (str): A string denoting the path used to lookup the value in the sagemaker config. default_value: The value used if not present elsewhere. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). sagemaker_config (dict): The sdk defaults config that is normally accessed through a Session object by doing `session.sagemaker_config`. (default: None) This parameter will be checked for the config value if (and only if) sagemaker_session is None. This parameter exists for the rare cases where the user provided no Session but a default Session cannot be initialized before config injection is needed. In that case, the config dictionary may be loaded and passed here before a default Session object is created. Returns: The value that should be used by the caller """ config_value = ( self.get_sagemaker_config_value( sagemaker_session, config_path, sagemaker_config=sagemaker_config ) if config_path else None ) _log_sagemaker_config_single_substitution(direct_input, config_value, config_path) if direct_input is not None: return direct_input if config_value is not None: return config_value return default_value
[docs] def get_sagemaker_config_value(self, sagemaker_session, key, sagemaker_config: dict = None): """Returns the value that corresponds to the provided key from the configuration file. Args: key: Key Path of the config file entry. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions. sagemaker_config (dict): The sdk defaults config that is normally accessed through a Session object by doing `session.sagemaker_config`. (default: None) This parameter will be checked for the config value if (and only if) sagemaker_session is None. This parameter exists for the rare cases where no Session provided but a default Session cannot be initialized before config injection is needed. In that case, the config dictionary may be loaded and passed here before a default Session object is created. Returns: object: The corresponding default value in the configuration file. """ if sagemaker_session and hasattr(sagemaker_session, "sagemaker_config"): config_to_check = sagemaker_session.sagemaker_config else: config_to_check = sagemaker_config if not config_to_check: return None self.validate_sagemaker_config(config_to_check) config_value = self.get_config_value(key, config_to_check) # Copy the value so any modifications to the output will not modify the source config return copy.deepcopy(config_value)
[docs] def resolve_class_attribute_from_config( self, clazz: Optional[type], instance: Optional[object], attribute: str, config_path: str, default_value=None, sagemaker_session=None, ): """Utility method that merges config values to data classes. Takes an instance of a class and, if not already set, sets the instance's attribute to a value fetched from the sagemaker_config or the default_value. Uses this order of prioritization to determine what the value of the attribute should be: 1. current value of attribute 2. config value 3. default_value 4. does not set it Args: clazz (Optional[type]): Class of 'instance'. Used to generate a new instance if the instance is None. If None is provided here, no new object will be created if 'instance' doesnt exist. Note: if provided, the constructor should set default values to None; Otherwise, the constructor's non-None default will be left as-is even if a config value was defined. instance (Optional[object]): instance of the Class 'clazz' that has an attribute of 'attribute' to set attribute (str): attribute of the instance to set if not already set config_path (str): a string denoting the path to use to lookup the config value in the sagemaker config default_value: the value to use if not present elsewhere sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). Returns: The updated class instance that should be used by the caller instead of the 'instance' parameter that was passed in. """ config_value = self.get_sagemaker_config_value(sagemaker_session, config_path) if config_value is None and default_value is None: # return instance unmodified. Could be None or populated return instance if instance is None: if clazz is None or not inspect.isclass(clazz): return instance # construct a new instance if the instance does not exist instance = clazz() if not hasattr(instance, attribute): raise TypeError( "Unexpected structure of object.", "Expected attribute {} to be present inside instance {} of class {}".format( attribute, instance, clazz ), ) current_value = getattr(instance, attribute) if current_value is None: # only set value if object does not already have a value set if config_value is not None: setattr(instance, attribute, config_value) elif default_value is not None: setattr(instance, attribute, default_value) _log_sagemaker_config_single_substitution(current_value, config_value, config_path) return instance
[docs] def resolve_nested_dict_value_from_config( self, dictionary: dict, nested_keys: List[str], config_path: str, default_value: object = None, sagemaker_session=None, ): """Utility method that sets the value of a key path in a nested dictionary . This method takes a dictionary and, if not already set, sets the value for the provided list of nested keys to the value fetched from the sagemaker_config or the default_value. Uses this order of prioritization to determine what the value of the attribute should be: (1) current value of nested key, (2) config value, (3) default_value, (4) does not set it Args: dictionary: The dict to update. nested_keys: The paths of keys where the value should be checked and set if needed. config_path (str): A string denoting the path used to find the config value in the sagemaker config. default_value: The value to use if not present elsewhere. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). Returns: The updated dictionary that should be used by the caller instead of the 'dictionary' parameter that was passed in. """ config_value = self.get_sagemaker_config_value(sagemaker_session, config_path) if config_value is None and default_value is None: # if there is nothing to set, return early. And there is no need to traverse through # the dictionary or add nested dicts to it return dictionary try: current_nested_value = self.get_nested_value(dictionary, nested_keys) except ValueError as e: logger.error("Failed to check dictionary for applying sagemaker config: %s", e) return dictionary if current_nested_value is None: # only set value if not already set if config_value is not None: dictionary = self.set_nested_value(dictionary, nested_keys, config_value) elif default_value is not None: dictionary = self.set_nested_value(dictionary, nested_keys, default_value) _log_sagemaker_config_single_substitution(current_nested_value, config_value, config_path) return dictionary
[docs] def update_list_of_dicts_with_values_from_config( self, input_list, config_key_path, required_key_paths: List[str] = None, union_key_paths: List[List[str]] = None, sagemaker_session=None, ): """Updates a list of dictionaries with missing values that are present in Config. In some cases, config file might introduce new parameters which requires certain other parameters to be provided as part of the input list. Without those parameters, the underlying service will throw an exception. This method provides the capability to specify required key paths. In some other cases, config file might introduce new parameters but the service API requires either an existing parameter or the new parameter that was supplied by config but not both Args: input_list: The input list that was provided as a method parameter. config_key_path: The Key Path in the Config file that corresponds to the input_list parameter. required_key_paths (List[str]): List of required key paths that should be verified in the merged output. If a required key path is missing, we will not perform the merge for that item. union_key_paths (List[List[str]]): List of List of Key paths for which we need to verify whether exactly zero/one of the parameters exist. For example: If the resultant dictionary can have either 'X1' or 'X2' as parameter or neither but not both, then pass [['X1', 'X2']] sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). Returns: No output. In place merge happens. """ if not input_list: return inputs_copy = copy.deepcopy(input_list) inputs_from_config = ( self.get_sagemaker_config_value(sagemaker_session, config_key_path) or [] ) unmodified_inputs_from_config = copy.deepcopy(inputs_from_config) for i in range(min(len(input_list), len(inputs_from_config))): dict_from_inputs = input_list[i] dict_from_config = inputs_from_config[i] merge_dicts(dict_from_config, dict_from_inputs) # Check if required key paths are present in merged dict (dict_from_config) required_key_path_check_passed = self._validate_required_paths_in_a_dict( dict_from_config, required_key_paths ) if not required_key_path_check_passed: # Don't do the merge, config is introducing a new parameter which needs a # corresponding required parameter. continue union_key_path_check_passed = self._validate_union_key_paths_in_a_dict( dict_from_config, union_key_paths ) if not union_key_path_check_passed: # Don't do the merge, Union parameters are not obeyed. continue input_list[i] = dict_from_config _log_sagemaker_config_merge( source_value=inputs_copy, config_value=unmodified_inputs_from_config, merged_source_and_config_value=input_list, config_key_path=config_key_path, )
def _validate_required_paths_in_a_dict( self, source_dict, required_key_paths: List[str] = None ) -> bool: """Placeholder docstring""" if not required_key_paths: return True for required_key_path in required_key_paths: if self.get_config_value(required_key_path, source_dict) is None: return False return True def _validate_union_key_paths_in_a_dict( self, source_dict, union_key_paths: List[List[str]] = None ) -> bool: """Placeholder docstring""" if not union_key_paths: return True for union_key_path in union_key_paths: union_parameter_present = False for key_path in union_key_path: if self.get_config_value(key_path, source_dict): if union_parameter_present: return False union_parameter_present = True return True
[docs] def update_nested_dictionary_with_values_from_config( self, source_dict, config_key_path, sagemaker_session=None ) -> dict: """Updates a nested dictionary with missing values that are present in Config. Args: source_dict: The input nested dictionary that was provided as method parameter. config_key_path: The Key Path in the Config file which corresponds to this source_dict parameter. sagemaker_session (sagemaker.session.Session): A SageMaker Session object, used for SageMaker interactions (default: None). Returns: dict: The merged nested dictionary that is updated with missing values that are present in the Config file. """ inferred_config_dict = ( self.get_sagemaker_config_value(sagemaker_session, config_key_path) or {} ) original_config_dict_value = copy.deepcopy(inferred_config_dict) merge_dicts(inferred_config_dict, source_dict or {}) if original_config_dict_value == {}: # The config value is empty. That means either # (1) inferred_config_dict equals source_dict, or # (2) if source_dict was None, inferred_config_dict equals {} # We should return whatever source_dict was to be safe. Because if for example, # a VpcConfig is set to {} instead of None, some boto calls will fail due to # ParamValidationError (because a VpcConfig was specified but required parameters for # the VpcConfig were missing.) # Don't need to print because no config value was used or defined return source_dict _log_sagemaker_config_merge( source_value=source_dict, config_value=original_config_dict_value, merged_source_and_config_value=inferred_config_dict, config_key_path=config_key_path, ) return inferred_config_dict
[docs] @lru_cache(maxsize=None) def load_default_configs_for_resource_name(self, resource_name: str): configs_data = self.load_sagemaker_config() if not configs_data: logger.debug("No default configurations found for resource: %s", resource_name) return {} return configs_data["SageMaker"]["PythonSDK"]["Resources"].get(resource_name)
[docs] def get_resolved_config_value(self, attribute, resource_defaults, global_defaults): if resource_defaults and attribute in resource_defaults: return resource_defaults[attribute] if global_defaults and attribute in global_defaults: return global_defaults[attribute] logger.debug( f"Configurable value {attribute} not entered in parameters or present in the Config" ) return None