Source code for sagemaker.core.remote_function.spark_config

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module is used to define the Spark job config to remote function."""
from __future__ import absolute_import

from typing import Optional, List, Dict, Union
import attr
from urllib.parse import urlparse
from sagemaker.core.workflow import is_pipeline_variable


def _validate_configuration(instance, attribute, configuration):
    # pylint: disable=unused-argument
    """This is the helper method to validate the spark configuration"""
    if configuration:
        SparkConfigUtils.validate_configuration(configuration=configuration)


def _validate_s3_uri(instance, attribute, s3_uri):
    # pylint: disable=unused-argument
    """This is the helper method to validate the s3 uri"""
    if s3_uri:
        SparkConfigUtils.validate_s3_uri(s3_uri)


[docs] @attr.s(frozen=True) class SparkConfig: """This is the class to initialize the spark configurations for remote function Attributes: submit_jars (Optional[List[str]]): A list which contains paths to the jars which are going to be submitted to Spark job. The location can be a valid s3 uri or local path to the jar. Defaults to ``None``. submit_py_files (Optional[List[str]]): A list which contains paths to the python files which are going to be submitted to Spark job. The location can be a valid s3 uri or local path to the python file. Defaults to ``None``. submit_files (Optional[List[str]]): A list which contains paths to the files which are going to be submitted to Spark job. The location can be a valid s3 uri or local path to the python file. Defaults to ``None``. configuration (list[dict] or dict): Configuration for Hadoop, Spark, or Hive. List or dictionary of EMR-style classifications. https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html spark_event_logs_s3_uri (str): S3 path where Spark application events will be published to. """ submit_jars: Optional[List[str]] = attr.ib(default=None) submit_py_files: Optional[List[str]] = attr.ib(default=None) submit_files: Optional[List[str]] = attr.ib(default=None) configuration: Optional[Union[List[Dict], Dict]] = attr.ib( default=None, validator=_validate_configuration ) spark_event_logs_uri: Optional[str] = attr.ib(default=None, validator=_validate_s3_uri)
[docs] class SparkConfigUtils: """Util class for spark configurations""" _valid_configuration_keys = ["Classification", "Properties", "Configurations"] _valid_configuration_classifications = [ "core-site", "hadoop-env", "hadoop-log4j", "hive-env", "hive-log4j", "hive-exec-log4j", "hive-site", "spark-defaults", "spark-env", "spark-log4j", "spark-hive-site", "spark-metrics", "yarn-env", "yarn-site", "export", ]
[docs] @staticmethod def validate_configuration(configuration: Dict): """Validates the user-provided Hadoop/Spark/Hive configuration. This ensures that the list or dictionary the user provides will serialize to JSON matching the schema of EMR's application configuration Args: configuration (Dict): A dict that contains the configuration overrides to the default values. For more information, please visit: https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html """ emr_configure_apps_url = ( "https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-configure-apps.html" ) if isinstance(configuration, dict): keys = configuration.keys() if "Classification" not in keys or "Properties" not in keys: raise ValueError( f"Missing one or more required keys in configuration dictionary " f"{configuration} Please see {emr_configure_apps_url} for more information" ) for key in keys: if key not in SparkConfigUtils._valid_configuration_keys: raise ValueError( f"Invalid key: {key}. " f"Must be one of {SparkConfigUtils._valid_configuration_keys}. " f"Please see {emr_configure_apps_url} for more information." ) if key == "Classification": if ( configuration[key] not in SparkConfigUtils._valid_configuration_classifications ): raise ValueError( f"Invalid classification: {key}. Must be one of " f"{SparkConfigUtils._valid_configuration_classifications}" ) if isinstance(configuration, list): for item in configuration: SparkConfigUtils.validate_configuration(item)
# TODO (guoqioa@): method only checks urlparse scheme, need to perform deep s3 validation
[docs] @staticmethod def validate_s3_uri(spark_output_s3_path): """Validate whether the URI uses an S3 scheme. In the future, this validation will perform deeper S3 validation. Args: spark_output_s3_path (str): The URI of the Spark output S3 Path. """ if is_pipeline_variable(spark_output_s3_path): return if urlparse(spark_output_s3_path).scheme != "s3": raise ValueError( f"Invalid s3 path: {spark_output_s3_path}. Please enter something like " "s3://bucket-name/folder-name" )