Tabular Online Explainability with SageMaker Clarify
Introduction
Amazon SageMaker Clarify helps improve your machine learning models by detecting potential bias and helping explain how these models make predictions. The fairness and explainability functionality provided by SageMaker Clarify takes a step towards enabling AWS customers to build trustworthy and understandable machine learning models.
SageMaker Clarify currently supports explainability for SageMaker models as an offline processing job. This example notebook showcases a new feature for explainability on a SageMaker real-time inference endpoint, a.k.a. online explainability.
EnableExplanations query. 1. Explaining the importance of the various input features on the model’s decision.In doing so, the notebook will first train a SageMaker XGBoost model using training dataset, then use SageMaker Clarify to analyze a testing dataset in CSV format.
General Setup
We recommend you use Python 3 (Data Science) kernel on SageMaker Studio or conda_python3 kernel on SageMaker Notebook Instance.
Install dependencies
Upgrade the SageMaker Python SDK. Install shap and matplotlib which are used to visualize the feature attributions.
[ ]:
!pip install sagemaker --upgrade
!pip install boto3 --upgrade
!pip install botocore --upgrade
!pip install shap --upgrade
Import libraries
[ ]:
import boto3
import io
import os
import shap
import pprint
import pandas as pd
import numpy as np
from collections import OrderedDict
from sagemaker import get_execution_role, Session
from sagemaker.serializers import CSVSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.utils import unique_name_from_base
Set configurations
[ ]:
boto3_session = boto3.session.Session()
sagemaker_client = boto3.client("sagemaker")
sagemaker_runtime_client = boto3.client("sagemaker-runtime")
# Initialize sagemaker session
sagemaker_session = Session(
boto_session=boto3_session,
sagemaker_client=sagemaker_client,
sagemaker_runtime_client=sagemaker_runtime_client,
)
region = sagemaker_session.boto_region_name
print(f"Region: {region}")
role = get_execution_role()
print(f"Role: {role}")
s3_client = boto3.client("s3")
prefix = unique_name_from_base("DEMO-Tabular-Adult")
s3_bucket = sagemaker_session.default_bucket()
s3_prefix = f"sagemaker/{prefix}"
s3_key = f"s3://{s3_bucket}/{s3_prefix}"
print(f"Demo S3 key: {s3_key}")
model_name = f"{prefix}-model"
print(f"Demo model name: {model_name}")
endpoint_config_name = f"{prefix}-endpoint-config"
print(f"Demo endpoint config name: {endpoint_config_name}")
endpoint_name = f"{prefix}-endpoint"
print(f"Demo endpoint name: {endpoint_name}")
# Instance type for training and hosting
instance_type = "ml.m5.xlarge"
Create serializer and deserializer
CSV serializer to serialize test data to string
[ ]:
csv_serializer = CSVSerializer()
JSON deserializer to deserialize invoke endpoint response
[ ]:
json_deserializer = JSONDeserializer()
For visualization
SHAP plots are useful visualization tools to interpret the explanations. For example, SHAP additive force layout shows how each feature contributes to pushing the base value (also called the expected value which is the mean predictions of the training dataset) to the corresponding prediction. Features that push the prediction higher are in red color, while those push the prediction lower are in blue.
[ ]:
def force_plot(expected_value, shap_values, feature_data, feature_headers):
"""
Visualize the given SHAP values with an additive force layout.
For more information: https://shap.readthedocs.io/en/latest/example_notebooks/tabular_examples/tree_based_models/Force%20Plot%20Colors.html
"""
force_plot_display = shap.plots.force(
base_value=expected_value,
shap_values=shap_values,
features=feature_data,
feature_names=feature_headers,
matplotlib=True,
)
def display_plots(explanations, expected_value, request_records, predictions):
"""
Display the Model Explainability plots
"""
per_request_shap_values = OrderedDict()
feature_headers = []
for i, record_output in enumerate(explanations):
per_record_shap_values = []
if record_output is not None:
feature_headers = []
for feature_attribution in record_output:
per_record_shap_values.append(
feature_attribution["attributions"][0]["attribution"][0]
)
feature_headers.append(feature_attribution["feature_header"])
per_request_shap_values[i] = per_record_shap_values
for record_index, shap_values in per_request_shap_values.items():
print(
f"Visualize the SHAP values for Record number {record_index + 1} with Model Prediction: {predictions[record_index][0]}"
)
force_plot(
expected_value,
np.array(shap_values),
request_records.iloc[record_index],
feature_headers,
)
def visualize_result(result, request_records, expected_value):
"""
Visualize the output from the endpoint.
"""
predictions = pd.read_csv(io.StringIO(result["predictions"]["data"]), header=None)
predictions = predictions.values.tolist()
print(f"Model Inference output: ")
for i, model_output in enumerate(predictions):
print(f"Record: {i + 1}\tModel Prediction: {model_output[0]}")
if "kernel_shap" in result["explanations"]:
explanations = result["explanations"]["kernel_shap"]
display_plots(explanations, expected_value, request_records, predictions)
else:
print(f"No Clarify explanations for the record(s)")
Prepare data
Download data
Data Source: https://archive.ics.uci.edu/ml/machine-learning-databases/adult/
Let’s download the data and save it in the local folder with the name adult.data and adult.test from UCI repository\(^{[2]}\).
\(^{[2]}\)Dua Dheeru, and Efi Karra Taniskidou. “UCI Machine Learning Repository”. Irvine, CA: University of California, School of Information and Computer Science (2017).
[ ]:
adult_columns = [
"Age",
"Workclass",
"fnlwgt",
"Education",
"Education-Num",
"Marital Status",
"Occupation",
"Relationship",
"Ethnic group",
"Sex",
"Capital Gain",
"Capital Loss",
"Hours per week",
"Country",
"Target",
]
if not os.path.isfile("adult.data"):
s3_client.download_file(
"sagemaker-sample-files", "datasets/tabular/uci_adult/adult.data", "adult.data"
)
print(f"adult.data saved!")
else:
print(f"adult.data already on disk.")
if not os.path.isfile("adult.test"):
s3_client.download_file(
"sagemaker-sample-files", "datasets/tabular/uci_adult/adult.test", "adult.test"
)
print(f"adult.test saved!")
else:
print(f"adult.test already on disk.")
Loading the data: Adult Dataset
From the UCI repository of machine learning datasets, this database contains 14 features concerning demographic characteristics of 45,222 rows (32,561 for training and 12,661 for testing). The task is to predict whether a person has a yearly income that is more or less than $50,000.
Here are the features and their possible values: 1. Age: continuous. 1. Workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked. 1. Fnlwgt: continuous (the number of people the census takers believe that observation represents). 1. Education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool. 1. Education-num: continuous. 1. Marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse. 1. Occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces. 1. Relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried. 1. Ethnic group: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black. 1. Sex: Female, Male. * Note: this data is extracted from the 1994 Census and enforces a binary option on Sex 1. Capital-gain: continuous. 1. Capital-loss: continuous. 1. Hours-per-week: continuous. 1. Native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.
[ ]:
training_data = pd.read_csv(
"adult.data", names=adult_columns, sep=r"\s*,\s*", engine="python", na_values="?"
).dropna()
testing_data = pd.read_csv(
"adult.test", names=adult_columns, sep=r"\s*,\s*", engine="python", na_values="?", skiprows=1
).dropna()
training_data.head()
Data inspection
[ ]:
training_data["Sex"].value_counts().sort_values().plot(kind="bar", title="Counts of Sex", rot=0)
[ ]:
training_data["Sex"].where(training_data["Target"] == ">50K").value_counts().sort_values().plot(
kind="bar", title="Counts of Sex earning >$50K", rot=0
)
Encode and Upload the Dataset
Here we encode the training and test data. Encoding input data is not necessary for SageMaker Clarify, but is necessary for the model.
[ ]:
from sklearn import preprocessing
def number_encode_features(df):
result = df.copy()
encoders = {}
for column in result.columns:
if result.dtypes[column] == np.object:
encoders[column] = preprocessing.LabelEncoder()
# print('Column:', column, result[column])
result[column] = encoders[column].fit_transform(result[column].fillna("None"))
return result, encoders
training_data = pd.concat([training_data["Target"], training_data.drop(["Target"], axis=1)], axis=1)
training_data, _ = number_encode_features(training_data)
training_data.to_csv("train_data.csv", index=False, header=False)
testing_data, _ = number_encode_features(testing_data)
test_features = testing_data.drop(["Target"], axis=1)
test_target = testing_data["Target"]
test_features.to_csv("test_features.csv", index=False, header=False)
A quick note about our encoding: the “Female” Sex value has been encoded as 0 and “Male” as 1.
[ ]:
training_data.head()
Get the feature names and the label names from the dataset
[ ]:
feature_headers = testing_data.columns.to_list()
label_header = feature_headers.pop()
print(f"Feature names: {feature_headers}")
print(f"Label name: {label_header}")
Lastly, let’s upload the data to S3 so that they can be used by the training job.
[ ]:
from sagemaker.s3 import S3Uploader
from sagemaker.inputs import TrainingInput
train_uri = S3Uploader.upload("train_data.csv", "s3://{}/{}".format(s3_bucket, prefix))
train_input = TrainingInput(train_uri, content_type="csv")
test_uri = S3Uploader.upload("test_features.csv", "s3://{}/{}".format(s3_bucket, prefix))
Train XGBoost Model
Since our focus is on understanding how to use SageMaker Clarify, we keep it simple by using a standard XGBoost model.
[ ]:
from sagemaker.image_uris import retrieve
from sagemaker.estimator import Estimator
container = retrieve("xgboost", region, version="1.3-1")
xgb = Estimator(
container,
role,
instance_count=1,
instance_type=instance_type,
disable_profiler=True,
debugger_hook_config=False,
)
xgb.set_hyperparameters(
max_depth=5,
eta=0.2,
gamma=4,
min_child_weight=6,
subsample=0.8,
objective="binary:logistic",
num_round=800,
)
xgb.fit({"train": train_input}, logs=False)
Create a new model object which will be used to create the SageMaker model.
[ ]:
model = xgb.create_model(name=model_name)
container_def = model.prepare_container_def()
container_def
Create endpoint
Create model
The following parameters are required to create a SageMaker model:
ExecutionRoleArn: The ARN of the IAM role that Amazon SageMaker can assume to access the model artifacts/ docker images for deploymentModelName: name of the SageMaker model.PrimaryContainer: The location of the primary docker image containing inference code, associated artifacts, and custom environment map that the inference code uses when the model is deployed for predictions.
[ ]:
sagemaker_client.create_model(
ExecutionRoleArn=role,
ModelName=model_name,
PrimaryContainer=container_def,
)
print(f"Model created: {model_name}")
Create endpoint config
Create an endpoint configuration by calling the create_endpoint_config API. Here, supply the same model_name used in the create_model API call. The create_endpoint_config now supports the additional parameter ClarifyExplainerConfig to enable the Clarify explainer. The SHAP baseline is mandatory, it can be provided either as inline baseline data (the ShapBaseline parameter) or by a S3 baseline file (the ShapBaselineUri parameter). Please see the developer guide for the
other parameters.
[ ]:
baseline = test_features.mean().to_list() # Inline baseline data
print(f"Use the mean of the test data as the SHAP baseline: {baseline}")
[ ]:
sagemaker_client.create_endpoint_config(
EndpointConfigName=endpoint_config_name,
ProductionVariants=[
{
"VariantName": "TestVariant",
"ModelName": model_name,
"InitialInstanceCount": 1,
"InstanceType": instance_type,
}
],
ExplainerConfig={
"ClarifyExplainerConfig": {
# "EnableExplanations": "`false`", # By default explanations are enabled, but you can change the condition by this parameter.
"InferenceConfig": {
"FeatureHeaders": feature_headers,
},
"ShapConfig": {
"ShapBaselineConfig": {
"ShapBaseline": csv_serializer.serialize(baseline), # inline baseline data
}
},
}
},
)
Create endpoint
Once you have your model and endpoint configuration ready, use the create_endpoint API to create your endpoint. The endpoint_name must be unique within an AWS Region in your AWS account. The create_endpoint API is synchronous in nature and returns an immediate response with the endpoint status being Creating state.
[ ]:
sagemaker_client.create_endpoint(
EndpointName=endpoint_name,
EndpointConfigName=endpoint_config_name,
)
Wait for the endpoint to be in “InService” state.
[ ]:
sagemaker_session.wait_for_endpoint(endpoint_name)
Invoke endpoint
There are expanding business needs and legislative regulations that require explanations of why a model made the decision it did. SageMaker Clarify uses SHAP to explain the contribution that each input feature makes to the final decision.
Below are the several different combination of endpoint invocation, call them one by one and visualize the explanations by running the subsequent cell.
Single record request
Put only one record in the request body, and then send the request to the endpoint to get its predictions and explanations.
[ ]:
request_records = test_features.iloc[:1, :]
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=csv_serializer.serialize(request_records.to_numpy()),
)
pprint.pprint(response)
Print the response body which is JSON. Please see the developer guide for its schema.
[ ]:
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
Use SHAP plots to visualize the result. SHAP additive force layout shows how each feature contributes to pushing the base value (also called the expected value which is the mean predictions of the training dataset) to the corresponding prediction. Features that push the prediction higher are in red color, while those push the prediction lower are in blue.
The expected value is the average of the model predictions over the baseline. Here we predict the baseline data and then compute the expected value. Only the predictions are needed, so the EnableExplanations parameter is used to disable the explanations.
[ ]:
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=csv_serializer.serialize(baseline),
EnableExplanations="`false`", # Do not provide explanations
)
json_object = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
expected_value = float(
pd.read_csv(io.StringIO(json_object["predictions"]["data"]), header=None)
.astype(float)
.mean(axis=1)
)
print(f"expected value: {expected_value}")
[ ]:
visualize_result(result, request_records, expected_value)
Single record request, no explanation
Use the EnableExplanations parameter to disable the explanations for this request.
[ ]:
request_records = test_features.iloc[:1, :]
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=csv_serializer.serialize(request_records.to_numpy()),
EnableExplanations="`false`", # Do not provide explanations
)
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
[ ]:
visualize_result(result, request_records, expected_value)
Batch request, explain both
Put two records in the request body, and then send the request to the endpoint to get their predictions and explanations.
[ ]:
request_records = test_features.iloc[:2, :]
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=csv_serializer.serialize(request_records.to_numpy()),
)
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
[ ]:
visualize_result(result, request_records, expected_value)
[ ]:
request_records = test_features.iloc[:2, :]
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=csv_serializer.serialize(request_records.to_numpy()),
EnableExplanations="`false`", # Do not provide explanations
)
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
[ ]:
visualize_result(result, request_records, expected_value)
Batch request with more records, explain some of the records
Put a few more records to the request body, and then use the EnableExplanations expression to filter the records to be explained according to their predictions.
[ ]:
request_records = test_features.iloc[:70, :]
response = sagemaker_runtime_client.invoke_endpoint(
EndpointName=endpoint_name,
ContentType="text/csv",
Body=csv_serializer.serialize(request_records.to_numpy()),
EnableExplanations="[0]>`0.95`", # Explain a record only when its prediction is greater than the threshold
)
result = json_deserializer.deserialize(response["Body"], content_type=response["ContentType"])
pprint.pprint(result)
[ ]:
visualize_result(result, request_records, expected_value)
Cleanup
Finally, don’t forget to clean up the resources we set up and used for this demo!
[ ]:
sagemaker_client.delete_endpoint(EndpointName=endpoint_name)
[ ]:
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
[ ]:
sagemaker_client.delete_model(ModelName=model_name)