Zero Shot Classification
Machine Learning is commonly used to classify data into pre-defined categories.
Typically, to reach high classification performance, models need to be trained on context specific labeled data. Zero-shot classification is a type of classification that uses a pre-trained model and does not require further training on context specific data.
Tutorial Introduction
In this tutorial we want to detect dissatisfaction in an email dataset. Let's create a basic dataset:
import pandas as pd
from transformers import pipeline
from melusine.base import MelusineDetector
def create_dataset():
df = pd.DataFrame(
[
{
"header": "Dossier 123456",
"body": "Merci beaucoup pour votre gentillesse et votre écoute !",
},
{
"header": "Réclamation (Dossier 987654)",
"body": ("Bonjour, je ne suis pas satisfait de cette situation, " "répondez-moi rapidement svp!"),
},
]
)
return df
header | body | |
---|---|---|
0 | Dossier 123456 | Merci beaucoup pour votre gentillesse et votre écoute ! |
1 | Réclamation (Dossier 987654) | Bonjour, je ne suis pas satisfait de cette situation, répondez-moi rapidement svp! |
Classify with Zero-Shot-Classification
The transformers
library, provided by HuggingFace, makes it really simple to use pre-trained models for zero shot classification.
model_name_or_path = "cmarkea/distilcamembert-base-nli"
sentences = [
"Quelle belle journée aujourd'hui",
"La marée est haute",
"Ce film est une catastrophe, je suis en colère",
]
classifier = pipeline(task="zero-shot-classification", model=model_name_or_path, tokenizer=model_name_or_path)
result = classifier(
sequences=sentences, candidate_labels=", ".join(["positif", "négatif"]), hypothesis_template="Ce texte est {}."
)
The classifier returns a score for the positive (positif in French) and negative (négatif in French) labels for each input text:
[
{
'sequence': "Quelle belle journée aujourd'hui",
'labels': ['positif', 'négatif'],
'scores': [0.95, 0.05]
},
{
'sequence': 'La marée est haute',
'labels': ['positif', 'négatif'],
'scores': [0.76, 0.24]
},
{'sequence': 'Ce film est une catastrophe, je suis en colère',
'labels': ['négatif', 'positif'],
'scores': [0.97, 0.03]
}
]
Implement a Dissatisfaction Detector
A full email processing pipeline may contain multiple models.
Melusine uses the MelusineDetector
template class to standardize how models are integrated into a pipeline.
class DissatisfactionDetector(MelusineDetector):
"""
Detect if the text expresses dissatisfaction.
"""
# Dataframe column names
OUTPUT_RESULT_COLUMN = "dissatisfaction_result"
TMP_DETECTION_INPUT_COLUMN = "detection_input"
TMP_DETECTION_OUTPUT_COLUMN = "detection_output"
# Model inference parameters
POSITIVE_LABEL = "positif"
NEGATIVE_LABEL = "négatif"
HYPOTHESIS_TEMPLATE = "Ce texte est {}."
def __init__(self, model_name_or_path: str, text_columns: List[str], threshold: float):
self.text_columns = text_columns
self.threshold = threshold
self.classifier = pipeline(
task="zero-shot-classification", model=model_name_or_path, tokenizer=model_name_or_path
)
super().__init__(input_columns=text_columns, output_columns=[self.OUTPUT_RESULT_COLUMN], name="dissatisfaction")
The pre_detect
method assembles the text that we want to use for classification.
def pre_detect(self, row, debug_mode=False):
# Assemble the text columns into a single text
effective_text = ""
for col in self.text_columns:
effective_text += "\n" + row[col]
row[self.TMP_DETECTION_INPUT_COLUMN] = effective_text
# Store the effective detection text in the debug data
if debug_mode:
row[self.debug_dict_col] = {"detection_input": row[self.TMP_DETECTION_INPUT_COLUMN]}
return row
The detect
method runs the classification model on the text.
def detect(self, row, debug_mode=False):
# Run the classifier on the text
pipeline_result = self.classifier(
sequences=row[self.TMP_DETECTION_INPUT_COLUMN],
candidate_labels=", ".join([self.POSITIVE_LABEL, self.NEGATIVE_LABEL]),
hypothesis_template=self.HYPOTHESIS_TEMPLATE,
)
# Format classification result
result_dict = dict(zip(pipeline_result["labels"], pipeline_result["scores"]))
row[self.TMP_DETECTION_OUTPUT_COLUMN] = result_dict
# Store ML results in the debug data
if debug_mode:
row[self.debug_dict_col].update(result_dict)
return row
The post_detect
method applies a threshold on the prediction score to determine the detection result.
def post_detect(self, row, debug_mode=False):
# Compare classification score to the detection threshold
if row[self.TMP_DETECTION_OUTPUT_COLUMN][self.NEGATIVE_LABEL] > self.threshold:
row[self.OUTPUT_RESULT_COLUMN] = True
else:
row[self.OUTPUT_RESULT_COLUMN] = False
return row
On top of that, the detector takes care of building debug data to make the result explicable.
Run Detection
Putting it all together, we run the detector on the input dataset.
df = create_dataset()
detector = DissatisfactionDetector(
model_name_or_path="cmarkea/distilcamembert-base-nli",
text_columns=["header", "body"],
threshold=0.7,
)
df = detector.transform(df)
As a result, we get a new column dissatisfaction_result
with the detection result.
We could have detection details by running the detector in debug mode.
header | body | dissatisfaction_result | |
---|---|---|---|
0 | Dossier 123456 | Merci beaucoup pour votre gentillesse et votre écoute ! | False |
1 | Réclamation (Dossier 987654) | Bonjour, je ne suis pas satisfait de cette situation, répondez-moi rapidement svp! | True |