Tabular Online Explainability with SageMaker Clarify

Introduction

Amazon SageMaker Clarify helps improve your machine learning models by detecting potential bias and helping explain how these models make predictions. The fairness and explainability functionality provided by SageMaker Clarify takes a step towards enabling AWS customers to build trustworthy and understandable machine learning models.

SageMaker Clarify currently supports explainability for SageMaker models as an offline processing job. This example notebook showcases a new feature for explainability on a SageMaker real-time inference endpoint, a.k.a. online explainability.

This example notebook walks you through:
1. Key terms and concepts needed to understand SageMaker Clarify 1. Trained the model on a training dataset. 1. Create a model from trained model artifacts, create an endpoint configuration with the new SageMaker Clarify explainer configuration, and create an endpoint using the same explainer configuration. 1. Invoke the endpoint with single and batch request with different EnableExplanations query. 1. Explaining the importance of the various input features on the model’s decision.

In doing so, the notebook will first train a SageMaker XGBoost model using training dataset, then use SageMaker Clarify to analyze a testing dataset in CSV format.

General Setup

We recommend you use Python 3 (Data Science) kernel on SageMaker Studio or conda_python3 kernel on SageMaker Notebook Instance.

Install dependencies

Upgrade the SageMaker Python SDK. Install shap and matplotlib which are used to visualize the feature attributions.

[ ]:
!pip install sagemaker --upgrade
!pip install boto3 --upgrade
!pip install botocore --upgrade
!pip install shap --upgrade

Import libraries

[ ]:
import boto3
import io
import os
import shap
import pprint
import pandas as pd
import numpy as np
from collections import OrderedDict
from sagemaker import get_execution_role, Session
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.utils import unique_name_from_base

Set configurations

[ ]:
boto3_session = boto3.session.Session()
sagemaker_client = boto3.client("sagemaker")
sagemaker_runtime_client = boto3.client("sagemaker-runtime")

# Initialize sagemaker session
sagemaker_session = Session(
    boto_session=boto3_session,
    sagemaker_client=sagemaker_client,
    sagemaker_runtime_client=sagemaker_runtime_client,
)

region = sagemaker_session.boto_region_name
print(f"Region: {region}")

role = get_execution_role()
print(f"Role: {role}")

s3_client = boto3.client("s3")

prefix = unique_name_from_base("DEMO-Tabular-Adult")

s3_bucket = sagemaker_session.default_bucket()
s3_prefix = f"sagemaker/{prefix}"
s3_key = f"s3://{s3_bucket}/{s3_prefix}"
print(f"Demo S3 key: {s3_key}")

model_name = f"{prefix}-model"
print(f"Demo model name: {model_name}")
endpoint_config_name = f"{prefix}-endpoint-config"
print(f"Demo endpoint config name: {endpoint_config_name}")
endpoint_name = f"{prefix}-endpoint"
print(f"Demo endpoint name: {endpoint_name}")

# Instance type for training and hosting
instance_type = "ml.m5.xlarge"

Create serializer and deserializer

CSV serializer to serialize test data to string

[ ]:
csv_serializer = CSVSerializer()

JSON deserializer to deserialize invoke endpoint response

[ ]:
json_deserializer = JSONDeserializer()

For visualization

SHAP plots are useful visualization tools to interpret the explanations. For example, SHAP additive force layout shows how each feature contributes to pushing the base value (also called the expected value which is the mean predictions of the training dataset) to the corresponding prediction. Features that push the prediction higher are in red color, while those push the prediction lower are in blue.

[ ]:
def force_plot(expected_value, shap_values, feature_data, feature_headers):
    """
    Visualize the given SHAP values with an additive force layout.

    For more information: https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/tree_based_models/Force%20Plot%20Colors.html
    """
    force_plot_display = shap.plots.force(
        base_value=expected_value,
        shap_values=shap_values,
        features=feature_data,
        feature_names=feature_headers,
        matplotlib=True,
    )


