Contrastive Learning Mention Embedding

A BERT-base model with a linear projection head fine-tuned via contrastive learning to produce embeddings that maximize separability between mentions of different social groups. Designed for clustering social group mentions into qualitative categories.

This model is part of the group-appeal-detector package, which also provides group mention detection and stance classification.

Model Details

  • Base model: bert-base-uncased
  • Architecture: BERT-base + linear projection head (768 → 128 dimensions)
  • Training objective: Triplet loss with hard negative mining
  • Training data: Social group dictionary provided by Will Horne, Alona O. Dolinsky and Lena Maria Huber

How It Works

Each mention is fed into the model using the following prompt template:

Social group of {mention} is: [MASK].

The hidden state at the [MASK] position is extracted, passed through the projection layer, and L2-normalized. Mentions of the same social group category are pulled together in embedding space; mentions of different categories are pushed apart.

The model was trained using the triplet loss. Each anchor is a term from a category in the social group dictionary, paired with a randomly sampled positive from the same category and a hard negative mined from a different category.

Usage

Via group-appeal-detector package (recommended)

pip install group-appeal-detector
from group_appeal_detector import GroupAppealDetector, GroupMentionClusterer

detector = GroupAppealDetector(device="cpu")

# collect mentions from a corpus
texts = [...]
all_mentions = detector.detect_mentions_batch(texts, batch_size=16, as_df=False)
mentions = [m["span"] for mentions in all_mentions for m in mentions]

# cluster into categories
clusterer = GroupMentionClusterer(mentions, device="cpu")
results_df = clusterer.cluster(n_clusters=5, as_df=True)
results_df.head()

To find the optimal number of clusters automatically:

best_k, all_scores = clusterer.find_optimal_k(k_range=(2, 20), metric="silhouette", visualize=True)
results_df = clusterer.cluster(n_clusters=best_k, as_df=True)

Direct usage

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file

REPO_ID = "maxwlnd/cl_mention_embedding"

class ModelMask(nn.Module):
    def __init__(self, tokenizer, pretrained_model_name="bert-base-uncased", proj_dim=128):
        super().__init__()
        config = AutoConfig.from_pretrained(pretrained_model_name)
        self.encoder = AutoModel.from_config(config)
        self.mask_id = tokenizer.mask_token_id
        self.projector = nn.Sequential(nn.Linear(config.hidden_size, proj_dim))

    def encode(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        mask_positions = (input_ids == self.mask_id)
        h = torch.stack([
            outputs.last_hidden_state[i][mask_positions[i]].mean(dim=0)
            for i in range(input_ids.size(0))
        ])
        z = self.projector(h)
        return F.normalize(z, p=2, dim=1)

tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
model = ModelMask(tokenizer)
model.load_state_dict(load_file(hf_hub_download(REPO_ID, "model.safetensors")))
model.eval()

def embed(mention: str) -> torch.Tensor:
    prompt = f"Social group of {mention} is: {tokenizer.mask_token}."
    inputs = tokenizer(prompt, return_tensors="pt")
    with torch.no_grad():
        return model.encode(inputs["input_ids"], inputs["attention_mask"])

emb_a = embed("farmers")
emb_b = embed("agricultural workers")
print(F.cosine_similarity(emb_a, emb_b).item())

Related Models

This model is one of three models in the group appeal detection pipeline:

Model Task
maxwlnd/roberta_group_mention_detector Detect social group mentions
maxwlnd/socialgroup_stance_classification_nli Classify stance toward a group as positive, negative, or neutral
maxwlnd/cl_mention_embedding Embed mentions for clustering into qualitative categories (this model)

License

MIT

Downloads last month
67
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for maxwlnd/cl_mention_embedding

Finetuned
(6607)
this model

Collection including maxwlnd/cl_mention_embedding