Creating assets for model customization#
from rich.pretty import pprint
from sagemaker.ai_registry.air_constants import REWARD_FUNCTION, REWARD_PROMPT
from sagemaker.ai_registry.dataset import DataSet, CustomizationTechnique
from sagemaker.ai_registry.evaluator import Evaluator
# Configure AWS credentials and region
#! ada credentials update --provider=isengard --account=<> --role=Admin --profile=default --once
#! aws configure set region us-west-2
DataSets#
Create#
DataSet input format depends on Customization technique
If no customization technique is provide, client side validation would be skipped
Provide a source (it can be local file path or S3 URL)
# 1. S3 Data source
dataset = DataSet.create(
name="sdkv3-gen-ds2",
source="s3://sdk-air-test-bucket/datasets/training-data/jamjee-sft-ds1.jsonl",
# or use local filepath as source.
# customization_technique=CustomizationTechnique.SFT
)
# Refreshes status from hub
dataset.refresh()
pprint(dataset.__dict__)
versions = dataset.get_versions()
pprint(versions.__dict__)
# delete specific version
dataset.delete(version="0.0.4")
#dataset.delete(version="use a version from versions")
#pprint(versions)
# specified deleted version should not be part of output
# deletes all versions of this dataset by default
dataset.delete()
List DataSet#
#Optional max_results argument for pagination or else use default config
datasets = DataSet.get_all(max_results=2)
for dataset in datasets:
pprint(dataset)
Use an existing DataSet#
# Use a dataset from iterator
dataset = next(DataSet.get_all(max_results=2))
for dataset in datasets:
pprint(dataset.__dict__)
# Use a dataset by name
dataset = DataSet.get(name="sdkv3-gen-ds2")
pprint(dataset)
# We can do CRUD operation on this DataSet
# e.g. dataset.delete()
#Create a new version of this dataset
dataset.create_version(source="s3://<bucket>/datasets/test_ds")
versions = dataset.get_versions()
pprint(versions)
Evaluator#
# Method : Lambda
evaluator = Evaluator.create(
name = "sdk-new-rf11",
source="arn:aws:lambda:us-west-2:<>:function:<function-name>8",
type=REWARD_FUNCTION
)
# Method : BYOC
evaluator = Evaluator.create(
name = "eval-lambda-test",
source="/path_to_local/eval_lambda_1.py",
type = REWARD_FUNCTION
)
# Reward Prompt
evaluator = Evaluator.create(
name = "jamj-rp2",
source="/path_to_local/custom_prompt.jinja",
type = REWARD_PROMPT
)
# Optional wait, by default we have wait = True during create call.
evaluator.wait()
evaluator.refresh()
pprint(evaluator)
# Optional max_results for pagination
evaluators = Evaluator.get_all(max_results=2)
for evaluator in evaluators:
pprint(evaluator)
# Get evaluators by type
evaluators = Evaluator.get_all(type='RewardPrompt', max_results=2)
for evaluator in evaluators:
pprint(evaluator)
# Get an evaluator by name
evaluator = Evaluator.get(name="sdk-new-rf11")
pprint(evaluator)
evaluator.create_version(source=evaluator.reference)
versions = evaluator.get_versions()
pprint(versions)
# delete evaluator, option version argument or delete all versions.
evaluator.delete()