def display_plots(explanations, expected_value, request_records, predictions):
    """
    Display the Model Explainability plots
    """
    per_request_shap_values = OrderedDict()
    feature_headers = []
    for i, record_output in enumerate(explanations):
        per_record_shap_values = []
        if record_output is not None:
            feature_headers = []
            for feature_attribution in record_output:
                per_record_shap_values.append(
                    feature_attribution["attributions"][0]["attribution"][0]
                )
                feature_headers.append(feature_attribution["feature_header"])
            per_request_shap_values[i] = per_record_shap_values

    for record_index, shap_values in per_request_shap_values.items():
        print(
            f"Visualize the SHAP values for Record number {record_index + 1} with Model Prediction: {predictions[record_index][0]}"
        )
        force_plot(
            expected_value,
            np.array(shap_values),
            request_records.iloc[record_index],
            feature_headers,
        )


def visualize_result(result, request_records, expected_value):
    """
    Visualize the output from the endpoint.
    """
    predictions = pd.read_csv(io.StringIO(result["predictions"]["data"]), header=None)
    predictions = predictions.values.tolist()
    print(f"Model Inference output: ")
    for i, model_output in enumerate(predictions):
        print(f"Record: {i + 1}\tModel Prediction: {model_output[0]}")

    if "kernel_shap" in result["explanations"]:
        explanations = result["explanations"]["kernel_shap"]
        display_plots(explanations, expected_value, request_records, predictions)
    else:
        print(f"No Clarify explanations for the record(s)")

Prepare data

Download data

Data Source: https://archive.ics.uci.edu/ml/machine-learning-databases/adult/

Let’s download the data and save it in the local folder with the name adult.data and adult.test from UCI repository\(^{[2]}\).

\(^{[2]}\)Dua Dheeru, and Efi Karra Taniskidou. “UCI Machine Learning Repository”. Irvine, CA: University of California, School of Information and Computer Science (2017).

[ ]:
adult_columns = [
    "Age",
    "Workclass",
    "fnlwgt",
    "Education",
    "Education-Num",
    "Marital Status",
    "Occupation",
    "Relationship",
    "Ethnic group",
    "Sex",
    "Capital Gain",
    "Capital Loss",
    "Hours per week",
    "Country",
    "Target",
]
if not os.path.isfile("adult.data"):
    s3_client.download_file(
        "sagemaker-sample-files", "datasets/tabular/uci_adult/adult.data", "adult.data"
    )
    print(f"adult.data saved!")
else:
    print(f"adult.data already on disk.")

if not os.path.isfile("adult.test"):
    s3_client.download_file(
        "sagemaker-sample-files", "datasets/tabular/uci_adult/adult.test", "adult.test"
    )
    print(f"adult.test saved!")
else:
    print(f"adult.test already on disk.")

Loading the data: Adult Dataset

From the UCI repository of machine learning datasets, this database contains 14 features concerning demographic characteristics of 45,222 rows (32,561 for training and 12,661 for testing). The task is to predict whether a person has a yearly income that is more or less than $50,000.

Here are the features and their possible values: 1. Age: continuous. 1. Workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked. 1. Fnlwgt: continuous (the number of people the census takers believe that observation represents). 1. Education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. 1. Education-num: continuous. 1. Marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse. 1. Occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. 1. Relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. 1. Ethnic group: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. 1. Sex: Female, Male. * Note: this data is extracted from the 1994 Census and enforces a binary option on Sex 1. Capital-gain: continuous. 1. Capital-loss: continuous. 1. Hours-per-week: continuous. 1. Native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.

Next, we specify our binary prediction task:
15. Target: <=50,000, >$50,000.
[ ]:
training_data = pd.read_csv(
    "adult.data", names=adult_columns, sep=r"\s*,\s*", engine="python", na_values="?"
).dropna()

testing_data = pd.read_csv(
    "adult.test", names=adult_columns, sep=r"\s*,\s*", engine="python", na_values="?", skiprows=1
).dropna()

training_data.head()

Data inspection

