Contrastive Learning Mention Embedding
A BERT-base model with a linear projection head fine-tuned via contrastive learning to produce embeddings that maximize separability between mentions of different social groups. Designed for clustering social group mentions into qualitative categories.
This model is part of the group-appeal-detector package, which also provides group mention detection and stance classification.
Model Details
- Base model:
bert-base-uncased - Architecture: BERT-base + linear projection head (768 → 128 dimensions)
- Training objective: Triplet loss with hard negative mining
- Training data: Social group dictionary provided by Will Horne, Alona O. Dolinsky and Lena Maria Huber
How It Works
Each mention is fed into the model using the following prompt template:
Social group of {mention} is: [MASK].
The hidden state at the [MASK] position is extracted, passed through the projection layer, and L2-normalized. Mentions of the same social group category are pulled together in embedding space; mentions of different categories are pushed apart.
The model was trained using the triplet loss. Each anchor is a term from a category in the social group dictionary, paired with a randomly sampled positive from the same category and a hard negative mined from a different category.
Usage
Via group-appeal-detector package (recommended)
pip install group-appeal-detector
from group_appeal_detector import GroupAppealDetector, GroupMentionClusterer
detector = GroupAppealDetector(device="cpu")
# collect mentions from a corpus
texts = [...]
all_mentions = detector.detect_mentions_batch(texts, batch_size=16, as_df=False)
mentions = [m["span"] for mentions in all_mentions for m in mentions]
# cluster into categories
clusterer = GroupMentionClusterer(mentions, device="cpu")
results_df = clusterer.cluster(n_clusters=5, as_df=True)
results_df.head()
To find the optimal number of clusters automatically:
best_k, all_scores = clusterer.find_optimal_k(k_range=(2, 20), metric="silhouette", visualize=True)
results_df = clusterer.cluster(n_clusters=best_k, as_df=True)
Direct usage
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, AutoTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
REPO_ID = "maxwlnd/cl_mention_embedding"
class ModelMask(nn.Module):
def __init__(self, tokenizer, pretrained_model_name="bert-base-uncased", proj_dim=128):
super().__init__()
config = AutoConfig.from_pretrained(pretrained_model_name)
self.encoder = AutoModel.from_config(config)
self.mask_id = tokenizer.mask_token_id
self.projector = nn.Sequential(nn.Linear(config.hidden_size, proj_dim))
def encode(self, input_ids, attention_mask):
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
mask_positions = (input_ids == self.mask_id)
h = torch.stack([
outputs.last_hidden_state[i][mask_positions[i]].mean(dim=0)
for i in range(input_ids.size(0))
])
z = self.projector(h)
return F.normalize(z, p=2, dim=1)
tokenizer = AutoTokenizer.from_pretrained(REPO_ID)
model = ModelMask(tokenizer)
model.load_state_dict(load_file(hf_hub_download(REPO_ID, "model.safetensors")))
model.eval()
def embed(mention: str) -> torch.Tensor:
prompt = f"Social group of {mention} is: {tokenizer.mask_token}."
inputs = tokenizer(prompt, return_tensors="pt")
with torch.no_grad():
return model.encode(inputs["input_ids"], inputs["attention_mask"])
emb_a = embed("farmers")
emb_b = embed("agricultural workers")
print(F.cosine_similarity(emb_a, emb_b).item())
Related Models
This model is one of three models in the group appeal detection pipeline:
| Model | Task |
|---|---|
maxwlnd/roberta_group_mention_detector |
Detect social group mentions |
maxwlnd/socialgroup_stance_classification_nli |
Classify stance toward a group as positive, negative, or neutral |
maxwlnd/cl_mention_embedding |
Embed mentions for clustering into qualitative categories (this model) |
License
MIT
- Downloads last month
- 67
Model tree for maxwlnd/cl_mention_embedding
Base model
google-bert/bert-base-uncased