Source code for sagemaker.train.remote_function.invoke_function

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""An entry point for invoking remote function inside a job."""

from __future__ import absolute_import

import argparse
import sys
import json
import os
from typing import TYPE_CHECKING

import boto3
from sagemaker.train.remote_function.job import (
    KEY_EXPERIMENT_NAME,
    KEY_RUN_NAME,
)

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.s3 import s3_path_join
from sagemaker.train.remote_function.errors import handle_error
from sagemaker.train.remote_function import logging_config
from sagemaker.train.remote_function.core.pipeline_variables import Context

if TYPE_CHECKING:
    from sagemaker.core.experiments.run import Run


SUCCESS_EXIT_CODE = 0


def _parse_args(args):
    """Parses CLI arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--region", type=str, required=True)
    parser.add_argument("--s3_base_uri", type=str, required=True)
    parser.add_argument("--s3_kms_key", type=str)
    parser.add_argument("--run_in_context", type=str)
    parser.add_argument("--pipeline_step_name", type=str)
    parser.add_argument("--pipeline_execution_id", type=str)
    parser.add_argument("--property_references", nargs="+", type=str, default=[])
    parser.add_argument(
        "--serialize_output_to_json", default=False, type=lambda x: (str(x).lower() == "true")
    )
    parser.add_argument("--func_step_s3_dir", type=str)

    args, _ = parser.parse_known_args(args)
    return args


def _get_sagemaker_session(region):
    """Get sagemaker session for interacting with AWS or Sagemaker services"""
    boto_session = boto3.session.Session(region_name=region)
    return Session(boto_session=boto_session)


def _load_run_object(run_in_context: str, sagemaker_session: Session) -> "Run":
    """Load current run in json string into run object"""
    from sagemaker.core.experiments.run import Run

    run_dict = json.loads(run_in_context)
    return Run(
        experiment_name=run_dict.get(KEY_EXPERIMENT_NAME),
        run_name=run_dict.get(KEY_RUN_NAME),
        sagemaker_session=sagemaker_session,
    )


def _load_pipeline_context(args) -> Context:
    """Load pipeline build or run context into context object"""

    pipeline_step_name = args.pipeline_step_name
    pipeline_execution_id = args.pipeline_execution_id
    property_references = args.property_references
    serialize_output_to_json = args.serialize_output_to_json
    func_step_s3_dir = args.func_step_s3_dir

    property_references_dict = {}
    for i in range(0, len(property_references), 2):
        property_references_dict[property_references[i]] = property_references[i + 1]
    return Context(
        step_name=pipeline_step_name,
        execution_id=pipeline_execution_id,
        property_references=property_references_dict,
        serialize_output_to_json=serialize_output_to_json,
        func_step_s3_dir=func_step_s3_dir,
    )


def _execute_remote_function(
    sagemaker_session, s3_base_uri, s3_kms_key, run_in_context, hmac_key, context
):
    """Execute stored remote function"""
    from sagemaker.train.remote_function.core.stored_function import StoredFunction

    stored_function = StoredFunction(
        sagemaker_session=sagemaker_session,
        s3_base_uri=s3_base_uri,
        s3_kms_key=s3_kms_key,
        hmac_key=hmac_key,
        context=context,
    )

    if run_in_context:
        run_obj = _load_run_object(run_in_context, sagemaker_session)
        with run_obj:
            stored_function.load_and_invoke()
    else:
        stored_function.load_and_invoke()


[docs] def main(sys_args=None): """Entry point for invoke function script Args: sys_args (list): List of arguments to parse. If not specified, sys.argv is used. """ logger = logging_config.get_logger() exit_code = SUCCESS_EXIT_CODE try: args = _parse_args(sys_args) region = args.region s3_base_uri = args.s3_base_uri s3_kms_key = args.s3_kms_key run_in_context = args.run_in_context pipeline_context = _load_pipeline_context(args) hmac_key = os.getenv("REMOTE_FUNCTION_SECRET_KEY") sagemaker_session = _get_sagemaker_session(region) _execute_remote_function( sagemaker_session=sagemaker_session, s3_base_uri=s3_base_uri, s3_kms_key=s3_kms_key, run_in_context=run_in_context, hmac_key=hmac_key, context=pipeline_context, ) except Exception as e: # pylint: disable=broad-except logger.exception("Error encountered while invoking the remote function.") s3_uri = ( s3_path_join(s3_base_uri, pipeline_context.execution_id, pipeline_context.step_name) if pipeline_context.step_name else s3_base_uri ) exit_code = handle_error( error=e, sagemaker_session=sagemaker_session, s3_base_uri=s3_uri, s3_kms_key=s3_kms_key, hmac_key=hmac_key, ) finally: sys.exit(exit_code)
if __name__ == "__main__": main(sys.argv[1:])