Source code for sagemaker.train.evaluate.execution

"""SageMaker Evaluation Execution Module.

This module provides classes for managing evaluation executions.
"""
from __future__ import absolute_import

# Standard library imports
import json
import logging
import os
import time
import uuid
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional

# Third-party imports
from botocore.exceptions import ClientError
from pydantic import BaseModel, Field
from sagemaker.core.common_utils import TagsDict
from sagemaker.core.helper.session_helper import Session
from sagemaker.core.resources import Pipeline, PipelineExecution
from sagemaker.core.resources import Tag as ResourceTag  # For Tag.get_all()
from sagemaker.core.shapes import Tag  # For Pipeline.create() tags parameter
from sagemaker.core.telemetry.telemetry_logging import _telemetry_emitter
from sagemaker.core.telemetry.constants import Feature

# Local imports
from .constants import (
    _TAG_SAGEMAKER_MODEL_EVALUATION,
    EvalType,
    _get_pipeline_name,
    _get_pipeline_name_prefix,
)

logger = logging.getLogger(__name__)


def _create_evaluation_pipeline(
    eval_type: EvalType,
    role_arn: str,
    pipeline_definition: str,
    session: Optional[Any] = None,
    region: Optional[str] = None,
    tags: Optional[List[TagsDict]] = [],
) -> Any:
    """Helper method to create a SageMaker pipeline for evaluation.
    
    Re-renders pipeline_definition with actual pipeline_name before creating.
    
    Args:
        eval_type (EvalType): Type of evaluation.
        role_arn (str): IAM role ARN for pipeline execution.
        pipeline_definition (str): JSON pipeline definition (Jinja2 template).
        session (Optional[Any]): SageMaker session object.
        region (Optional[str]): AWS region.
        tags (Optional[List[TagsDict]]): List of tags to include in pipeline
        
    Returns:
        Any: Created Pipeline instance (ready for execution).
    """
    from jinja2 import Template
    
    pipeline_name = _get_pipeline_name(eval_type)
    client_request_token = str(uuid.uuid4())
    
    logger.info(f"Creating new pipeline: {pipeline_name}")
    
    # Re-render pipeline definition with actual pipeline_name
    template = Template(pipeline_definition)
    resolved_pipeline_definition = template.render(pipeline_name=pipeline_name)
    
    # Create tags for the pipeline
    # Note: Tags must be Tag objects, not dicts, for Pydantic validation to pass
    tag_objects = []
    
    # Add evaluation tag
    tag_objects.append(Tag(key=_TAG_SAGEMAKER_MODEL_EVALUATION, value="true"))
    
    # Process any additional tags passed in
    if tags:
        for i, tag_item in enumerate(tags):
            try:
                if hasattr(tag_item, '__class__') and 'Tag' in tag_item.__class__.__name__:
                    # Already a Tag object
                    tag_objects.append(tag_item)
                elif isinstance(tag_item, dict):
                    # Convert dict to Tag object - handle both lowercase and capitalized keys
                    key = tag_item.get("key") or tag_item.get("Key")
                    value = tag_item.get("value") or tag_item.get("Value")
                    if key and value:
                        tag_objects.append(Tag(key=str(key), value=str(value)))
                    else:
                        logger.warning(f"Skipping invalid tag at index {i}: {tag_item}")
                else:
                    logger.warning(f"Skipping unsupported tag type at index {i}: {type(tag_item)}")
            except Exception as e:
                logger.warning(f"Error processing tag at index {i}: {e}")
    
    logger.info(f"Creating pipeline with {len(tag_objects)} tags")
    
    pipeline = Pipeline.create(
        pipeline_name=pipeline_name,
        client_request_token=client_request_token,
        role_arn=role_arn,
        pipeline_definition=resolved_pipeline_definition,
        pipeline_display_name=f"EvaluationPipeline-{eval_type.value}",
        pipeline_description=f"Pipeline for {eval_type.value} evaluation jobs",
        tags=tag_objects,
        session=session,
        region=region
    )
    
    logger.info(f"Successfully created pipeline: {pipeline_name}")
    
    # Wait for pipeline to be ready before returning
    logger.info(f"Waiting for pipeline {pipeline_name} to be ready...")
    try:
        pipeline.wait_for_status(target_status="Active", poll=5, timeout=300)  # Wait up to 5 minutes
        logger.info(f"Pipeline {pipeline_name} is now active and ready for execution")
    except Exception as e:
        logger.warning(f"Failed to wait for pipeline status: {e}. Pipeline may still be initializing...")
    
    return pipeline


def _clean_unassigned_value(value: Any) -> Any:
    """Clean Unassigned object by converting to None.
    
    Args:
        value (Any): Value that may be an Unassigned object.
        
    Returns:
        Any: None if value is Unassigned, otherwise returns the value unchanged.
    """
    if value is not None and hasattr(value, '__class__'):
        if 'Unassigned' in value.__class__.__name__:
            return None
    return value


def _clean_unassigned_from_dict(data: Dict[str, Any]) -> Dict[str, Any]:
    """Clean Unassigned objects from nested dict before pydantic validation.
    
    Args:
        data (Dict[str, Any]): Dictionary that may contain Unassigned objects.
        
    Returns:
        Dict[str, Any]: Cleaned dictionary with Unassigned objects replaced with None.
    """
    if data.get('status', {}).get('failure_reason') is not None:
        data['status']['failure_reason'] = _clean_unassigned_value(data['status']['failure_reason'])
    return data


