Created
October 21, 2021 20:01
-
-
Save drussellmrichie/8c198d2d234a96d7ea8bfeeae3be23e0 to your computer and use it in GitHub Desktop.
Example of Shapley explanations for a zero-shot classifier
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import pipeline | |
from numpy import argmax | |
import shap | |
from transformers import RobertaConfig, RobertaModel | |
from transformers import RobertaTokenizer | |
import torch | |
import numpy as np | |
import scipy as sp | |
# import datasets | |
import pandas as pd | |
# dataset = datasets.load_dataset("emotion", split = "train") | |
# data = pd.DataFrame({'text':dataset['text'],'emotion':dataset['label']}) | |
transformer_name = 'roberta-large-mnli' | |
tokenizer = RobertaTokenizer.from_pretrained(transformer_name) | |
model = RobertaModel.from_pretrained(transformer_name) | |
dataset = pd.read_excel(r"C:\Me\custom_data.xlsx", engine='openpyxl') | |
data = pd.DataFrame({'text':dataset['Text'],'topic':dataset['Topic']}) | |
states_of_mind = ['happy','sad','angry','neutral'] | |
hypothesis = "I am feeling {}." | |
zero_shot_classifier = pipeline("zero-shot-classification", | |
tokenizer=tokenizer, | |
model=transformer_name, | |
device=0) | |
def predict_zeroshot(text): | |
results = zero_shot_classifier(text, | |
states_of_mind, | |
hypothesis_template=hypothesis | |
) | |
scores = results["scores"] | |
# you may want to convert your scores (which are probabilities, right?) to logits | |
return scores | |
explainer = shap.Explainer(predict_zeroshot, tokenizer, output_names=states_of_mind) | |
shap_values = explainer(data['text'][:3]) | |
shap.plots.text(shap_values) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment