Source code for sagemaker.train.container_drivers.distributed_drivers.basic_script_driver

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module is the entry point for the Basic Script Driver."""
from __future__ import absolute_import

import os
import sys
import json
import shlex

from pathlib import Path
from typing import List

sys.path.insert(0, str(Path(__file__).parent.parent))

from common.utils import (  # noqa: E402 # pylint: disable=C0413,E0611
    logger,
    get_python_executable,
    write_failure_file,
    hyperparameters_to_cli_args,
    execute_commands,
)


[docs] def create_commands() -> List[str]: """Create the commands to execute.""" entry_script = os.environ["SM_ENTRY_SCRIPT"] hyperparameters = json.loads(os.environ["SM_HPS"]) python_executable = get_python_executable() args = hyperparameters_to_cli_args(hyperparameters) if entry_script.endswith(".py"): commands = [python_executable, entry_script] commands += args elif entry_script.endswith(".sh"): args_str = " ".join(shlex.quote(arg) for arg in args) commands = [ "/bin/sh", "-c", f"chmod +x {entry_script} && ./{entry_script} {args_str}", ] else: raise ValueError( f"Unsupported entry script type: {entry_script}. Only .py and .sh are supported." ) return commands
[docs] def main(): """Main function for the Basic Script Driver. This function is the entry point for the Basic Script Driver. Execution Lifecycle: 1. Read the source code and hyperparameters JSON files. 2. Set hyperparameters as command line arguments. 3. Create the commands to execute. 4. Execute the commands. """ cmd = create_commands() logger.info(f"Executing command: {' '.join(cmd)}") exit_code, traceback = execute_commands(cmd) if exit_code != 0: write_failure_file(traceback) sys.exit(exit_code)
if __name__ == "__main__": main()