def _extract_eval_type_from_arn(arn: str) -> Optional[EvalType]:
    """Helper method to extract evaluation type from pipeline or execution ARN.
    
    Extracts eval type from new naming pattern: SagemakerEvaluation-[EvalType]-[uuid]
    
    Args:
        arn (str): Pipeline ARN or Pipeline Execution ARN.
            Pipeline ARN format: arn:aws:sagemaker:region:account:pipeline/pipeline-name
            Execution ARN format: arn:aws:sagemaker:region:account:pipeline/pipeline-name/execution/execution-id
    
    Returns:
        Optional[EvalType]: EvalType if found, None otherwise.
    """
    try:
        # Split ARN and extract pipeline name
        arn_parts = arn.split('/')
        if len(arn_parts) >= 2:
            # For execution ARN, pipeline name is at index -3
            # For pipeline ARN, pipeline name is at index -1
            pipeline_name = arn_parts[-3] if len(arn_parts) >= 4 else arn_parts[-1]
            
            # Check pattern: SagemakerEvaluation-{EvalType}-{uuid}
            for eval_type in EvalType:
                prefix = _get_pipeline_name_prefix(eval_type)
                if pipeline_name.startswith(prefix):
                    logger.debug(f"Extracted eval_type: {eval_type.value} from ARN: {arn}")
                    return eval_type
        
        logger.warning(f"Could not extract eval_type from ARN: {arn}")
        return None
    except Exception as e:
        logger.warning(f"Error extracting eval_type from ARN {arn}: {str(e)}")
        return None


def _get_or_create_pipeline(
    eval_type: EvalType,
    pipeline_definition: str,
    role_arn: str,
    session: Optional[Session] = None,
    region: Optional[str] = None,
    create_tags: Optional[List[TagsDict]] = [],
) -> Pipeline:
    """Get existing pipeline or create/update it.
    
    Searches for existing pipeline using Pipeline.get_all with pipeline_name_prefix.
    Validates tag using Tag.get_all and updates if found. Otherwise creates new pipeline with UUID.
    Re-renders pipeline_definition with actual pipeline_name before create/update.
    
    Args:
        eval_type: Type of evaluation
        pipeline_definition: JSON pipeline definition (Jinja2 template)
        role_arn: IAM role ARN for pipeline execution
        session: Boto3 session (optional)
        region: AWS region (optional)
        create_tags (Optional[List[TagsDict]]): List of tags to include in pipeline
        
    Returns:
        Pipeline instance (existing updated or newly created)
        
    Raises:
        ClientError: If AWS service call fails
    """
    from jinja2 import Template
    
    pipeline_name_prefix = _get_pipeline_name_prefix(eval_type)
    
    try:
        # Use Pipeline.get_all with pipeline_name_prefix to find existing pipelines
        pipelines = Pipeline.get_all(
            pipeline_name_prefix=pipeline_name_prefix,
            session=session,
            region=region
        )
        
        # Check each pipeline for the required tag
        for pipeline in pipelines:
            pipeline_arn = pipeline.pipeline_arn
            
            # Get tags using ResourceTag.get_all
            tags_list = ResourceTag.get_all(resource_arn=pipeline_arn, session=session, region=region)
            tags = {tag.key: tag.value for tag in tags_list}
            
            # Validate tag
            if tags.get(_TAG_SAGEMAKER_MODEL_EVALUATION) == "true":
                pipeline_name = pipeline.pipeline_name
                logger.info(f"Found existing pipeline: {pipeline_name}")
                
                # Re-render pipeline definition with actual pipeline_name
                template = Template(pipeline_definition)
                resolved_pipeline_definition = template.render(pipeline_name=pipeline_name)
                
                # Update pipeline with latest definition
                logger.info(f"Updating pipeline {pipeline_name} with latest definition")
                pipeline.update(
                    pipeline_definition=resolved_pipeline_definition,
                    role_arn=role_arn,
                    pipeline_description=f"Pipeline for {eval_type.value} evaluation jobs (updated)"
                )
                logger.info(f"Successfully updated pipeline: {pipeline_name}")
                return pipeline
        
        # No matching pipeline found, create new one
        logger.info(f"No existing pipeline found with prefix {pipeline_name_prefix}, creating new one")
        return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
        
    except ClientError as e:
        error_code = e.response['Error']['Code']
        if "ResourceNotFound" in error_code:
            return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)
        else:
            raise
            
    except Exception as e:
        # If search fails for other reasons, try to create
        logger.info(f"Error searching for pipeline ({str(e)}), attempting to create new pipeline")
        return _create_evaluation_pipeline(eval_type, role_arn, pipeline_definition, session, region, create_tags)


def _start_pipeline_execution(
    pipeline_name: str,
    name: str,
    session: Optional[Session] = None,
    region: Optional[str] = None
) -> str:
    """Start pipeline execution using boto3 client.
    
    Extracted for testability - can be mocked independently in tests.
    
    Args:
        pipeline_name: Name of the pipeline to execute
        name: Base name for the execution
        session: Boto3 session (optional)
        region: AWS region (optional)
        
    Returns:
        ARN of the started pipeline execution
        
    Raises:
        ClientError: If AWS service call fails
    """
    import os
    import boto3
    
    execution_display_name = f"{name}-{int(time.time())}"
    endpoint_url = os.environ.get('SAGEMAKER_ENDPOINT')
    
    # Get boto3 client
    if session:
        sm_client = session.client('sagemaker', region_name=region, endpoint_url=endpoint_url)
    else:
        sm_client = boto3.client('sagemaker', region_name=region, endpoint_url=endpoint_url)
    
    # Start pipeline execution
    logger.info(f"Starting pipeline execution: {execution_display_name}")
    
    response = sm_client.start_pipeline_execution(
        PipelineName=pipeline_name,
        PipelineExecutionDisplayName=execution_display_name,
        PipelineExecutionDescription=f"Evaluation execution: {name}",
        PipelineParameters=[],  # Empty since all values are pre-substituted
        ClientRequestToken=str(uuid.uuid4())
    )
    
    execution_arn = response['PipelineExecutionArn']
    logger.info(f"Pipeline execution started: {execution_arn}")
    
    return execution_arn


def _create_execution_from_pipeline_execution(
    pe: PipelineExecution,
    eval_type: EvalType
) -> 'EvaluationPipelineExecution':
    """Create EvaluationPipelineExecution from PipelineExecution.
    
    Handles failure_reason Unassigned objects and sets basic properties.
    Extracted for testability - used by both get() and get_all().
    
    Args:
        pe: PipelineExecution object from sagemaker_core
        eval_type: Type of evaluation
        
    Returns:
        EvaluationPipelineExecution with basic properties set
    """
    name = pe.pipeline_execution_arn.split('/')[-1] if pe.pipeline_execution_arn else 'unknown'
    
    # Handle failure_reason which might be an Unassigned object
    failure_reason = pe.failure_reason
    if failure_reason is not None and hasattr(failure_reason, '__class__'):
        if 'Unassigned' in failure_reason.__class__.__name__:
            failure_reason = None
    
    execution = EvaluationPipelineExecution(
        arn=pe.pipeline_execution_arn,
        name=name,
        status=PipelineExecutionStatus(
            overall_status=pe.pipeline_execution_status or 'Unknown',
            failure_reason=failure_reason
        ),
        last_modified_time=pe.last_modified_time,
        eval_type=eval_type
    )
    
    # Store the internal pipeline execution reference
    execution._pipeline_execution = pe
    
    return execution


def _extract_output_s3_location_from_steps(raw_steps: List[Any], session: Optional[Any] = None, region: Optional[str] = None) -> Optional[str]:
    """Helper method to extract S3 output location from training job's OutputDataConfig.
    
    Finds the first evaluation training step (EvaluateCustomModel or EvaluateBaseModel),
    gets its training job ARN, and uses boto3 DescribeTrainingJob to retrieve the S3 output path.
    
    Args:
        raw_steps: List of PipelineExecutionStep objects from SageMaker
        session: Boto3 session (optional)
        region: AWS region (optional)
    
    Returns:
        S3 output location from OutputDataConfig if found, None otherwise
    """
    try:
        import boto3
        import os
        
        # Get endpoint URL from environment variable (for beta endpoint support)
        endpoint_url = os.environ.get('SAGEMAKER_ENDPOINT')
        
        # Get SageMaker client with optional endpoint URL
        if session:
            sm_client = session.client('sagemaker', region_name=region, endpoint_url=endpoint_url)
        else:
            sm_client = boto3.client('sagemaker', region_name=region, endpoint_url=endpoint_url)
        
        for step in raw_steps:
            step_name = getattr(step, 'step_name', '')
            
            # Look for evaluation training steps (custom or base)
            if 'EvaluateCustomModel' in step_name or 'EvaluateBaseModel' in step_name:
                metadata = getattr(step, 'metadata', None)
                if metadata and hasattr(metadata, 'training_job'):
                    training_job_meta = metadata.training_job
                    
                    # Get training job name from ARN
                    if hasattr(training_job_meta, 'arn'):
                        training_job_name = training_job_meta.arn.split('/')[-1]
                        
                        try:
                            # Use boto3 DescribeTrainingJob (avoids pydantic validation issues)
                            response = sm_client.describe_training_job(TrainingJobName=training_job_name)
                            
                            # Get OutputDataConfig.S3OutputPath
                            if 'OutputDataConfig' in response and 'S3OutputPath' in response['OutputDataConfig']:
                                s3_output_path = response['OutputDataConfig']['S3OutputPath']
                                logger.info(f"Extracted s3_output_path from training job {training_job_name}: {s3_output_path}")
                                return s3_output_path
                        except ClientError as e:
                            logger.warning(f"Failed to describe training job {training_job_name}: {e}")
                            continue
                        except Exception as e:
                            logger.warning(f"Error describing training job {training_job_name}: {e}")
                            continue
        
        logger.debug("Could not extract s3_output_path from pipeline steps")
        return None
    except Exception as e:
        logger.warning(f"Error extracting s3_output_path from steps: {str(e)}")
        return None


