Source code for sagemaker.core.s3.client

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains Enums and helper methods related to S3."""
from __future__ import print_function, absolute_import

import logging
import io

from typing import Union
from functools import reduce
from typing import Optional

from six.moves.urllib.parse import urlparse
from sagemaker.core.helper.session_helper import Session

logger = logging.getLogger("sagemaker")


[docs] class S3Uploader(object): """Contains static methods for uploading directories or files to S3."""
[docs] @staticmethod def upload(local_path, desired_s3_uri, kms_key=None, sagemaker_session=None, callback=None): """Static method that uploads a given file or directory to S3. Args: local_path (str): Path (absolute or relative) of local file or directory to upload. desired_s3_uri (str): The desired S3 location to upload to. It is the prefix to which the local filename will be added. kms_key (str): The KMS key to use to encrypt the files. sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: The S3 uri of the uploaded file(s). """ sagemaker_session = sagemaker_session or Session() bucket, key_prefix = parse_s3_url(url=desired_s3_uri) if kms_key is not None: extra_args = {"SSEKMSKeyId": kms_key, "ServerSideEncryption": "aws:kms"} else: extra_args = None return sagemaker_session.upload_data( path=local_path, bucket=bucket, key_prefix=key_prefix, callback=callback, extra_args=extra_args, )
[docs] @staticmethod def upload_string_as_file_body( body: str, desired_s3_uri=None, kms_key=None, sagemaker_session=None ): """Static method that uploads a given file or directory to S3. Args: body (str): String representing the body of the file. desired_s3_uri (str): The desired S3 uri to upload to. kms_key (str): The KMS key to use to encrypt the files. sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. Returns: str: The S3 uri of the uploaded file. """ sagemaker_session = sagemaker_session or Session() bucket, key = parse_s3_url(desired_s3_uri) sagemaker_session.upload_string_as_file_body( body=body, bucket=bucket, key=key, kms_key=kms_key ) return desired_s3_uri
[docs] @staticmethod def upload_bytes(b: Union[bytes, io.BytesIO], s3_uri, kms_key=None, sagemaker_session=None): """Static method that uploads a given file or directory to S3. Args: b (bytes or io.BytesIO): bytes. s3_uri (str): The S3 uri to upload to. kms_key (str): The KMS key to use to encrypt the files. sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: str: The S3 uri of the uploaded file. """ sagemaker_session = sagemaker_session or Session() bucket, object_key = parse_s3_url(s3_uri) if kms_key is not None: extra_args = {"SSEKMSKeyId": kms_key, "ServerSideEncryption": "aws:kms"} else: extra_args = None b = b if isinstance(b, io.BytesIO) else io.BytesIO(b) sagemaker_session.s3_resource.Bucket(bucket).upload_fileobj( b, object_key, ExtraArgs=extra_args ) return s3_uri
[docs] class S3Downloader(object): """Contains static methods for downloading directories or files from S3."""
[docs] @staticmethod def download(s3_uri, local_path, kms_key=None, sagemaker_session=None): """Static method that downloads a given S3 uri to the local machine. Args: s3_uri (str): An S3 uri to download from. local_path (str): A local path to download the file(s) to. kms_key (str): The KMS key to use to decrypt the files. sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: list[str]: List of local paths of downloaded files """ sagemaker_session = sagemaker_session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) if kms_key is not None: extra_args = {"SSECustomerKey": kms_key} else: extra_args = None return sagemaker_session.download_data( path=local_path, bucket=bucket, key_prefix=key_prefix, extra_args=extra_args )
[docs] @staticmethod def read_file(s3_uri, sagemaker_session=None) -> str: """Static method that returns the contents of a s3 uri file body as a string. Args: s3_uri (str): An S3 uri that refers to a single file. sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: str: The body of the file. """ sagemaker_session = sagemaker_session or Session() bucket, object_key = parse_s3_url(url=s3_uri) return sagemaker_session.read_s3_file(bucket=bucket, key_prefix=object_key)
[docs] @staticmethod def read_bytes(s3_uri, sagemaker_session=None) -> bytes: """Static method that returns the contents of a s3 object as bytes. Args: s3_uri (str): An S3 uri that refers to a s3 object. sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: bytes: The body of the file. """ sagemaker_session = sagemaker_session or Session() bucket, object_key = parse_s3_url(s3_uri) bytes_io = io.BytesIO() sagemaker_session.s3_resource.Bucket(bucket).download_fileobj(object_key, bytes_io) bytes_io.seek(0) return bytes_io.read()
[docs] @staticmethod def list(s3_uri, sagemaker_session=None): """Static method that lists the contents of an S3 uri. Args: s3_uri (str): The S3 base uri to list objects in. sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, one is created using the default AWS configuration chain. Returns: [str]: The list of S3 URIs in the given S3 base uri. """ sagemaker_session = sagemaker_session or Session() bucket, key_prefix = parse_s3_url(url=s3_uri) file_keys = sagemaker_session.list_s3_files(bucket=bucket, key_prefix=key_prefix) return [s3_path_join("s3://", bucket, file_key) for file_key in file_keys]
[docs] def parse_s3_url(url): """Returns an (s3 bucket, key name/prefix) tuple from a url with an s3 scheme. Args: url (str): Returns: tuple: A tuple containing: - str: S3 bucket name - str: S3 key """ parsed_url = urlparse(url) if parsed_url.scheme != "s3": raise ValueError("Expecting 's3' scheme, got: {} in {}.".format(parsed_url.scheme, url)) return parsed_url.netloc, parsed_url.path.lstrip("/")
[docs] def is_s3_url(url): """Returns True if url is an s3 url, False if not Args: url (str): Returns: bool: """ parsed_url = urlparse(url) return parsed_url.scheme == "s3"
[docs] def s3_path_join(*args, with_end_slash: bool = False): """Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix). Behavior of this function: - If the first argument is "s3://", then that is preserved. - The output by default will have no slashes at the beginning or end. There is one exception (see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield `"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"` - Any repeat slashes will be removed in the output (except for "s3://" if provided at the beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield `"s3://foo/bar/baz"`. - Empty or None arguments will be skipped. For example `s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"` Alternatives to this function that are NOT recommended for S3 paths: - `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines - `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single dots (".") and root directories. (for example `pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`) - `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes Args: *args: The strings to join with a slash. with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/" to the end of the path Returns: str: The joined string, without a slash at the end unless with_end_slash is True. """ delimiter = "/" non_empty_args = list(filter(lambda item: item is not None and item != "", args)) merged_path = "" for index, path in enumerate(non_empty_args): if ( index == 0 or (merged_path and merged_path[-1] == delimiter) or (path and path[0] == delimiter) ): # dont need to add an extra slash because either this is the beginning of the string, # or one (or more) slash already exists merged_path += path else: merged_path += delimiter + path if with_end_slash and merged_path and merged_path[-1] != delimiter: merged_path += delimiter # At this point, merged_path may include slashes at the beginning and/or end. And some of the # provided args may have had duplicate slashes inside or at the ends. # For backwards compatibility reasons, these need to be filtered out (done below). In the # future, if there is a desire to support multiple slashes for S3 paths throughout the SDK, # one option is to create a new optional argument (or a new function) that only executes the # logic above. filtered_path = merged_path # remove duplicate slashes if filtered_path: def duplicate_delimiter_remover(sequence, next_char): if sequence[-1] == delimiter and next_char == delimiter: return sequence return sequence + next_char if filtered_path.startswith("s3://"): filtered_path = reduce( duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5] ) else: filtered_path = reduce(duplicate_delimiter_remover, filtered_path) # remove beginning slashes filtered_path = filtered_path.lstrip(delimiter) # remove end slashes if not with_end_slash and filtered_path != "s3://": filtered_path = filtered_path.rstrip(delimiter) return filtered_path
[docs] def determine_bucket_and_prefix( bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None ): """Helper function that returns the correct S3 bucket and prefix to use depending on the inputs. Args: bucket (Optional[str]): S3 Bucket to use (if it exists) key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists) sagemaker_session (sagemaker.core.helper.session_helper.Session): Session to fetch a default bucket and prefix from, if bucket doesn't exist. Expected to exist Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used """ if bucket: final_bucket = bucket final_key_prefix = key_prefix else: final_bucket = sagemaker_session.default_bucket() # default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not # exist and we are using the Session's default_bucket. final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix) # We should not append default_bucket_prefix even if the bucket exists but is equal to the # default_bucket, because either: # (1) the bucket was explicitly passed in by the user and just happens to be the same as the # default_bucket (in which case we don't want to change the user's input), or # (2) the default_bucket was fetched from Session earlier already (and the default prefix # should have been fetched then as well), and then this function was # called with it. If we appended the default prefix here, we would be appending it more than # once in total. return final_bucket, final_key_prefix