Skip to content

Zero Shot Classification

Machine Learning is commonly used to classify data into pre-defined categories.

---
title: Email classification
---
flowchart LR
    Input[[Email]] --> X(((Classifier)))
    X --> A(Car)
    X --> B(Boat)
    X --> C(Housing)
    X --> D(Health)

Typically, to reach high classification performance, models need to be trained on context specific labeled data. Zero-shot classification is a type of classification that uses a pre-trained model and does not require further training on context specific data.

Tutorial Introduction

In this tutorial we want to detect dissatisfaction in an email dataset. Let's create a basic dataset:

import pandas as pd
from transformers import pipeline

from melusine.base import MelusineDetector


def create_dataset():
    df = pd.DataFrame(
        [
            {
                "header": "Dossier 123456",
                "body": "Merci beaucoup pour votre gentillesse et votre écoute !",
            },
            {
                "header": "Réclamation (Dossier 987654)",
                "body": ("Bonjour, je ne suis pas satisfait de cette situation, " "répondez-moi rapidement svp!"),
            },
        ]
    )

    return df
header body
0 Dossier 123456 Merci beaucoup pour votre gentillesse et votre écoute !
1 Réclamation (Dossier 987654) Bonjour, je ne suis pas satisfait de cette situation, répondez-moi rapidement svp!

Classify with Zero-Shot-Classification

The transformers library, provided by HuggingFace, makes it really simple to use pre-trained models for zero shot classification.

model_name_or_path = "cmarkea/distilcamembert-base-nli"

sentences = [
    "Quelle belle journée aujourd'hui",
    "La marée est haute",
    "Ce film est une catastrophe, je suis en colère",
]

classifier = pipeline(task="zero-shot-classification", model=model_name_or_path, tokenizer=model_name_or_path)

result = classifier(
    sequences=sentences, candidate_labels=", ".join(["positif", "négatif"]), hypothesis_template="Ce texte est {}."
)

The classifier returns a score for the positive (positif in French) and negative (négatif in French) labels for each input text:

[
    {
        'sequence': "Quelle belle journée aujourd'hui",
        'labels': ['positif', 'négatif'],
        'scores': [0.95, 0.05]
    },
    {
        'sequence': 'La marée est haute',
        'labels': ['positif', 'négatif'],
        'scores': [0.76, 0.24]
    },
    {'sequence': 'Ce film est une catastrophe, je suis en colère',
     'labels': ['négatif', 'positif'],
     'scores': [0.97, 0.03]
     }
]

Implement a Dissatisfaction Detector

A full email processing pipeline may contain multiple models. Melusine uses the MelusineDetector template class to standardize how models are integrated into a pipeline.

class DissatisfactionDetector(MelusineDetector):
    """
    Detect if the text expresses dissatisfaction.
    """

    # Dataframe column names
    OUTPUT_RESULT_COLUMN = "dissatisfaction_result"
    TMP_DETECTION_INPUT_COLUMN = "detection_input"
    TMP_DETECTION_OUTPUT_COLUMN = "detection_output"

    # Model inference parameters
    POSITIVE_LABEL = "positif"
    NEGATIVE_LABEL = "négatif"
    HYPOTHESIS_TEMPLATE = "Ce texte est {}."

    def __init__(self, model_name_or_path: str, text_columns: List[str], threshold: float):
        self.text_columns = text_columns
        self.threshold = threshold
        self.classifier = pipeline(
            task="zero-shot-classification", model=model_name_or_path, tokenizer=model_name_or_path
        )

        super().__init__(input_columns=text_columns, output_columns=[self.OUTPUT_RESULT_COLUMN], name="dissatisfaction")

The pre_detect method assembles the text that we want to use for classification.

def pre_detect(self, row, debug_mode=False):
    # Assemble the text columns into a single text
    effective_text = ""
    for col in self.text_columns:
        effective_text += "\n" + row[col]
    row[self.TMP_DETECTION_INPUT_COLUMN] = effective_text

    # Store the effective detection text in the debug data
    if debug_mode:
        row[self.debug_dict_col] = {"detection_input": row[self.TMP_DETECTION_INPUT_COLUMN]}

    return row

The detect method runs the classification model on the text.

def detect(self, row, debug_mode=False):
    # Run the classifier on the text
    pipeline_result = self.classifier(
        sequences=row[self.TMP_DETECTION_INPUT_COLUMN],
        candidate_labels=", ".join([self.POSITIVE_LABEL, self.NEGATIVE_LABEL]),
        hypothesis_template=self.HYPOTHESIS_TEMPLATE,
    )
    # Format classification result
    result_dict = dict(zip(pipeline_result["labels"], pipeline_result["scores"]))
    row[self.TMP_DETECTION_OUTPUT_COLUMN] = result_dict

    # Store ML results in the debug data
    if debug_mode:
        row[self.debug_dict_col].update(result_dict)

    return row

The post_detect method applies a threshold on the prediction score to determine the detection result.

def post_detect(self, row, debug_mode=False):
    # Compare classification score to the detection threshold
    if row[self.TMP_DETECTION_OUTPUT_COLUMN][self.NEGATIVE_LABEL] > self.threshold:
        row[self.OUTPUT_RESULT_COLUMN] = True
    else:
        row[self.OUTPUT_RESULT_COLUMN] = False

    return row

On top of that, the detector takes care of building debug data to make the result explicable.

Run Detection

Putting it all together, we run the detector on the input dataset.

df = create_dataset()

detector = DissatisfactionDetector(
    model_name_or_path="cmarkea/distilcamembert-base-nli",
    text_columns=["header", "body"],
    threshold=0.7,
)

df = detector.transform(df)

As a result, we get a new column dissatisfaction_result with the detection result. We could have detection details by running the detector in debug mode.

header body dissatisfaction_result
0 Dossier 123456 Merci beaucoup pour votre gentillesse et votre écoute ! False
1 Réclamation (Dossier 987654) Bonjour, je ne suis pas satisfait de cette situation, répondez-moi rapidement svp! True