LatentSC Llama 3.1 8B with Summary Tokens

This repository contains a Llama 3.1 8B Instruct backbone with LatentSC Summary-token embeddings attached. The base model weights are unchanged; only the Summary token embeddings are added so that LatentSC inference can use the trained Summary tokens.

Usage

import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM

repo = "jeongseokoh/LatentSC_llama3.1_8b_6SummaryTokens"

tokenizer = AutoTokenizer.from_pretrained(repo)
model = AutoModelForCausalLM.from_pretrained(
    repo, torch_dtype=torch.bfloat16, device_map="auto"
)

# Summary tokens (default: 6)
summary_tokens = [f"<|Summary{i}|>" for i in range(1, 7)]

messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": "Solve: 17 * 23. Show the final answer only."},
]
prompt = tokenizer.apply_chat_template(messages, tokenize=False)
prompt_with_summary = prompt + "".join(summary_tokens)

inputs = tokenizer(prompt_with_summary, return_tensors="pt").to(model.device)
with torch.no_grad():
    out = model.generate(
        **inputs,
        max_new_tokens=128,
        do_sample=True,
        temperature=0.9,
        top_p=0.95,
        num_return_sequences=10,
        pad_token_id=tokenizer.eos_token_id,
        return_dict_in_generate=True,
        output_hidden_states=True,
    )

# Decode candidates
sequences = out.sequences
answers = tokenizer.batch_decode(sequences, skip_special_tokens=True)

# Embeddings: use last hidden state of the final token per sequence
last_hs = out.hidden_states[-1][-1]  # (batch, seq, hidden)
seq_lens = inputs["attention_mask"].sum(dim=1) - 1
idx = torch.arange(last_hs.size(0), device=last_hs.device)
embs = last_hs[idx, seq_lens, :]  # (N, D)

# LSC selection (cosine similarity)
embs = F.normalize(embs.float(), p=2, dim=1)
sim = embs @ embs.T
sim.fill_diagonal_(0.0)
avg_sim = sim.mean(dim=1)
best_idx = int(torch.argmax(avg_sim))
best_answer = answers[best_idx]

# Dynamic TopK LSC
def lsc_topk(embs, answers, k):
    embs = F.normalize(embs.float(), p=2, dim=1)
    sim = embs @ embs.T
    sim.fill_diagonal_(0.0)
    avg_sim = sim.mean(dim=1)
    topk_idx = torch.topk(avg_sim, k=k).indices
    sub = embs[topk_idx]
    sub_sim = sub @ sub.T
    sub_sim.fill_diagonal_(0.0)
    sub_avg = sub_sim.mean(dim=1)
    best_local = int(torch.argmax(sub_avg))
    return answers[int(topk_idx[best_local])], float(sub_avg.max())

best = None
best_score = -1e9
for k in [3, 5, 7]:
    cand, score = lsc_topk(embs, answers, k)
    if score > best_score:
        best_score = score
        best = cand

Stored LatentSC config fields

The following config fields are saved (when present) to guide LatentSC inference:

lsc_num_special_tokens
lsc_special_token_prefix
lsc_aggr
lsc_remove_eos
lsc_temp

For detailed training/inference scripts and full usage, see the GitHub repository: https://github.com/jeongseokO/LatentSC_official

Downloads last month
2
Safetensors
Model size
8B params
Tensor type
BF16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for jeongseokoh/LatentSC_llama3.1_8b_6SummaryTokens

Finetuned
(2587)
this model

Collection including jeongseokoh/LatentSC_llama3.1_8b_6SummaryTokens