Source code for sagemaker.core.partner_app.auth_utils
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Partner App Auth Utils Module"""
from __future__ import absolute_import
from hashlib import sha256
import functools
from typing import Tuple, Dict
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
HEADER_CONNECTION = "Connection"
HEADER_X_AMZ_TARGET = "X-Amz-Target"
HEADER_AUTHORIZATION = "Authorization"
HEADER_PARTNER_APP_SERVER_ARN = "X-SageMaker-Partner-App-Server-Arn"
HEADER_PARTNER_APP_AUTHORIZATION = "X-Amz-Partner-App-Authorization"
HEADER_X_AMZ_CONTENT_SHA_256 = "X-Amz-Content-SHA256"
CALL_PARTNER_APP_API_ACTION = "SageMaker.CallPartnerAppApi"
PAYLOAD_BUFFER = 1024 * 1024
EMPTY_SHA256_HASH = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
UNSIGNED_PAYLOAD = "UNSIGNED-PAYLOAD"
[docs]
class PartnerAppAuthUtils:
"""Partner App Auth Utils Class"""
[docs]
@staticmethod
def get_signed_request(
sigv4: SigV4Auth, app_arn: str, url: str, method: str, headers: dict, body: object
) -> Tuple[str, Dict[str, str]]:
"""Generate the SigV4 header and add it to the request headers.
Args:
sigv4 (SigV4Auth): SigV4Auth object
app_arn (str): Application ARN
url (str): Request URL
method (str): HTTP method
headers (dict): Request headers
body (object): Request body
Returns:
tuple: (url, headers)
"""
# Move API key to X-Amz-Partner-App-Authorization
if HEADER_AUTHORIZATION in headers:
headers[HEADER_PARTNER_APP_AUTHORIZATION] = headers[HEADER_AUTHORIZATION]
# App Arn
headers[HEADER_PARTNER_APP_SERVER_ARN] = app_arn
# IAM Action
headers[HEADER_X_AMZ_TARGET] = CALL_PARTNER_APP_API_ACTION
# Body
headers[HEADER_X_AMZ_CONTENT_SHA_256] = PartnerAppAuthUtils.get_body_header(body)
# Connection header is excluded from server-side signature calculation
connection_header = headers[HEADER_CONNECTION] if HEADER_CONNECTION in headers else None
if HEADER_CONNECTION in headers:
del headers[HEADER_CONNECTION]
# Spaces are encoded as %20
url = url.replace("+", "%20")
# Calculate SigV4 header
aws_request = AWSRequest(
method=method,
url=url,
headers=headers,
data=body,
)
sigv4.add_auth(aws_request)
# Reassemble headers
final_headers = dict(aws_request.headers.items())
if connection_header is not None:
final_headers[HEADER_CONNECTION] = connection_header
return (url, final_headers)
[docs]
@staticmethod
def get_body_header(body: object):
"""Calculate the body header for the SigV4 header.
Args:
body (object): Request body
"""
if body and hasattr(body, "seek"):
position = body.tell()
read_chunksize = functools.partial(body.read, PAYLOAD_BUFFER)
checksum = sha256()
for chunk in iter(read_chunksize, b""):
checksum.update(chunk)
hex_checksum = checksum.hexdigest()
body.seek(position)
return hex_checksum
if body and not isinstance(body, bytes):
# Body is of a class we don't recognize, so don't sign the payload
return UNSIGNED_PAYLOAD
if body:
# The request serialization has ensured that
# request.body is a bytes() type.
return sha256(body).hexdigest()
# Body is None
return EMPTY_SHA256_HASH