Source code for sagemaker.mlops.feature_store.athena_query

import os
import tempfile
from dataclasses import dataclass, field
from typing import Any, Dict
from urllib.parse import urlparse
import pandas as pd
from pandas import DataFrame

from sagemaker.mlops.feature_store.feature_utils import (
    start_query_execution,
    get_query_execution,
    wait_for_athena_query,
    download_athena_query_result,
)

from sagemaker.core.helper.session_helper import Session
from sagemaker.core.telemetry import Feature, _telemetry_emitter

[docs] @dataclass class AthenaQuery: """Class to manage querying of feature store data with AWS Athena. This class instantiates a AthenaQuery object that is used to retrieve data from feature store via standard SQL queries. Attributes: catalog (str): name of the data catalog. database (str): name of the database. table_name (str): name of the table. sagemaker_session (Session): instance of the Session class to perform boto calls. """ catalog: str database: str table_name: str sagemaker_session: Session _current_query_execution_id: str = field(default=None, init=False) _result_bucket: str = field(default=None, init=False) _result_file_prefix: str = field(default=None, init=False)
[docs] @_telemetry_emitter(Feature.FEATURE_STORE, "AthenaQuery.run") def run( self, query_string: str, output_location: str, kms_key: str = None, workgroup: str = None ) -> str: """Execute a SQL query given a query string, output location and kms key. This method executes the SQL query using Athena and outputs the results to output_location and returns the execution id of the query. Args: query_string: SQL query string. output_location: S3 URI of the query result. kms_key: KMS key id. If set, will be used to encrypt the query result file. workgroup (str): The name of the workgroup in which the query is being started. Returns: Execution id of the query. """ response = start_query_execution( session=self.sagemaker_session, catalog=self.catalog, database=self.database, query_string=query_string, output_location=output_location, kms_key=kms_key, workgroup=workgroup, ) self._current_query_execution_id = response["QueryExecutionId"] parsed_result = urlparse(output_location, allow_fragments=False) self._result_bucket = parsed_result.netloc self._result_file_prefix = parsed_result.path.strip("/") return self._current_query_execution_id
[docs] def wait(self): """Wait for the current query to finish.""" wait_for_athena_query(self.sagemaker_session, self._current_query_execution_id)
[docs] def get_query_execution(self) -> Dict[str, Any]: """Get execution status of the current query. Returns: Response dict from Athena. """ return get_query_execution(self.sagemaker_session, self._current_query_execution_id)
[docs] @_telemetry_emitter(Feature.FEATURE_STORE, "AthenaQuery.as_dataframe") def as_dataframe(self, **kwargs) -> DataFrame: """Download the result of the current query and load it into a DataFrame. Args: **kwargs (object): key arguments used for the method pandas.read_csv to be able to have a better tuning on data. For more info read: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html Returns: A pandas DataFrame contains the query result. """ state = self.get_query_execution()["QueryExecution"]["Status"]["State"] if state != "SUCCEEDED": if state in ("QUEUED", "RUNNING"): raise RuntimeError(f"Query {self._current_query_execution_id} still executing.") raise RuntimeError(f"Query {self._current_query_execution_id} failed.") output_file = os.path.join(tempfile.gettempdir(), f"{self._current_query_execution_id}.csv") download_athena_query_result( session=self.sagemaker_session, bucket=self._result_bucket, prefix=self._result_file_prefix, query_execution_id=self._current_query_execution_id, filename=output_file, ) kwargs.pop("delimiter", None) return pd.read_csv(output_file, delimiter=",", **kwargs)