Bonsai Diffusion LM - ModernBERT
A lightweight diffusion language model based on the LLaDA paper (Large Language Diffusion with masking).
Model Description
Unlike traditional autoregressive models (GPT) that generate left-to-right, this model starts from pure noise and iteratively unmasks tokens to generate coherent text.
| Property | Value |
|---|---|
| Architecture | ModernBERT-base |
| Parameters | 149M |
| Training Data | TinyStories (50,000 samples) |
| Context Length | 256 tokens |
Quick Usage
from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch
model = AutoModelForMaskedLM.from_pretrained("TASMAYU/bonsai-diffusionLM-modernbert")
tokenizer = AutoTokenizer.from_pretrained("TASMAYU/bonsai-diffusionLM-modernbert")
if tokenizer.mask_token is None:
tokenizer.mask_token = "[MASK]"
def generate(prompt=None, num_steps=64, seq_len=256):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
if prompt:
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
input_ids = torch.full((1, seq_len), tokenizer.mask_token_id, device=device)
input_ids[0, :len(prompt_ids)] = torch.tensor(prompt_ids, device=device)
else:
input_ids = torch.full((1, seq_len), tokenizer.mask_token_id, device=device)
for step in range(num_steps):
t = 1.0 - (step / num_steps)
s = 1.0 - ((step + 1) / num_steps)
with torch.no_grad():
outputs = model(input_ids)
predictions = outputs.logits.argmax(dim=-1)
mask_positions = (input_ids == tokenizer.mask_token_id)
remask_prob = s / t if t > 0 else 0
remask = torch.rand_like(input_ids.float()) < remask_prob
new_input_ids = input_ids.clone()
new_input_ids[mask_positions] = predictions[mask_positions]
new_input_ids[remask & mask_positions] = tokenizer.mask_token_id
input_ids = new_input_ids
return tokenizer.decode(input_ids[0].cpu().tolist(), skip_special_tokens=True)
# Example
print(generate("Once upon a time", num_steps=64))
- Downloads last month
- 244