Source code for sagemaker.core.interactive_apps.base_interactive_app

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""A base class for starting/accessing apps hosted on Amazon SageMaker Studio"""

from __future__ import absolute_import

import abc
import base64
import json
import logging
import os
import re
import webbrowser

from typing import Optional
import boto3
from sagemaker.core.helper.session_helper import Session, NOTEBOOK_METADATA_FILE

logger = logging.getLogger(__name__)


[docs] class BaseInteractiveApp(abc.ABC): """BaseInteractiveApp is a base class for creating/accessing apps hosted on SageMaker.""" def __init__( self, region: Optional[str] = None, ): """Initialize a BaseInteractiveApp object. Args: region (str): Optional. The AWS Region, e.g. us-east-1. If not specified, one is created using the default AWS configuration chain. Default: ``None`` """ if isinstance(region, str): self.region = region else: try: self.region = Session().boto_region_name except ValueError: raise ValueError( "Failed to get the Region information from the default config. Please either " "pass your Region manually as an input argument or set up the local AWS" " configuration." ) self._sagemaker_client = boto3.client("sagemaker", region_name=self.region) # Used to store domain and user profile info retrieved from Studio environment. self._domain_id = None self._user_profile_name = None self._in_studio_env = False self._get_domain_and_user() def __str__(self): """Return str(self).""" return f"{type(self).__name__}(region={self.region})" def __repr__(self): """Return repr(self).""" return self.__str__() def _get_domain_and_user(self): """Get domain id and user profile from Studio environment. To verify Studio environment, we check if NOTEBOOK_METADATA_FILE exists and domain id and user profile name are present in the file. """ if not os.path.isfile(NOTEBOOK_METADATA_FILE): return try: with open(NOTEBOOK_METADATA_FILE, "rb") as metadata_file: metadata = json.loads(metadata_file.read()) except OSError as err: logger.warning("Could not load metadata due to unexpected error. %s", err) return if "DomainId" in metadata and "UserProfileName" in metadata: self._in_studio_env = True self._domain_id = metadata.get("DomainId") self._user_profile_name = metadata.get("UserProfileName") def _get_presigned_url( self, create_presigned_url_kwargs: dict, redirect: Optional[str] = None, state: Optional[str] = None, ): """Generate a presigned URL to access a user's domain / user profile. Optional state and redirect parameters can be used to to have presigned URL automatically redirect to a specific app and provide modifying data. Args: create_presigned_url_kwargs (dict): Required. This dictionary should include the parameters that will be used when calling create_presigned_domain_url via the boto3 client. At a minimum, this should include the "DomainId" and "UserProfileName" parameters as defined by create_presigned_domain_url's documentation. Default: ``None`` redirect (str): Optional. This value will be appended to the resulting presigned URL in the format "&redirect=<redirect parameter>". This is used to automatically redirect the user into a specific Studio app. Default: ``None`` state (str): Optional. This value will be appended to the resulting presigned URL in the format "&state=<state parameter base64 encoded>". This is used to automatically apply a state to the given app. Should be used in conjuction with the redirect parameter. Default: ``None`` Returns: str: A presigned URL. """ response = self._sagemaker_client.create_presigned_domain_url(**create_presigned_url_kwargs) if response["ResponseMetadata"]["HTTPStatusCode"] == 200: url = response["AuthorizedUrl"] else: raise ValueError( "An invalid status code was returned when creating a presigned URL." f" See response for more: {response}" ) if redirect: url += f"&redirect={redirect}" if state: url += f"&state={base64.b64encode(bytes(state, 'utf-8')).decode('utf-8')}" logger.warning( "A presigned domain URL was generated. This is sensitive and should not be shared with" " others." ) return url def _open_url_in_web_browser(self, url: str): """Open a URL in the default web browser. Args: url (str): The URL to open. """ webbrowser.open(url) def _validate_job_name(self, job_name: str): """Validate training job name format. Args: job_name (str): The job name to validate. Returns: bool: Whether the supplied job name is valid. """ job_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}" if not re.fullmatch(job_name_regex, job_name): raise ValueError( f"Invalid job name. Job name must match regular expression {job_name_regex}" ) def _validate_domain_id(self, domain_id: str): """Validate domain id format. Args: domain_id (str): Required. The domain ID to validate. Returns: bool: Whether the supplied domain ID is valid. """ if domain_id is None or len(domain_id) > 63: return False return True def _validate_user_profile_name(self, user_profile_name: str): """Validate user profile name format. Args: user_profile_name (str): Required. The user profile name to validate. Returns: bool: Whether the supplied user profile name is valid. """ user_profile_name_regex = "^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}" if user_profile_name is None or not re.fullmatch( user_profile_name_regex, user_profile_name ): return False return True
[docs] @abc.abstractmethod def get_app_url(self): """Abstract method to generate a URL to help access the application in Studio. Classes that inherit from BaseInteractiveApp should implement and override with what parameters are needed for its specific use case. """