Source code for sagemaker.serve.utils.task

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Accessors to retrieve task fallback input/output schema"""
from __future__ import absolute_import

import json
import os
from typing import Any, Tuple


[docs] def retrieve_local_schemas(task: str) -> Tuple[Any, Any]: """Retrieves task sample inputs and outputs locally. Args: task (str): Required, the task name Returns: Tuple[Any, Any]: A tuple that contains the sample input, at index 0, and output schema, at index 1. Raises: ValueError: If no tasks config found or the task does not exist in the local config. """ config_dir = os.path.dirname(os.path.dirname(__file__)) task_io_config_path = os.path.join(config_dir, "schema", "task.json") try: with open(task_io_config_path) as f: task_io_config = json.load(f) task_io_schemas = task_io_config.get(task, None) if task_io_schemas is None: raise ValueError(f"Could not find {task} I/O schema.") sample_schema = ( task_io_schemas["sample_inputs"]["properties"], task_io_schemas["sample_outputs"]["properties"], ) return sample_schema except FileNotFoundError: raise ValueError("Could not find tasks config file.")