Latent Self-Consistency
Collection
LSC for Majority selection in Short- and Long-form generation • 1 item • Updated
This repository contains a Llama 3.1 8B Instruct backbone with LatentSC Summary-token embeddings attached. The base model weights are unchanged; only the Summary token embeddings are added so that LatentSC inference can use the trained Summary tokens.
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
repo = "jeongseokoh/LatentSC_llama3.1_8b_6SummaryTokens"
tokenizer = AutoTokenizer.from_pretrained(repo)
model = AutoModelForCausalLM.from_pretrained(
repo, torch_dtype=torch.bfloat16, device_map="auto"
)
# Summary tokens (default: 6)
summary_tokens = [f"<|Summary{i}|>" for i in range(1, 7)]
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Solve: 17 * 23. Show the final answer only."},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
prompt_with_summary = prompt + "".join(summary_tokens)
inputs = tokenizer(prompt_with_summary, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.9,
top_p=0.95,
num_return_sequences=10,
pad_token_id=tokenizer.eos_token_id,
return_dict_in_generate=True,
output_hidden_states=True,
)
# Decode candidates
sequences = out.sequences
answers = tokenizer.batch_decode(sequences, skip_special_tokens=True)
# Embeddings: use last hidden state of the final token per sequence
last_hs = out.hidden_states[-1][-1] # (batch, seq, hidden)
seq_lens = inputs["attention_mask"].sum(dim=1) - 1
idx = torch.arange(last_hs.size(0), device=last_hs.device)
embs = last_hs[idx, seq_lens, :] # (N, D)
# LSC selection (cosine similarity)
embs = F.normalize(embs.float(), p=2, dim=1)
sim = embs @ embs.T
sim.fill_diagonal_(0.0)
avg_sim = sim.mean(dim=1)
best_idx = int(torch.argmax(avg_sim))
best_answer = answers[best_idx]
# Dynamic TopK LSC
def lsc_topk(embs, answers, k):
embs = F.normalize(embs.float(), p=2, dim=1)
sim = embs @ embs.T
sim.fill_diagonal_(0.0)
avg_sim = sim.mean(dim=1)
topk_idx = torch.topk(avg_sim, k=k).indices
sub = embs[topk_idx]
sub_sim = sub @ sub.T
sub_sim.fill_diagonal_(0.0)
sub_avg = sub_sim.mean(dim=1)
best_local = int(torch.argmax(sub_avg))
return answers[int(topk_idx[best_local])], float(sub_avg.max())
best = None
best_score = -1e9
for k in [3, 5, 7]:
cand, score = lsc_topk(embs, answers, k)
if score > best_score:
best_score = score
best = cand
The following config fields are saved (when present) to guide LatentSC inference:
lsc_num_special_tokens
lsc_special_token_prefix
lsc_aggr
lsc_remove_eos
lsc_temp
For detailed training/inference scripts and full usage, see the GitHub repository: https://github.com/jeongseokO/LatentSC_official
Base model
meta-llama/Llama-3.1-8B