Source code for sagemaker.core.parameter
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import
import json
from typing import Union
from sagemaker.core.helper.pipeline_variable import PipelineVariable
from sagemaker.core.common_utils import to_string
[docs]
class ParameterRange(object):
"""Base class for representing parameter ranges.
This is used to define what hyperparameters to tune for an Amazon SageMaker
hyperparameter tuning job and to verify hyperparameters for Marketplace Algorithms.
"""
__all_types__ = ("Continuous", "Categorical", "Integer")
def __init__(
self,
min_value: Union[int, float, PipelineVariable],
max_value: Union[int, float, PipelineVariable],
scaling_type: Union[str, PipelineVariable] = "Auto",
):
"""Initialize a parameter range.
Args:
min_value (float or int or PipelineVariable): The minimum value for the range.
max_value (float or int or PipelineVariable): The maximum value for the range.
scaling_type (str or PipelineVariable): The scale used for searching the range during
tuning (default: 'Auto'). Valid values: 'Auto', 'Linear',
'Logarithmic' and 'ReverseLogarithmic'.
"""
self.min_value = min_value
self.max_value = max_value
self.scaling_type = scaling_type
[docs]
def is_valid(self, value):
"""Determine if a value is valid within this ParameterRange.
Args:
value (float or int): The value to be verified.
Returns:
bool: True if valid, False otherwise.
"""
return self.min_value <= value <= self.max_value
[docs]
@classmethod
def cast_to_type(cls, value):
"""Placeholder docstring"""
return float(value)
[docs]
def as_tuning_range(self, name):
"""Represent the parameter range as a dictionary.
It is suitable for a request to create an Amazon SageMaker hyperparameter tuning job.
Args:
name (str): The name of the hyperparameter.
Returns:
dict[str, str]: A dictionary that contains the name and values of
the hyperparameter.
"""
return {
"Name": name,
"MinValue": to_string(self.min_value),
"MaxValue": to_string(self.max_value),
"ScalingType": self.scaling_type,
}
[docs]
class ContinuousParameter(ParameterRange):
"""A class for representing hyperparameters that have a continuous range of possible values.
Args:
min_value (float): The minimum value for the range.
max_value (float): The maximum value for the range.
"""
__name__ = "Continuous"
[docs]
@classmethod
def cast_to_type(cls, value):
"""Placeholder docstring"""
return float(value)
[docs]
class CategoricalParameter(ParameterRange):
"""A class for representing hyperparameters that have a discrete list of possible values."""
__name__ = "Categorical"
def __init__(self, values): # pylint: disable=super-init-not-called
"""Initialize a ``CategoricalParameter``.
Args:
values (list or object): The possible values for the hyperparameter.
This input will be converted into a list of strings.
"""
values = values if isinstance(values, list) else [values]
self.values = [to_string(v) for v in values]
[docs]
def as_tuning_range(self, name):
"""Represent the parameter range as a dictionary.
It is suitable for a request to create an Amazon SageMaker hyperparameter tuning job.
Args:
name (str): The name of the hyperparameter.
Returns:
dict[str, list[str]]: A dictionary that contains the name and values
of the hyperparameter.
"""
return {"Name": name, "Values": self.values}
[docs]
def as_json_range(self, name):
"""Represent the parameter range as a dictionary.
Dictionary is suitable for a request to create an Amazon SageMaker hyperparameter tuning job
using one of the deep learning frameworks.
The deep learning framework images require that hyperparameters be
serialized as JSON.
Args:
name (str): The name of the hyperparameter.
Returns:
dict[str, list[str]]: A dictionary that contains the name and values of the
hyperparameter, where the values are serialized as JSON.
"""
return {"Name": name, "Values": [json.dumps(v) for v in self.values]}
[docs]
def is_valid(self, value):
"""Placeholder docstring"""
return value in self.values
[docs]
@classmethod
def cast_to_type(cls, value):
"""Placeholder docstring"""
return str(value)
[docs]
class IntegerParameter(ParameterRange):
"""A class for representing hyperparameters that have an integer range of possible values.
Args:
min_value (int): The minimum value for the range.
max_value (int): The maximum value for the range.
"""
__name__ = "Integer"
[docs]
@classmethod
def cast_to_type(cls, value):
"""Placeholder docstring"""
return int(value)