# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Contains the SageMaker Experiment class."""
from __future__ import absolute_import
import time
from botocore.exceptions import ClientError
from sagemaker.core.apiutils import _base_types
from sagemaker.core.experiments.trial import _Trial
from sagemaker.core.experiments.trial_component import _TrialComponent
from sagemaker.core.common_utils import format_tags
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature
[docs]
class Experiment(_base_types.Record):
"""An Amazon SageMaker experiment, which is a collection of related trials.
New experiments are created by calling `experiments.experiment.Experiment.create`.
Existing experiments can be reloaded by calling `experiments.experiment.Experiment.load`.
Attributes:
experiment_name (str): The name of the experiment. The name must be unique
within an account.
display_name (str): Name of the experiment that will appear in UI,
such as SageMaker Studio.
description (str): A description of the experiment.
tags (List[Dict[str, str]]): A list of tags to associate with the experiment.
"""
experiment_name = None
display_name = None
description = None
tags = None
_boto_create_method = "create_experiment"
_boto_load_method = "describe_experiment"
_boto_update_method = "update_experiment"
_boto_delete_method = "delete_experiment"
_boto_update_members = ["experiment_name", "description", "display_name"]
_boto_delete_members = ["experiment_name"]
_MAX_DELETE_ALL_ATTEMPTS = 3
[docs]
def save(self):
"""Save the state of this Experiment to SageMaker.
Returns:
dict: Update experiment API response.
"""
return self._invoke_api(self._boto_update_method, self._boto_update_members)
[docs]
def delete(self):
"""Delete this Experiment from SageMaker.
Deleting an Experiment does not delete associated Trials and their Trial Components.
It requires that each Trial in the Experiment is first deleted.
Returns:
dict: Delete experiment API response.
"""
return self._invoke_api(self._boto_delete_method, self._boto_delete_members)
[docs]
@classmethod
def load(cls, experiment_name, sagemaker_session=None):
"""Load an existing experiment and return an `Experiment` object representing it.
Args:
experiment_name: (str): Name of the experiment
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
Returns:
experiments.experiment.Experiment: A SageMaker `Experiment` object
"""
return cls._construct(
cls._boto_load_method,
experiment_name=experiment_name,
sagemaker_session=sagemaker_session,
)
[docs]
@classmethod
@_telemetry_emitter(feature=Feature.MLOPS, func_name="experiment.create")
def create(
cls,
experiment_name,
display_name=None,
description=None,
tags=None,
sagemaker_session=None,
):
"""Create a new experiment in SageMaker and return an `Experiment` object.
Args:
experiment_name: (str): Name of the experiment. Must be unique. Required.
display_name: (str): Name of the experiment that will appear in UI,
such as SageMaker Studio (default: None).
description: (str): Description of the experiment (default: None).
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
tags (Optional[Tags]): A list of tags to associate with the experiment
(default: None).
Returns:
experiments.experiment.Experiment: A SageMaker `Experiment` object
"""
return cls._construct(
cls._boto_create_method,
experiment_name=experiment_name,
display_name=display_name,
description=description,
tags=format_tags(tags),
sagemaker_session=sagemaker_session,
)
@classmethod
def _load_or_create(
cls,
experiment_name,
display_name=None,
description=None,
tags=None,
sagemaker_session=None,
):
"""Load an experiment by name and create a new one if it does not exist.
Args:
experiment_name: (str): Name of the experiment. Must be unique. Required.
display_name: (str): Name of the experiment that will appear in UI,
such as SageMaker Studio (default: None). This is used only when the
given `experiment_name` does not exist and a new experiment has to be created.
description: (str): Description of the experiment (default: None).
This is used only when the given `experiment_name` does not exist and
a new experiment has to be created.
sagemaker_session (sagemaker.core.helper.session_helper.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, one is created using the
default AWS configuration chain.
tags (Optional[Tags]): A list of tags to associate with the experiment
(default: None). This is used only when the given `experiment_name` does not
exist and a new experiment has to be created.
Returns:
experiments.experiment.Experiment: A SageMaker `Experiment` object
"""
try:
experiment = Experiment.create(
experiment_name=experiment_name,
display_name=display_name,
description=description,
tags=format_tags(tags),
sagemaker_session=sagemaker_session,
)
except ClientError as ce:
error_code = ce.response["Error"]["Code"]
error_message = ce.response["Error"]["Message"]
if not (error_code == "ValidationException" and "already exists" in error_message):
raise ce
# already exists
experiment = Experiment.load(experiment_name, sagemaker_session)
return experiment
[docs]
def list_trials(self, created_before=None, created_after=None, sort_by=None, sort_order=None):
"""List trials in this experiment matching the specified criteria.
Args:
created_before (datetime.datetime): Return trials created before this instant
(default: None).
created_after (datetime.datetime): Return trials created after this instant
(default: None).
sort_by (str): Which property to sort results by. One of 'Name', 'CreationTime'
(default: None).
sort_order (str): One of 'Ascending', or 'Descending' (default: None).
Returns:
collections.Iterator[experiments._api_types.TrialSummary] :
An iterator over trials matching the criteria.
"""
return _Trial.list(
experiment_name=self.experiment_name,
created_before=created_before,
created_after=created_after,
sort_by=sort_by,
sort_order=sort_order,
sagemaker_session=self.sagemaker_session,
)
def _delete_all(self, action):
"""Force to delete the experiment and associated trials, trial components.
Args:
action (str): The string '--force' is required to pass in to confirm recursively
delete the experiments, and all its trials and trial components.
"""
if action != "--force":
raise ValueError(
"Must confirm with string '--force' in order to delete the experiment and "
"associated trials, trial components."
)
delete_attempt_count = 0
last_exception = None
while True:
if delete_attempt_count == self._MAX_DELETE_ALL_ATTEMPTS:
raise Exception("Failed to delete, please try again.") from last_exception
try:
for trial_summary in self.list_trials():
trial = _Trial.load(
sagemaker_session=self.sagemaker_session,
trial_name=trial_summary.trial_name,
)
for (
trial_component_summary
) in trial.list_trial_components(): # pylint: disable=no-member
tc = _TrialComponent.load(
sagemaker_session=self.sagemaker_session,
trial_component_name=trial_component_summary.trial_component_name,
)
tc.delete(force_disassociate=True)
# to prevent throttling
time.sleep(1.2)
trial.delete() # pylint: disable=no-member
# to prevent throttling
time.sleep(1.2)
self.delete()
break
except Exception as ex: # pylint: disable=broad-except
last_exception = ex
finally:
delete_attempt_count = delete_attempt_count + 1