SageMaker V3 PyTorch Processing

SageMaker V3 PyTorch Processing#

from sagemaker.core.helper.session_helper import Session, get_execution_role

sess = Session()
role = get_execution_role()

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

Download Data#

from datasets import load_dataset

train_dataset, test_dataset = load_dataset("imdb", split=["train", "test"])
train_dataset, test_dataset
train_dataset[10]

Use FrameworkProcessor with pytorch image#

from sagemaker.core.image_uris import get_training_image_uri
from sagemaker.core.processing import FrameworkProcessor

image_uri = get_training_image_uri(
    region=sess.boto_region_name,
    framework="pytorch",
    framework_version="1.13",
    py_version="py39",
    instance_type="ml.m5.xlarge",
)

pytorch_processor = FrameworkProcessor(
    image_uri=image_uri,
    role=role,
    instance_type="ml.m5.xlarge",
    instance_count=1,
)
from sagemaker.core.shapes import ProcessingOutput, ProcessingS3Output
from time import gmtime, strftime
import os

s3_prefix = "huggingface-text-classification"
processing_job_name = "{}-{}".format(s3_prefix, strftime("%d-%H-%M-%S", gmtime()))
output_destination = "s3://{}/{}".format(sess.default_bucket(), s3_prefix)

pytorch_processor.run(
    code="preprocessing.py",
    source_dir=os.path.abspath("scripts/preprocess"),
    job_name=processing_job_name,
    outputs=[
        ProcessingOutput(
            output_name="train",
            s3_output=ProcessingS3Output(
                s3_uri="{}/train".format(output_destination),
                local_path="/opt/ml/processing/train",
                s3_upload_mode="EndOfJob",
            ),
        ),
        ProcessingOutput(
            output_name="test",
            s3_output=ProcessingS3Output(
                s3_uri="{}/test".format(output_destination),
                local_path="/opt/ml/processing/test",
                s3_upload_mode="EndOfJob",
            ),
        ),
    ],
    wait=False,
)
pytorch_processor.latest_job.refresh().processing_job_status