Plotting histograms for the distribution of the different features is a good way to visualize the data. Let’s plot a few of the features that can be considered sensitive.
Let’s take a look specifically at the Sex feature of a census respondent. In the first plot we see that there are fewer Female respondents as a whole but especially in the positive outcomes, where they form ~\(\frac{1}{7}\)th of respondents.
[ ]:
training_data["Sex"].value_counts().sort_values().plot(kind="bar", title="Counts of Sex", rot=0)
[ ]:
training_data["Sex"].where(training_data["Target"] == ">50K").value_counts().sort_values().plot(
    kind="bar", title="Counts of Sex earning >$50K", rot=0
)

Encode and Upload the Dataset

Here we encode the training and test data. Encoding input data is not necessary for SageMaker Clarify, but is necessary for the model.

[ ]:
from sklearn import preprocessing


def number_encode_features(df):
    result = df.copy()
    encoders = {}
    for column in result.columns:
        if result.dtypes[column] == np.object:
            encoders[column] = preprocessing.LabelEncoder()
            #  print('Column:', column, result[column])
            result[column] = encoders[column].fit_transform(result[column].fillna("None"))
    return result, encoders


training_data = pd.concat([training_data["Target"], training_data.drop(["Target"], axis=1)], axis=1)
training_data, _ = number_encode_features(training_data)
training_data.to_csv("train_data.csv", index=False, header=False)

testing_data, _ = number_encode_features(testing_data)
test_features = testing_data.drop(["Target"], axis=1)
test_target = testing_data["Target"]
test_features.to_csv("test_features.csv", index=False, header=False)

A quick note about our encoding: the “Female” Sex value has been encoded as 0 and “Male” as 1.

[ ]:
training_data.head()

Get the feature names and the label names from the dataset

[ ]:
feature_headers = testing_data.columns.to_list()
label_header = feature_headers.pop()
print(f"Feature names: {feature_headers}")
print(f"Label name: {label_header}")

Lastly, let’s upload the data to S3 so that they can be used by the training job.

[ ]:
from sagemaker.s3 import S3Uploader
from sagemaker.inputs import TrainingInput

train_uri = S3Uploader.upload("train_data.csv", "s3://{}/{}".format(s3_bucket, prefix))
train_input = TrainingInput(train_uri, content_type="csv")
test_uri = S3Uploader.upload("test_features.csv", "s3://{}/{}".format(s3_bucket, prefix))

Train XGBoost Model

Since our focus is on understanding how to use SageMaker Clarify, we keep it simple by using a standard XGBoost model.

[ ]:
from sagemaker.image_uris import retrieve
from sagemaker.estimator import Estimator

container = retrieve("xgboost", region, version="1.3-1")
xgb = Estimator(
    container,
    role,
    instance_count=1,
    instance_type=instance_type,
    disable_profiler=True,
    debugger_hook_config=False,
)

xgb.set_hyperparameters(
    max_depth=5,
    eta=0.2,
    gamma=4,
    min_child_weight=6,
    subsample=0.8,
    objective="binary:logistic",
    num_round=800,
)

xgb.fit({"train": train_input}, logs=False)

Create a new model object which will be used to create the SageMaker model.

[ ]:
model = xgb.create_model(name=model_name)
container_def = model.prepare_container_def()
container_def

Create endpoint

Create model

The following parameters are required to create a SageMaker model:

  • ExecutionRoleArn: The ARN of the IAM role that Amazon SageMaker can assume to access the model artifacts/ docker images for deployment

  • ModelName: name of the SageMaker model.

  • PrimaryContainer: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions.

[ ]:
sagemaker_client.create_model(
    ExecutionRoleArn=role,
    ModelName=model_name,
    PrimaryContainer=container_def,
)
print(f"Model created: {model_name}")

Create endpoint config

Create an endpoint configuration by calling the create_endpoint_config API. Here, supply the same model_name used in the create_model API call. The create_endpoint_config now supports the additional parameter ClarifyExplainerConfig to enable the Clarify explainer. The SHAP baseline is mandatory, it can be provided either as inline baseline data (the ShapBaseline parameter) or by a S3 baseline file (the ShapBaselineUri parameter). Please see the developer guide for the other parameters.

