Source code for sagemaker.core.workflow.functions

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The step definitions for workflow."""
from __future__ import absolute_import

from typing import List, Union, Optional, TYPE_CHECKING

import attr

from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.workflow.execution_variables import ExecutionVariable
from sagemaker.core.workflow.parameters import Parameter
from sagemaker.core.workflow.properties import PropertyFile, Properties


[docs] def is_pipeline_variable(var: object) -> bool: """Check if the variable is a pipeline variable Args: var (object): The variable to be verified. Returns: bool: True if it is, False otherwise. """ return isinstance(var, PipelineVariable)
if TYPE_CHECKING: from sagemaker.mlops.workflow.steps import Step
[docs] @attr.s class Join(PipelineVariable): """Join together properties. Examples: Build a Amazon S3 Uri with bucket name parameter and pipeline execution Id and use it as training input:: bucket = ParameterString('bucket', default_value='my-bucket') TrainingInput( s3_data=Join( on='/', values=['s3:/', bucket, ExecutionVariables.PIPELINE_EXECUTION_ID] ), content_type="text/csv") Attributes: values (List[Union[PrimitiveType, Parameter, PipelineVariable]]): The primitive type values, parameters, step properties, expressions to join. on (str): The string to join the values on (Defaults to ""). """ on: str = attr.ib(factory=str) values: List = attr.ib(factory=list)
[docs] def to_string(self) -> PipelineVariable: """Prompt the pipeline to convert the pipeline variable to String in runtime As Join is treated as String in runtime, no extra actions are needed. """ return self
@property def expr(self): """The expression dict for a `Join` function.""" return { "Std:Join": { "On": self.on, "Values": [ value.expr if hasattr(value, "expr") else value for value in self.values ], }, } @property def _referenced_steps(self) -> List[Union["Step", str]]: """List of step names that this function depends on.""" steps = [] for value in self.values: if isinstance(value, PipelineVariable): steps.extend(value._referenced_steps) return steps
[docs] @attr.s class JsonGet(PipelineVariable): """Get JSON properties from PropertyFiles or S3 location. Attributes: step_name (str): The step name from which to get the property file. property_file (Optional[Union[PropertyFile, str]]): Either a PropertyFile instance or the name of a property file. json_path (str): The JSON path expression to the requested value. s3_uri (Optional[sagemaker.workflow.functions.Join]): The S3 location from which to fetch a Json file. The Json file is the output of a step defined with ``@step`` decorator. step (Step): The upstream step object which the s3_uri is associated to. """ # pylint: disable=W0613 def _check_property_file_s3_uri(self, attribute, value): """Validate mutually exclusive property file / s3uri""" if self.property_file and self.s3_uri: raise ValueError( "Please specify either a property file or s3 uri as an input, but not both." ) if not self.property_file and not self.s3_uri: raise ValueError( "Missing s3uri or property file as a required input to JsonGet." "Please specify either a property file or s3 uri as an input, but not both." ) if self.s3_uri: self._validate_json_get_s3_uri() step_name: str = attr.ib(default=None) property_file: Optional[Union[PropertyFile, str]] = attr.ib( default=None, validator=_check_property_file_s3_uri ) json_path: str = attr.ib(default=None) s3_uri: Optional[Join] = attr.ib(default=None, validator=_check_property_file_s3_uri) step: "Step" = attr.ib(default=None) # pylint: disable=R1710 @property def expr(self): """The expression dict for a `JsonGet` function.""" if self.property_file: if not isinstance(self.step_name, str) or not self.step_name: raise ValueError("Please give a valid step name as a string.") if isinstance(self.property_file, PropertyFile): name = self.property_file.name else: name = self.property_file return { "Std:JsonGet": { "PropertyFile": {"Get": f"Steps.{self.step_name}.PropertyFiles.{name}"}, "Path": self.json_path, } } # ConditionStep uses a JoinFunction to provide this non-static, built s3Uri in # the case of Lightsaber steps. if self.s3_uri: return { "Std:JsonGet": { "S3Uri": ( self.s3_uri.expr if isinstance(self.s3_uri, PipelineVariable) else self.s3_uri ), "Path": self.json_path, } } @property def _referenced_steps(self) -> List[Union["Step", str]]: """List of step that this function depends on.""" if self.step: return [self.step] if self.step_name: return [self.step_name] return [] def _validate_json_get_s3_uri(self): """Validate the s3 uri in JsonGet""" s3_uri = self.s3_uri if not isinstance(s3_uri, Join): raise ValueError( f"Invalid JsonGet function {self.expr}. JsonGet " "function's s3_uri can only be a sagemaker.workflow.functions.Join object." ) for join_arg in s3_uri.values: if not is_pipeline_variable(join_arg): continue if not isinstance(join_arg, (Parameter, ExecutionVariable, Properties)): raise ValueError( f"Invalid JsonGet function {self.expr}. " f"The Join values in JsonGet's s3_uri can only be a primitive object, " f"Parameter, ExecutionVariable or Properties." )