Source code for sagemaker.core.partner_app.auth_provider
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""The SageMaker partner application SDK auth module"""
from __future__ import absolute_import
import os
import re
from typing import Dict, Tuple
import boto3
from botocore.auth import SigV4Auth
from botocore.credentials import Credentials
from requests.auth import AuthBase
from requests.models import PreparedRequest
from sagemaker.core.partner_app.auth_utils import PartnerAppAuthUtils
SERVICE_NAME = "sagemaker"
AWS_PARTNER_APP_ARN_REGEX = r"arn:aws[a-z\-]*:sagemaker:[a-z0-9\-]*:[0-9]{12}:partner-app\/.*"
[docs]
class RequestsAuth(AuthBase):
"""Requests authentication class for SigV4 header generation.
This class is used to generate the SigV4 header and add it to the request headers.
"""
def __init__(self, sigv4: SigV4Auth, app_arn: str):
"""Initialize the RequestsAuth class.
Args:
sigv4 (SigV4Auth): SigV4Auth object
app_arn (str): Application ARN
"""
self.sigv4 = sigv4
self.app_arn = app_arn
def __call__(self, request: PreparedRequest) -> PreparedRequest:
"""Callback function to generate the SigV4 header and add it to the request headers.
Args:
request (PreparedRequest): PreparedRequest object
Returns:
PreparedRequest: PreparedRequest object with the SigV4 header added
"""
url, signed_headers = PartnerAppAuthUtils.get_signed_request(
sigv4=self.sigv4,
app_arn=self.app_arn,
url=request.url,
method=request.method,
headers=request.headers,
body=request.body,
)
request.url = url
request.headers.update(signed_headers)
return request
[docs]
class PartnerAppAuthProvider:
"""The SageMaker partner application SDK auth provider class"""
def __init__(self, credentials: Credentials = None):
"""Initialize the PartnerAppAuthProvider class.
Args:
credentials (Credentials, optional): AWS credentials. Defaults to None.
Raises:
ValueError: If the AWS_PARTNER_APP_ARN environment variable is not set or is invalid.
"""
self.app_arn = os.getenv("AWS_PARTNER_APP_ARN")
if self.app_arn is None:
raise ValueError("Must specify the AWS_PARTNER_APP_ARN environment variable")
app_arn_regex_match = re.search(AWS_PARTNER_APP_ARN_REGEX, self.app_arn)
if app_arn_regex_match is None:
raise ValueError("Must specify a valid AWS_PARTNER_APP_ARN environment variable")
split_arn = self.app_arn.split(":")
self.region = split_arn[3]
self.credentials = (
credentials if credentials is not None else boto3.Session().get_credentials()
)
self.sigv4 = SigV4Auth(self.credentials, SERVICE_NAME, self.region)
[docs]
def get_signed_request(
self, url: str, method: str, headers: dict, body: object
) -> Tuple[str, Dict[str, str]]:
"""Generate the SigV4 header and add it to the request headers.
Args:
url (str): Request URL
method (str): HTTP method
headers (dict): Request headers
body (object): Request body
Returns:
tuple: (url, headers)
"""
return PartnerAppAuthUtils.get_signed_request(
sigv4=self.sigv4,
app_arn=self.app_arn,
url=url,
method=method,
headers=headers,
body=body,
)
[docs]
def get_auth(self) -> RequestsAuth:
"""Returns the callback class (RequestsAuth) used for generating the SigV4 header.
Returns:
RequestsAuth: Callback Object which will calculate the header just before
request submission.
"""
return RequestsAuth(self.sigv4, os.environ["AWS_PARTNER_APP_ARN"])