Bonsai Diffusion LM - ModernBERT

A lightweight diffusion language model based on the LLaDA paper (Large Language Diffusion with masking).

Model Description

Unlike traditional autoregressive models (GPT) that generate left-to-right, this model starts from pure noise and iteratively unmasks tokens to generate coherent text.

Property Value
Architecture ModernBERT-base
Parameters 149M
Training Data TinyStories (50,000 samples)
Context Length 256 tokens

Quick Usage

from transformers import AutoModelForMaskedLM, AutoTokenizer
import torch

model = AutoModelForMaskedLM.from_pretrained("TASMAYU/bonsai-diffusionLM-modernbert")
tokenizer = AutoTokenizer.from_pretrained("TASMAYU/bonsai-diffusionLM-modernbert")

if tokenizer.mask_token is None:
    tokenizer.mask_token = "[MASK]"

def generate(prompt=None, num_steps=64, seq_len=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()
    
    if prompt:
        prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
        input_ids = torch.full((1, seq_len), tokenizer.mask_token_id, device=device)
        input_ids[0, :len(prompt_ids)] = torch.tensor(prompt_ids, device=device)
    else:
        input_ids = torch.full((1, seq_len), tokenizer.mask_token_id, device=device)
    
    for step in range(num_steps):
        t = 1.0 - (step / num_steps)
        s = 1.0 - ((step + 1) / num_steps)
        
        with torch.no_grad():
            outputs = model(input_ids)
            predictions = outputs.logits.argmax(dim=-1)
        
        mask_positions = (input_ids == tokenizer.mask_token_id)
        remask_prob = s / t if t > 0 else 0
        remask = torch.rand_like(input_ids.float()) < remask_prob
        
        new_input_ids = input_ids.clone()
        new_input_ids[mask_positions] = predictions[mask_positions]
        new_input_ids[remask & mask_positions] = tokenizer.mask_token_id
        input_ids = new_input_ids
    
    return tokenizer.decode(input_ids[0].cpu().tolist(), skip_special_tokens=True)

# Example
print(generate("Once upon a time", num_steps=64))
Downloads last month
244
Safetensors
Model size
0.1B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support