# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""This module contains helper functions related to S3. You may want to use `s3.py` instead.
This has a subset of the functions available through s3.py. This module was initially created with
functions that were originally in `s3.py` so that those functions could be imported inside
`session.py` without circular dependencies. (`s3.py` imports Session as a dependency.)
"""
from __future__ import print_function, absolute_import
import logging
from functools import reduce
from typing import Optional
from six.moves.urllib.parse import urlparse
logger = logging.getLogger("sagemaker")
[docs]
def parse_s3_url(url):
"""Returns an (s3 bucket, key name/prefix) tuple from a url with an s3 scheme.
Args:
url (str):
Returns:
tuple: A tuple containing:
- str: S3 bucket name
- str: S3 key
"""
parsed_url = urlparse(url)
if parsed_url.scheme != "s3":
raise ValueError("Expecting 's3' scheme, got: {} in {}.".format(parsed_url.scheme, url))
return parsed_url.netloc, parsed_url.path.lstrip("/")
[docs]
def is_s3_url(url):
"""Returns True if url is an s3 url, False if not
Args:
url (str):
Returns:
bool:
"""
parsed_url = urlparse(url)
return parsed_url.scheme == "s3"
[docs]
def s3_path_join(*args, with_end_slash: bool = False):
"""Returns the arguments joined by a slash ("/"), similar to ``os.path.join()`` (on Unix).
Behavior of this function:
- If the first argument is "s3://", then that is preserved.
- The output by default will have no slashes at the beginning or end. There is one exception
(see `with_end_slash`). For example, `s3_path_join("/foo", "bar/")` will yield
`"foo/bar"` and `s3_path_join("foo", "bar", with_end_slash=True)` will yield `"foo/bar/"`
- Any repeat slashes will be removed in the output (except for "s3://" if provided at the
beginning). For example, `s3_path_join("s3://", "//foo/", "/bar///baz")` will yield
`"s3://foo/bar/baz"`.
- Empty or None arguments will be skipped. For example
`s3_path_join("foo", "", None, "bar")` will yield `"foo/bar"`
Alternatives to this function that are NOT recommended for S3 paths:
- `os.path.join(...)` will have different behavior on Unix machines vs non-Unix machines
- `pathlib.PurePosixPath(...)` will apply potentially unintended simplification of single
dots (".") and root directories. (for example
`pathlib.PurePosixPath("foo", "/bar/./", "baz")` would yield `"/bar/baz"`)
- `"{}/{}/{}".format(...)` and similar may result in unintended repeat slashes
Args:
*args: The strings to join with a slash.
with_end_slash (bool): (default: False) If true and if the path is not empty, appends a "/"
to the end of the path
Returns:
str: The joined string, without a slash at the end unless with_end_slash is True.
"""
delimiter = "/"
non_empty_args = list(filter(lambda item: item is not None and item != "", args))
merged_path = ""
for index, path in enumerate(non_empty_args):
if (
index == 0
or (merged_path and merged_path[-1] == delimiter)
or (path and path[0] == delimiter)
):
# dont need to add an extra slash because either this is the beginning of the string,
# or one (or more) slash already exists
merged_path += path
else:
merged_path += delimiter + path
if with_end_slash and merged_path and merged_path[-1] != delimiter:
merged_path += delimiter
# At this point, merged_path may include slashes at the beginning and/or end. And some of the
# provided args may have had duplicate slashes inside or at the ends.
# For backwards compatibility reasons, these need to be filtered out (done below). In the
# future, if there is a desire to support multiple slashes for S3 paths throughout the SDK,
# one option is to create a new optional argument (or a new function) that only executes the
# logic above.
filtered_path = merged_path
# remove duplicate slashes
if filtered_path:
def duplicate_delimiter_remover(sequence, next_char):
if sequence[-1] == delimiter and next_char == delimiter:
return sequence
return sequence + next_char
if filtered_path.startswith("s3://"):
filtered_path = reduce(
duplicate_delimiter_remover, filtered_path[5:], filtered_path[:5]
)
else:
filtered_path = reduce(duplicate_delimiter_remover, filtered_path)
# remove beginning slashes
filtered_path = filtered_path.lstrip(delimiter)
# remove end slashes
if not with_end_slash and filtered_path != "s3://":
filtered_path = filtered_path.rstrip(delimiter)
return filtered_path
[docs]
def determine_bucket_and_prefix(
bucket: Optional[str] = None, key_prefix: Optional[str] = None, sagemaker_session=None
):
"""Helper function that returns the correct S3 bucket and prefix to use depending on the inputs.
Args:
bucket (Optional[str]): S3 Bucket to use (if it exists)
key_prefix (Optional[str]): S3 Object Key Prefix to use or append to (if it exists)
sagemaker_session (sagemaker.session.Session): Session to fetch a default bucket and
prefix from, if bucket doesn't exist. Expected to exist
Returns: The correct S3 Bucket and S3 Object Key Prefix that should be used
"""
if bucket:
final_bucket = bucket
final_key_prefix = key_prefix
else:
final_bucket = sagemaker_session.default_bucket()
# default_bucket_prefix (if it exists) should be appended if (and only if) 'bucket' does not
# exist and we are using the Session's default_bucket.
final_key_prefix = s3_path_join(sagemaker_session.default_bucket_prefix, key_prefix)
# We should not append default_bucket_prefix even if the bucket exists but is equal to the
# default_bucket, because either:
# (1) the bucket was explicitly passed in by the user and just happens to be the same as the
# default_bucket (in which case we don't want to change the user's input), or
# (2) the default_bucket was fetched from Session earlier already (and the default prefix
# should have been fetched then as well), and then this function was
# called with it. If we appended the default prefix here, we would be appending it more than
# once in total.
return final_bucket, final_key_prefix