[docs] class StepDetail(BaseModel): """Pipeline step details for tracking execution progress. Represents the status and timing information for a single step in a SageMaker pipeline execution. Parameters: name (str): Name of the pipeline step. status (str): Status of the step (Completed, Executing, Waiting, Failed). start_time (Optional[str]): ISO format timestamp when step started. end_time (Optional[str]): ISO format timestamp when step ended. display_name (Optional[str]): Human-readable display name for the step. failure_reason (Optional[str]): Detailed reason if the step failed. """ name: str = Field(..., description="Name of the pipeline step") status: str = Field(..., description="Status of the step (Completed, Executing, Waiting, Failed)") start_time: Optional[str] = Field(None, description="Step start time") end_time: Optional[str] = Field(None, description="Step end time") display_name: Optional[str] = Field(None, description="Display name for the step") failure_reason: Optional[str] = Field(None, description="Reason for failure if step failed") job_arn: Optional[str] = Field(None, description="ARN of the underlying job (training, processing, transform, etc.)")
[docs] class PipelineExecutionStatus(BaseModel): """Combined pipeline execution status with step details and failure reason. Aggregates the overall execution status along with detailed information about individual pipeline steps and any failure reasons. Parameters: overall_status (str): Overall execution status (Starting, Executing, Completed, Failed, etc.). step_details (List[StepDetail]): List of individual pipeline step details. failure_reason (Optional[str]): Detailed reason if the execution failed. """ overall_status: str = Field(..., description="Overall execution status (Starting, Running, Completed, Failed, etc.)") step_details: List[StepDetail] = Field(default_factory=list, description="List of pipeline step details") failure_reason: Optional[str] = Field(None, description="Reason for failure if execution failed")
[docs] class EvaluationPipelineExecution(BaseModel): """Manages SageMaker pipeline-based evaluation execution lifecycle. This class wraps SageMaker Pipeline execution to provide a simplified interface for running, monitoring, and managing evaluation jobs. Users typically don't instantiate this class directly, but receive instances from evaluator classes. Example: .. code:: python from sagemaker.train.evaluate import BenchmarkEvaluator from sagemaker.train.evaluate.execution import EvaluationPipelineExecution # Start evaluation through evaluator evaluator = BenchmarkEvaluator(...) execution = evaluator.evaluate() # Monitor execution print(f"Status: {execution.status.overall_status}") print(f"Steps: {len(execution.status.step_details)}") # Wait for completion execution.wait() # Display results execution.show_results() # Retrieve past executions all_executions = list(EvaluationPipelineExecution.get_all()) specific_execution = EvaluationPipelineExecution.get(arn="arn:...") Parameters: arn (Optional[str]): ARN of the pipeline execution. name (str): Name of the evaluation execution. status (PipelineExecutionStatus): Combined status with step details and failure reason. last_modified_time (Optional[datetime]): Last modification timestamp. eval_type (Optional[EvalType]): Type of evaluation (BENCHMARK, CUSTOM_SCORER, LLM_AS_JUDGE). s3_output_path (Optional[str]): S3 location where evaluation results are stored. steps (List[Dict[str, Any]]): Raw step information from SageMaker. """ # Fields set by underlying SageMaker pipeline operations arn: Optional[str] = Field(None, description="ARN of the pipeline execution") name: str = Field(..., description="Name of the evaluation execution") status: PipelineExecutionStatus = Field(default_factory=lambda: PipelineExecutionStatus(overall_status="Unknown"), description="Combined status, step details, and failure reason") last_modified_time: Optional[datetime] = Field(None, description="Last modification timestamp") eval_type: Optional[EvalType] = Field(None, description="Evaluation type") s3_output_path: Optional[str] = Field(None, description="S3 location where evaluation results are stored") # Additional fields for internal use steps: List[Dict[str, Any]] = Field(default_factory=list, description="Raw step information from SageMaker")
[docs] class Config: arbitrary_types_allowed = True
def __init__(self, **data): super().__init__(**data) self._pipeline_execution: Optional[Any] = None
[docs] @classmethod @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="EvaluationPipelineExecution.start") def start( cls, eval_type: EvalType, name: str, pipeline_definition: str, role_arn: str, s3_output_path: Optional[str] = None, session: Optional[Session] = None, region: Optional[str] = None, tags: Optional[List[TagsDict]] = [], ) -> 'EvaluationPipelineExecution': """Create sagemaker pipeline execution. Optionally creates pipeline. Args: eval_type (EvalType): Type of evaluation (BENCHMARK, CUSTOM_SCORER, LLM_AS_JUDGE). name (str): Name for the evaluation execution. pipeline_definition (str): Complete rendered pipeline definition as JSON string. role_arn (str): IAM role ARN for pipeline execution. s3_output_path (Optional[str]): S3 location where evaluation results are stored. session (Optional[Session]): Boto3 session for API calls. region (Optional[str]): AWS region for the pipeline. tags (Optional[List[TagsDict]]): List of tags to include in pipeline Returns: EvaluationPipelineExecution: Started pipeline execution instance. Raises: ValueError: If pipeline_definition is not valid JSON. ClientError: If AWS service call fails. """ # Validate pipeline_definition is valid JSON import json try: json.loads(pipeline_definition) except json.JSONDecodeError as e: raise ValueError(f"Invalid pipeline definition JSON: {e}") # Create execution instance execution = cls( name=name, eval_type=eval_type, status=PipelineExecutionStatus(overall_status="Starting"), s3_output_path=s3_output_path ) try: # Get or create pipeline (handles update logic internally) pipeline = _get_or_create_pipeline( eval_type=eval_type, pipeline_definition=pipeline_definition, role_arn=role_arn, session=session, region=region, create_tags=tags, ) # Start pipeline execution via boto3 # Use the actual pipeline name from the created/updated pipeline object pipeline_name = pipeline.pipeline_name execution.arn = _start_pipeline_execution( pipeline_name=pipeline_name, name=name, session=session, region=region ) # Get the pipeline execution object for future operations execution._pipeline_execution = PipelineExecution.get( pipeline_execution_arn=execution.arn, session=session, region=region ) # Update execution with initial execution details execution.status.overall_status = execution._pipeline_execution.pipeline_execution_status or "Executing" execution.last_modified_time = execution._pipeline_execution.creation_time or datetime.now() except ClientError as e: error_code = e.response['Error']['Code'] error_message = e.response['Error']['Message'] logger.error(f"AWS service error when starting pipeline execution: {error_message}") execution.status.overall_status = "Failed" execution.status.failure_reason = f"AWS service error: {error_message}" except Exception as e: logger.error(f"Unexpected error when starting pipeline execution: {str(e)}") execution.status.overall_status = "Failed" execution.status.failure_reason = f"Unexpected error: {str(e)}" # Convert to appropriate subclass based on eval_type return execution._convert_to_subclass(eval_type)
[docs] @classmethod @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="EvaluationPipelineExecution.get_all") def get_all( cls, eval_type: Optional[EvalType] = None, session: Optional[Session] = None, region: Optional[str] = None ): """Get all pipeline executions, optionally filtered by evaluation type. Searches for existing pipelines using prefix and tag validation, then retrieves executions from those pipelines. Args: eval_type (Optional[EvalType]): Evaluation type to filter by (e.g., EvalType.BENCHMARK). If None, returns executions from all evaluation pipelines. session (Optional[Session]): Boto3 session. Will be inferred if not provided. region (Optional[str]): AWS region. Will be inferred if not provided. Yields: EvaluationPipelineExecution: Pipeline execution instances. Example: .. code:: python # Get all evaluation executions as iterator iter = EvaluationPipelineExecution.get_all() all_executions = list(iter) # Get only benchmark evaluations iter = EvaluationPipelineExecution.get_all(eval_type=EvalType.BENCHMARK) benchmark_executions = list(iter) """ try: # Determine which eval type(s) to search for eval_types_to_check = [eval_type] if eval_type else list(EvalType) for et in eval_types_to_check: pipeline_name_prefix = _get_pipeline_name_prefix(et) try: # Search for pipelines with the prefix pipelines = Pipeline.get_all( pipeline_name_prefix=pipeline_name_prefix, session=session, region=region ) # Check each pipeline for the required tag and get its executions for pipeline in pipelines: try: pipeline_arn = pipeline.pipeline_arn # Get tags using ResourceTag.get_all tags_list = ResourceTag.get_all(resource_arn=pipeline_arn, session=session, region=region) tags = {tag.key: tag.value for tag in tags_list} # Validate tag - only process evaluation pipelines if tags.get(_TAG_SAGEMAKER_MODEL_EVALUATION) != "true": logger.debug(f"Skipping pipeline {pipeline.pipeline_name} - missing required tag") continue pipeline_name = pipeline.pipeline_name logger.debug(f"Found evaluation pipeline: {pipeline_name}") # Get all executions for this pipeline pipeline_executions = PipelineExecution.get_all( pipeline_name=pipeline_name, session=session, region=region ) # Convert each PipelineExecution to EvaluationPipelineExecution for pe in pipeline_executions: # Create execution from pipeline execution execution = _create_execution_from_pipeline_execution(pe, et) # Enrich with step details and S3 path execution._enrich_with_step_details(session, region) # Convert to appropriate subclass based on eval_type execution = execution._convert_to_subclass(et) yield execution except Exception as e: logger.warning(f"Error processing pipeline {pipeline.pipeline_name}: {str(e)}") continue except ClientError as e: error_code = e.response['Error']['Code'] # If no pipelines found with prefix, skip to next eval type if "ResourceNotFound" in error_code or "ValidationException" in error_code: logger.debug(f"No pipelines found with prefix {pipeline_name_prefix}") continue else: logger.warning(f"Error searching for pipelines with prefix {pipeline_name_prefix}: {e}") continue except Exception as e: logger.warning(f"Error processing eval type {et.value}: {str(e)}") continue except Exception as e: logger.error(f"Unexpected error when listing pipeline executions: {str(e)}")
[docs] @classmethod @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="EvaluationPipelineExecution.get") def get( cls, arn: str, session: Optional[Session] = None, region: Optional[str] = None ) -> 'EvaluationPipelineExecution': """Get a sagemaker pipeline execution instance by ARN. Args: arn (str): ARN of the pipeline execution. session (Optional[Session]): Boto3 session. Will be inferred if not provided. region (Optional[str]): AWS region. Will be inferred if not provided. Returns: EvaluationPipelineExecution: Retrieved pipeline execution instance. Raises: ClientError: If AWS service call fails. Example: .. code:: python # Get execution by ARN arn = "arn:aws:sagemaker:us-west-2:123456789012:pipeline/eval-pipeline/execution/abc123" execution = EvaluationPipelineExecution.get(arn=arn) print(execution.status.overall_status) """ # Create execution instance with basic info name = arn.split('/')[-1] execution = cls( arn=arn, name=name, status=PipelineExecutionStatus(overall_status="Unknown") ) # Try to determine eval_type from execution ARN early (as fallback for error cases) execution.eval_type = _extract_eval_type_from_arn(arn) try: # Get pipeline execution details and store internally execution._pipeline_execution = PipelineExecution.get( pipeline_execution_arn=arn, session=session, region=region ) # Update execution with pipeline execution details execution.status.overall_status = execution._pipeline_execution.pipeline_execution_status or "Unknown" execution.status.failure_reason = _clean_unassigned_value(execution._pipeline_execution.failure_reason) execution.last_modified_time = execution._pipeline_execution.last_modified_time # Enrich with step details and S3 path execution._enrich_with_step_details(session, region) # Determine eval_type from pipeline ARN (preferred method) pipeline_arn = execution._pipeline_execution.pipeline_arn if execution._pipeline_execution else None determined_eval_type = execution._determine_eval_type(pipeline_arn) if determined_eval_type: execution.eval_type = determined_eval_type except ClientError as e: error_code = e.response['Error']['Code'] error_message = e.response['Error']['Message'] logger.error(f"AWS service error when getting pipeline execution: {error_message}") execution.status.overall_status = "Error" execution.status.failure_reason = f"AWS service error: {error_code}:{error_message}" # eval_type already set from execution ARN fallback above except Exception as e: logger.error(f"Unexpected error when getting pipeline execution details: {str(e)}") execution.status.overall_status = "Error" execution.status.failure_reason = f"Unexpected error: {str(e)}" # eval_type already set from execution ARN fallback above # Convert to appropriate subclass based on eval_type return execution._convert_to_subclass(execution.eval_type)
[docs] @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="EvaluationPipelineExecution.refresh") def refresh(self) -> None: """Describe a pipeline execution and update job status""" if not self._pipeline_execution: return try: # Refresh the pipeline execution instance self._pipeline_execution.refresh() # Update status from refreshed pipeline execution self.status.overall_status = self._pipeline_execution.pipeline_execution_status or "Unknown" self.status.failure_reason = _clean_unassigned_value(self._pipeline_execution.failure_reason) self.last_modified_time = self._pipeline_execution.last_modified_time # Get updated pipeline execution steps with proper session/region handling steps_iterator = self._pipeline_execution.get_all_steps() raw_steps = list(steps_iterator) self._update_step_details_from_raw_steps(raw_steps) except ClientError as e: error_code = e.response['Error']['Code'] error_message = e.response['Error']['Message'] logger.error(f"AWS service error when refreshing pipeline execution: {error_message}") except Exception as e: logger.error(f"Unexpected error when refreshing pipeline execution: {str(e)}")
[docs] @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="EvaluationPipelineExecution.stop") def stop(self) -> None: """Stop a pipeline execution""" if not self.arn: return try: # TODO: Move to sagemaker_core PipelineExecution.stop() when session handling is fixed # For now, use boto3 directly to stop the pipeline execution import os import boto3 endpoint_url = os.environ.get('SAGEMAKER_ENDPOINT') # Get boto3 client - extract from pipeline execution if available if self._pipeline_execution and hasattr(self._pipeline_execution, '_session'): session = self._pipeline_execution._session if hasattr(session, 'boto_session'): sm_client = session.boto_session.client('sagemaker', endpoint_url=endpoint_url) else: sm_client = session.client('sagemaker', endpoint_url=endpoint_url) else: # Fallback to default boto3 client sm_client = boto3.client('sagemaker', endpoint_url=endpoint_url) # Stop the pipeline execution using boto3 sm_client.stop_pipeline_execution( PipelineExecutionArn=self.arn ) # Update status self.status.overall_status = "Stopping" logger.info(f"Stopping pipeline execution: {self.arn}") # Refresh to get updated status self.refresh() except ClientError as e: error_code = e.response['Error']['Code'] error_message = e.response['Error']['Message'] logger.error(f"AWS service error when stopping pipeline execution: {error_message}") except Exception as e: logger.error(f"Unexpected error when stopping pipeline execution: {str(e)}")
[docs] @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="EvaluationPipelineExecution.wait") def wait( self, target_status: Literal["Executing", "Stopping", "Stopped", "Failed", "Succeeded"] = "Succeeded", poll: int = 5, timeout: Optional[int] = None ) -> None: """Wait for a pipeline execution to reach certain status. This method provides a hybrid implementation that works in both Jupyter notebooks and terminal environments, with appropriate visual feedback for each. Args: target_status: The status to wait for poll: The number of seconds to wait between each poll timeout: The maximum number of seconds to wait before timing out """ if not self._pipeline_execution: return start_time = time.time() # Detect if running in Jupyter is_jupyter = False try: from IPython import get_ipython ipython = get_ipython() if ipython is not None and 'IPKernelApp' in ipython.config: is_jupyter = True from IPython.display import display, HTML, clear_output except: pass if is_jupyter: # Jupyter notebook experience with rich library from rich.console import Console from rich.table import Table from rich.panel import Panel from rich.text import Text from rich.layout import Layout from rich.console import Group # Create console with Jupyter support console = Console(force_jupyter=True) while True: clear_output(wait=True) self.refresh() current_status = self.status.overall_status elapsed = time.time() - start_time # Create header table with pipeline name link header_table = Table(show_header=False, box=None, padding=(0, 1)) header_table.add_column("Property", style="cyan bold", width=20) header_table.add_column("Value", style="dim", overflow="fold") # Extract pipeline name and exec_id from execution ARN pipeline_name = None exec_id = '' if self.arn: arn_parts = self.arn.split('/') if len(arn_parts) >= 4: pipeline_name = arn_parts[-3] exec_id = arn_parts[-1] # Use execution display name if available, fall back to self.name display_name = self.name if self._pipeline_execution: dn = getattr(self._pipeline_execution, 'pipeline_execution_display_name', None) if dn and not (hasattr(dn, '__class__') and 'Unassigned' in dn.__class__.__name__): display_name = dn header_table.add_row("Evaluation Job", str(display_name)) # Build links row links = [] try: from sagemaker.core.utils.utils import SageMakerClient from sagemaker.train.common_utils.metrics_visualizer import _is_in_studio, _get_studio_base_url if pipeline_name and _is_in_studio(): region = SageMakerClient().region_name base = _get_studio_base_url(region) if base: pipeline_url = f"{base}/jobs/evaluation/detail?pipeline_name={pipeline_name}&execution_id={exec_id}" links.append(f"[bright_blue underline][link={pipeline_url}]🔗 Pipeline Execution (Studio)[/link][/bright_blue underline]") except Exception: pass if links: header_table.add_row("Links", " | ".join(links)) # Create main status table status_table = Table(show_header=False, box=None, padding=(0, 1)) status_table.add_column("Property", style="cyan bold", width=20) status_table.add_column("Value", style="dim") status_table.add_row("Overall Status", f"[bold][orange3]{current_status}[/][/]") status_table.add_row("Target Status", f"[bold yellow]{target_status}[/bold yellow]") status_table.add_row("Elapsed Time", f"[bold bright_red]{elapsed:.1f}s[/bold bright_red]") if self.status.failure_reason: status_table.add_row("Failure Reason", f"[red]{self.status.failure_reason}[/red]") # Create steps table if steps exist if self.status.step_details: has_failures = any(step.failure_reason for step in self.status.step_details) steps_table = Table(show_header=True, header_style="bold magenta", box=None, padding=(0, 1)) steps_table.add_column("Step Name", style="cyan", width=30) steps_table.add_column("Status", style="yellow", width=15) steps_table.add_column("Duration", style="green", width=12) failed_steps = [] job_arn_entries = [] for step in self.status.step_details: duration = "" if step.start_time and step.end_time: try: from datetime import datetime start = datetime.fromisoformat(step.start_time.replace('Z', '+00:00')) end = datetime.fromisoformat(step.end_time.replace('Z', '+00:00')) duration_seconds = (end - start).total_seconds() duration = f"{duration_seconds:.1f}s" except: duration = "N/A" elif step.start_time: duration = "Running..." status_display = step.status if "succeeded" in step.status.lower() or "completed" in step.status.lower(): status_display = f"[green]{step.status}[/green]" elif "failed" in step.status.lower(): status_display = f"[red]{step.status}[/red]" elif "executing" in step.status.lower() or "running" in step.status.lower(): status_display = f"[yellow]{step.status}[/yellow]" if step.job_arn: job_arn_entries.append({ 'step_name': step.display_name or step.name, 'job_arn': step.job_arn, }) row_data = [ step.display_name or step.name, status_display, duration ] if has_failures: if step.failure_reason: row_data.append("❌") failed_steps.append(step) else: row_data.append("") steps_table.add_row(*row_data) from rich.console import Group content_parts = [ status_table, Text(""), Text("Pipeline Steps", style="bold magenta"), steps_table ] if failed_steps: content_parts.append(Text("")) content_parts.append(Text("Step Failure Details", style="bold red")) for step in failed_steps: content_parts.append(Text("")) content_parts.append(Text(f"• {step.display_name or step.name}:", style="bold red")) content_parts.append(Text(f" {step.failure_reason}", style="red")) # Add job links table if any steps have ARNs if job_arn_entries: links_table = Table(show_header=True, header_style="bold magenta", box=None, padding=(0, 1)) links_table.add_column("Step", style="cyan", width=20) links_table.add_column("Console", style="dim") from sagemaker.core.utils.utils import SageMakerClient from sagemaker.train.common_utils.metrics_visualizer import ( _is_in_studio, _parse_job_arn, _get_studio_base_url, get_console_job_url, get_cloudwatch_logs_url, ) in_studio = _is_in_studio() studio_base = _get_studio_base_url(SageMakerClient().region_name) if in_studio else "" if in_studio: links_table.add_column("Studio", style="dim") links_table.add_column("Logs", style="dim") links_table.add_column("Job ARN", style="dim", overflow="fold") studio_path_map = { "training-job/": "jobs/train/", "processing-job/": "jobs/processing/", "transform-job/": "jobs/transform/", } for entry in job_arn_entries: console_link = "" logs_link = "" studio_link = "" try: arn = entry['job_arn'] url = get_console_job_url(arn) if url: console_link = f"[bright_blue underline][link={url}]🔗 link[/link][/bright_blue underline]" cw_url = get_cloudwatch_logs_url(arn) if cw_url: logs_link = f"[bright_blue underline][link={cw_url}]🔗 link[/link][/bright_blue underline]" if in_studio and studio_base: parsed = _parse_job_arn(arn) if parsed: _, resource = parsed for prefix, path in studio_path_map.items(): if resource.startswith(prefix): job_name = resource.split("/", 1)[1] s_url = f"{studio_base}/{path}{job_name}" studio_link = f"[bright_blue underline][link={s_url}]🔗 link[/link][/bright_blue underline]" break except Exception: pass row = [entry['step_name'], console_link] if in_studio: row.append(studio_link) row.extend([logs_link, entry['job_arn']]) links_table.add_row(*row) content_parts.append(Text("")) content_parts.append(Text("Job ARNs", style="bold magenta")) content_parts.append(links_table) console.print(Panel( Group(header_table, *content_parts), title="[bold bright_blue]Pipeline Execution Status[/bold bright_blue]", border_style="orange3" )) else: console.print(Panel( Group(header_table, status_table), title="[bold bright_blue]Pipeline Execution Status[/bold bright_blue]", border_style="orange3" )) if target_status == current_status: logger.info(f"Final Resource Status: {current_status}") return if "failed" in current_status.lower(): from sagemaker.core.utils.exceptions import FailedStatusError raise FailedStatusError( resource_type="PipelineExecution", status=current_status, reason=self.status.failure_reason, ) if timeout is not None and time.time() - start_time >= timeout: from sagemaker.core.utils.exceptions import TimeoutExceededError raise TimeoutExceededError( resource_type="PipelineExecution", status=current_status ) time.sleep(poll) else: # Terminal experience with rich library try: from rich.live import Live from rich.panel import Panel from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn from rich.console import Group from rich.status import Status from rich.style import Style progress = Progress( SpinnerColumn("bouncingBar"), TextColumn("{task.description}"), TimeElapsedColumn(), ) progress.add_task(f"Waiting for PipelineExecution to reach [bold]{target_status}[/bold] status...") status = Status("Current status:") with Live( Panel( Group(progress, status), title="Wait Log Panel", border_style=Style(color="blue"), ), transient=True, ): while True: self.refresh() current_status = self.status.overall_status status.update(f"Current status: [bold]{current_status}[/bold]") if target_status == current_status: logger.info(f"Final Resource Status: [bold]{current_status}[/bold]") return if "failed" in current_status.lower(): from sagemaker.core.utils.exceptions import FailedStatusError raise FailedStatusError( resource_type="PipelineExecution", status=current_status, reason=self.status.failure_reason, ) if timeout is not None and time.time() - start_time >= timeout: from sagemaker.core.utils.exceptions import TimeoutExceededError raise TimeoutExceededError( resource_type="PipelineExecution", status=current_status ) time.sleep(poll) except ImportError: # Fallback to simple print-based progress if rich is not available logger.info(f"Waiting for PipelineExecution to reach {target_status} status...") while True: self.refresh() current_status = self.status.overall_status elapsed = time.time() - start_time print(f"Current status: {current_status} (Elapsed: {elapsed:.1f}s)") if target_status == current_status: logger.info(f"Final Resource Status: {current_status}") return if "failed" in current_status.lower(): from sagemaker.core.utils.exceptions import FailedStatusError raise FailedStatusError( resource_type="PipelineExecution", status=current_status, reason=self.status.failure_reason, ) if timeout is not None and elapsed >= timeout: from sagemaker.core.utils.exceptions import TimeoutExceededError raise TimeoutExceededError( resource_type="PipelineExecution", status=current_status ) time.sleep(poll)
def _enrich_with_step_details( self, session: Optional[Session] = None, region: Optional[str] = None ) -> None: """Fetch steps, extract S3 path, and update execution with details. Modifies execution in place. Handles errors gracefully. Internal method for use by get() and get_all(). Args: session: Boto3 session (optional) region: AWS region (optional) """ if not self._pipeline_execution: return try: steps_iterator = self._pipeline_execution.get_all_steps(session=session, region=region) raw_steps = list(steps_iterator) self._update_step_details_from_raw_steps(raw_steps) # Extract s3_output_path from training job's OutputDataConfig if not self.s3_output_path: self.s3_output_path = _extract_output_s3_location_from_steps(raw_steps, session, region) except Exception as e: logger.warning(f"Failed to fetch step details for execution {self.name}: {str(e)}") def _determine_eval_type(self, pipeline_arn: Optional[str] = None) -> Optional[EvalType]: """Determine eval_type from execution or pipeline ARN. Tries pipeline ARN first (preferred), falls back to execution ARN. Internal method for use by get(). Args: pipeline_arn: Optional pipeline ARN to check first Returns: EvalType if found, None otherwise """ # Try to determine eval_type from pipeline ARN (preferred method when available) if pipeline_arn: eval_type_from_pipeline = _extract_eval_type_from_arn(pipeline_arn) if eval_type_from_pipeline: return eval_type_from_pipeline # Fall back to execution ARN if self.arn: return _extract_eval_type_from_arn(self.arn) return None def _convert_to_subclass(self, eval_type: EvalType) -> 'EvaluationPipelineExecution': """Convert this execution instance to eval-type-specific subclass. Internal method for use by start(), get(), and get_all(). Args: eval_type: Type of evaluation to determine subclass Returns: Execution instance of appropriate subclass """ # Save reference before conversion pipeline_execution_ref = self._pipeline_execution execution_dict = _clean_unassigned_from_dict(self.dict()) # Convert to appropriate subclass if eval_type == EvalType.BENCHMARK or eval_type == EvalType.CUSTOM_SCORER: execution = BenchmarkEvaluationExecution(**execution_dict) elif eval_type == EvalType.LLM_AS_JUDGE: execution = LLMAJEvaluationExecution(**execution_dict) else: execution = self # Restore internal pipeline execution reference execution._pipeline_execution = pipeline_execution_ref return execution @staticmethod def _extract_job_arn_from_metadata(step) -> Optional[str]: """Extract the underlying job ARN from a pipeline step's metadata.""" from sagemaker.train.common_utils.trainer_wait import _is_unassigned_attribute metadata = getattr(step, 'metadata', None) if metadata is None or _is_unassigned_attribute(metadata): return None for attr in ('training_job', 'processing_job', 'transform_job', 'tuning_job', 'auto_ml_job', 'compilation_job'): job_meta = getattr(metadata, attr, None) if job_meta is not None and not _is_unassigned_attribute(job_meta): arn = getattr(job_meta, 'arn', None) if arn and not _is_unassigned_attribute(arn): return str(arn) return None def _update_step_details_from_raw_steps(self, raw_steps: List[Any]) -> None: """Internal method to update step_details from raw pipeline execution steps Args: raw_steps: List of PipelineExecutionStep objects from SageMaker """ step_details = [] for step in raw_steps: try: # Convert datetime objects to strings if they exist start_time = None end_time = None if hasattr(step, 'start_time') and step.start_time: start_time = step.start_time.isoformat() if hasattr(step.start_time, 'isoformat') else str(step.start_time) if hasattr(step, 'end_time') and step.end_time: end_time = step.end_time.isoformat() if hasattr(step.end_time, 'isoformat') else str(step.end_time) # Create StepDetail object # Handle step_display_name which might be an Unassigned object step_display_name = getattr(step, 'step_display_name', None) if step_display_name is not None and hasattr(step_display_name, '__class__'): # Check if it's an Unassigned object from sagemaker_core if 'Unassigned' in step_display_name.__class__.__name__: step_display_name = None # Get failure reason if available failure_reason = getattr(step, 'failure_reason', None) if failure_reason is not None and hasattr(failure_reason, '__class__'): # Check if it's an Unassigned object from sagemaker_core if 'Unassigned' in failure_reason.__class__.__name__: failure_reason = None step_detail = StepDetail( name=getattr(step, 'step_name', 'Unknown Step'), status=getattr(step, 'step_status', 'Unknown'), start_time=start_time, end_time=end_time, display_name=step_display_name, failure_reason=failure_reason, job_arn=self._extract_job_arn_from_metadata(step) ) step_details.append(step_detail) except Exception as e: # If there's an error processing a step, log it but continue logger.warning(f"Failed to process pipeline step: {str(e)}") continue # Update the job's step details (reverse so earliest step appears first) self.status.step_details = list(reversed(step_details))
# ============================================================================ # Eval-Type-Specific Subclasses # ============================================================================
[docs] class BenchmarkEvaluationExecution(EvaluationPipelineExecution): """Benchmark evaluation execution subclass with type-specific show_results(). Provides benchmark-specific result display functionality for comparing custom model performance against a base model. """
[docs] @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="BenchmarkEvaluationExecution.show_results") def show_results(self) -> None: """Display benchmark evaluation results comparing custom vs base model. Shows aggregate metrics with detailed S3 artifact locations. Raises: ValueError: If execution hasn't succeeded. Example: .. code:: python execution = evaluator.evaluate() execution.wait() execution.show_results() """ # Refresh and validate status self.refresh() if self.status.overall_status != "Succeeded": raise ValueError( f"Cannot show results. Execution status is '{self.status.overall_status}'. " f"Results are only available after successful execution. " f"Use execution.wait() to wait for completion or check execution.status for details." ) # Delegate to utility from ..common_utils.show_results_utils import _show_benchmark_results _show_benchmark_results(self)
[docs] class LLMAJEvaluationExecution(EvaluationPipelineExecution): """LLM As Judge evaluation execution subclass with type-specific show_results(). Provides LLM-as-Judge-specific result display functionality with pagination and detailed judge explanations. """
[docs] @_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="LLMAJEvaluationExecution.show_results") def show_results( self, limit: int = 5, offset: int = 0, show_explanations: bool = False ) -> None: """Display LLM As Judge evaluation results with pagination. Shows per-evaluation results with prompt, response, and scores. Args: limit (int): Number of evaluation prompts to display. Set to None for all. Defaults to 5. offset (int): Starting index for pagination. Defaults to 0. show_explanations (bool): Whether to show judge explanations. Defaults to False. Raises: ValueError: If execution hasn't succeeded. Example: .. code:: python execution = evaluator.evaluate() execution.wait() # Show first 5 evaluations execution.show_results() # Show next 5 execution.show_results(limit=5, offset=5) # Show all with explanations execution.show_results(limit=None, show_explanations=True) """ # Refresh and validate status self.refresh() if self.status.overall_status != "Succeeded": raise ValueError( f"Cannot show results. Execution status is '{self.status.overall_status}'. " f"Results are only available after successful execution. " f"Use execution.wait() to wait for completion or check execution.status for details." ) # Delegate to utility from ..common_utils.show_results_utils import _show_llmaj_results _show_llmaj_results(self, limit=limit, offset=offset, show_explanations=show_explanations)