[ ]:
baseline = test_features.mean().to_list()  # Inline baseline data
print(f"Use the mean of the test data as the SHAP baseline: {baseline}")
[ ]:
sagemaker_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "TestVariant",
            "ModelName": model_name,
            "InitialInstanceCount": 1,
            "InstanceType": instance_type,
        }
    ],
    ExplainerConfig={
        "ClarifyExplainerConfig": {
            # "EnableExplanations": "`false`",  # By default explanations are enabled, but you can change the condition by this parameter.
            "InferenceConfig": {
                "FeatureHeaders": feature_headers,
            },
            "ShapConfig": {
                "ShapBaselineConfig": {
                    "ShapBaseline": csv_serializer.serialize(baseline),  # inline baseline data
                }
            },
        }
    },
)

Create endpoint

Once you have your model and endpoint configuration ready, use the create_endpoint API to create your endpoint. The endpoint_name must be unique within an AWS Region in your AWS account. The create_endpoint API is synchronous in nature and returns an immediate response with the endpoint status being Creating state.

[ ]:
sagemaker_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_config_name,
)

Wait for the endpoint to be in “InService” state.

[ ]:
sagemaker_session.wait_for_endpoint(endpoint_name)

Invoke endpoint

There are expanding business needs and legislative regulations that require explanations of why a model made the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.

Below are the several different combination of endpoint invocation, call them one by one and visualize the explanations by running the subsequent cell.

Single record request

Put only one record in the request body, and then send the request to the endpoint to get its predictions and explanations.

[ ]:
request_records = test_features.iloc[:1, :]
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/csv",
    Body=csv_serializer.serialize(request_records.to_numpy()),
)
pprint.pprint(response)

Print the response body which is JSON. Please see the developer guide for its schema.

[ ]:
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)

Use SHAP plots to visualize the result. SHAP additive force layout shows how each feature contributes to pushing the base value (also called the expected value which is the mean predictions of the training dataset) to the corresponding prediction. Features that push the prediction higher are in red color, while those push the prediction lower are in blue.

The expected value is the average of the model predictions over the baseline. Here we predict the baseline data and then compute the expected value. Only the predictions are needed, so the EnableExplanations parameter is used to disable the explanations.

[ ]:
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/csv",
    Body=csv_serializer.serialize(baseline),
    EnableExplanations="`false`",  # Do not provide explanations
)
json_object = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
expected_value = float(
    pd.read_csv(io.StringIO(json_object["predictions"]["data"]), header=None)
    .astype(float)
    .mean(axis=1)
)
print(f"expected value: {expected_value}")
[ ]:
visualize_result(result, request_records, expected_value)

Single record request, no explanation

Use the EnableExplanations parameter to disable the explanations for this request.

[ ]:
request_records = test_features.iloc[:1, :]
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/csv",
    Body=csv_serializer.serialize(request_records.to_numpy()),
    EnableExplanations="`false`",  # Do not provide explanations
)
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
[ ]:
visualize_result(result, request_records, expected_value)

Batch request, explain both

Put two records in the request body, and then send the request to the endpoint to get their predictions and explanations.

[ ]:
request_records = test_features.iloc[:2, :]
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/csv",
    Body=csv_serializer.serialize(request_records.to_numpy()),
)
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
[ ]:
visualize_result(result, request_records, expected_value)
[ ]:
request_records = test_features.iloc[:2, :]
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/csv",
    Body=csv_serializer.serialize(request_records.to_numpy()),
    EnableExplanations="`false`",  # Do not provide explanations
)
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
[ ]:
visualize_result(result, request_records, expected_value)

Batch request with more records, explain some of the records

Put a few more records to the request body, and then use the EnableExplanations expression to filter the records to be explained according to their predictions.

[ ]:
request_records = test_features.iloc[:70, :]
response = sagemaker_runtime_client.invoke_endpoint(
    EndpointName=endpoint_name,
    ContentType="text/csv",
    Body=csv_serializer.serialize(request_records.to_numpy()),
    EnableExplanations="[0]>`0.95`",  # Explain a record only when its prediction is greater than the threshold
)
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
[ ]:
visualize_result(result, request_records, expected_value)

Cleanup

Finally, don’t forget to clean up the resources we set up and used for this demo!

[ ]:
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
[ ]:
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
[ ]:
sagemaker_client.delete_model(ModelName=model_name)