Creating assets for model customization

Creating assets for model customization#

from rich.pretty import pprint

from sagemaker.ai_registry.air_constants import REWARD_FUNCTION, REWARD_PROMPT
from sagemaker.ai_registry.dataset import DataSet, CustomizationTechnique
from sagemaker.ai_registry.evaluator import Evaluator
# Configure AWS credentials and region
#! ada credentials update --provider=isengard --account=<> --role=Admin --profile=default --once
#! aws configure set region us-west-2

DataSets#

Create#

  • DataSet input format depends on Customization technique

  • If no customization technique is provide, client side validation would be skipped

  • Provide a source (it can be local file path or S3 URL)

# 1. S3 Data source
dataset = DataSet.create(
            name="sdkv3-gen-ds2",
            source="s3://sdk-air-test-bucket/datasets/training-data/jamjee-sft-ds1.jsonl",
                # or use local filepath as source.
            # customization_technique=CustomizationTechnique.SFT
        )
# Refreshes status from hub
dataset.refresh()
pprint(dataset.__dict__)
versions = dataset.get_versions()
pprint(versions.__dict__)
# delete specific version
dataset.delete(version="0.0.4")
#dataset.delete(version="use a version from versions")
#pprint(versions)
# specified deleted version should not be part of output
# deletes all versions of this dataset by default
dataset.delete()

List DataSet#

#Optional max_results argument for pagination or else use default config
datasets = DataSet.get_all(max_results=2)
for dataset in datasets:
    pprint(dataset)

Use an existing DataSet#

# Use a dataset from iterator
dataset = next(DataSet.get_all(max_results=2))
for dataset in datasets:
    pprint(dataset.__dict__)
# Use a dataset by name
dataset = DataSet.get(name="sdkv3-gen-ds2")
pprint(dataset)

# We can do CRUD operation on this DataSet
# e.g. dataset.delete()
#Create a new version of this dataset
dataset.create_version(source="s3://<bucket>/datasets/test_ds")
versions = dataset.get_versions()
pprint(versions)

Evaluator#

# Method : Lambda
evaluator = Evaluator.create(
    name = "sdk-new-rf11",
    source="arn:aws:lambda:us-west-2:<>:function:<function-name>8",
    type=REWARD_FUNCTION

)
# Method : BYOC

evaluator = Evaluator.create(
    name = "eval-lambda-test",
    source="/path_to_local/eval_lambda_1.py",
    type = REWARD_FUNCTION
)
# Reward Prompt
evaluator = Evaluator.create(
    name = "jamj-rp2",
    source="/path_to_local/custom_prompt.jinja",
    type = REWARD_PROMPT
)
# Optional wait, by default we have wait = True during create call.
evaluator.wait()
evaluator.refresh()
pprint(evaluator)
# Optional max_results for pagination
evaluators = Evaluator.get_all(max_results=2)
for evaluator in evaluators:
    pprint(evaluator)
# Get evaluators by type
evaluators = Evaluator.get_all(type='RewardPrompt', max_results=2)
for evaluator in evaluators:
    pprint(evaluator)
# Get an evaluator by name
evaluator = Evaluator.get(name="sdk-new-rf11")
pprint(evaluator)
evaluator.create_version(source=evaluator.reference)
versions = evaluator.get_versions()
pprint(versions)
# delete evaluator, option version argument or delete all versions.
evaluator.delete()