Spaces:
Runtime error
Runtime error
| import os | |
| import math | |
| import copy | |
| import json | |
| import torch | |
| import torch.nn.functional as F | |
| from flask import Flask, request, jsonify, Response | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| app = Flask(__name__) | |
| model = None | |
| tokenizer = None | |
| device = None | |
| def add_gumbel_noise(logits, temperature): | |
| if temperature == 0: | |
| return logits | |
| logits = logits.to(torch.float64) | |
| noise = torch.rand_like(logits, dtype=torch.float64) | |
| g = (-torch.log(noise)) ** temperature | |
| return logits.exp() / g | |
| def get_num_transfer_tokens(mask_index, steps): | |
| mask_num = mask_index.sum(dim=1, keepdim=True) | |
| base = mask_num // steps | |
| rem = mask_num % steps | |
| out = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base | |
| for i in range(mask_num.size(0)): | |
| out[i, : rem[i]] += 1 | |
| return out | |
| def build_staircase_attention_mask(x, block_size, pad_id): | |
| B, T = x.shape | |
| device = x.device | |
| valid = x != pad_id | |
| pos_raw = torch.cumsum(valid.long(), dim=-1) | |
| position_ids = torch.where(valid, pos_raw - 1, torch.zeros_like(pos_raw)).long() | |
| col = torch.arange(T, device=device) | |
| block_ids = (col // block_size).view(1, T).expand(B, T) | |
| block_ids = torch.where(valid, block_ids, torch.full_like(block_ids, -1)) | |
| q = block_ids.view(B, 1, T, 1) | |
| k = block_ids.view(B, 1, 1, T) | |
| attn = (k <= q) & (q >= 0) & (k >= 0) | |
| return attn, position_ids | |
| def diffusion_step_block(logits, x_block, mask_block, num_transfer, temperature, remasking): | |
| B, L, _ = logits.shape | |
| if not mask_block.any(): | |
| return x_block | |
| noisy = add_gumbel_noise(logits, temperature) | |
| x0 = noisy.argmax(dim=-1) | |
| if remasking == "low_confidence": | |
| p = F.softmax(logits, dim=-1) | |
| conf = p.gather(-1, x0.unsqueeze(-1)).squeeze(-1) | |
| elif remasking == "random": | |
| conf = torch.rand((B, L), device=logits.device) | |
| else: | |
| raise ValueError(remasking) | |
| x0 = torch.where(mask_block, x0, x_block) | |
| neg_inf = torch.full_like(conf, -float("inf")) | |
| conf = torch.where(mask_block, conf, neg_inf) | |
| commit = torch.zeros_like(x_block, dtype=torch.bool) | |
| for i in range(B): | |
| k = int(num_transfer[i].item()) | |
| if k > 0: | |
| valid = (conf[i] > -float("inf")).sum().item() | |
| k = min(k, valid) | |
| _, idx = torch.topk(conf[i], k) | |
| commit[i, idx] = True | |
| out = x_block.clone() | |
| out[commit] = x0[commit] | |
| return out | |
| def generate( | |
| model, | |
| tokenizer, | |
| prompt, | |
| steps=128, | |
| max_new_tokens=128, | |
| block_size=32, | |
| temperature=0.0, | |
| cfg_scale=0.0, | |
| remasking="low_confidence", | |
| capture_interval=0, | |
| ): | |
| device = model.device | |
| mask_id = tokenizer.mask_token_id | |
| pad_id = tokenizer.pad_token_id | |
| if pad_id is None: | |
| pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id | |
| if isinstance(prompt, torch.Tensor): | |
| x = prompt.to(device).long() | |
| else: | |
| if isinstance(prompt[0], (list, tuple)): | |
| max_len = max(len(p) for p in prompt) | |
| x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long) | |
| for i, p in enumerate(prompt): | |
| x[i, : len(p)] = torch.tensor(p, device=device) | |
| else: | |
| x = torch.tensor(prompt, device=device).long() | |
| if x.dim() == 1: | |
| x = x.unsqueeze(0) | |
| B = x.size(0) | |
| finished = torch.zeros(B, dtype=torch.bool, device=device) | |
| num_blocks = math.ceil(max_new_tokens / block_size) | |
| steps_per_block = math.ceil(steps / num_blocks) | |
| generated = 0 | |
| intermediates = [] | |
| total_step = 0 | |
| while generated < max_new_tokens: | |
| if finished.all(): | |
| break | |
| T_prefix = x.size(1) | |
| offset = T_prefix % block_size | |
| room = block_size if offset == 0 else block_size - offset | |
| cur_len = min(room, max_new_tokens - generated) | |
| if cur_len <= 0: | |
| break | |
| attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id) | |
| out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True) | |
| cond_past = out.past_key_values | |
| if cfg_scale > 0: | |
| un_x = x.clone() | |
| un_x[:] = mask_id | |
| out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True) | |
| uncond_past = out_un.past_key_values | |
| else: | |
| uncond_past = None | |
| block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long) | |
| block[finished] = pad_id | |
| x = torch.cat([x, block], dim=1) | |
| T_total = x.size(1) | |
| block_mask = x[:, -cur_len:] == mask_id | |
| num_transfer = get_num_transfer_tokens(block_mask, steps_per_block) | |
| eff_steps = num_transfer.size(1) | |
| full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id) | |
| attn_blk = full_attn[:, :, T_prefix:T_total, :] | |
| pos_blk = full_pos[:, T_prefix:T_total] | |
| for t in range(eff_steps): | |
| x_blk = x[:, T_prefix:T_total] | |
| m_blk = x_blk == mask_id | |
| cond_logits = model( | |
| x_blk, attention_mask=attn_blk, position_ids=pos_blk, | |
| past_key_values=copy.deepcopy(cond_past), use_cache=False | |
| ).logits | |
| logits = cond_logits | |
| if cfg_scale > 0: | |
| un_logits = model( | |
| x_blk, attention_mask=attn_blk, position_ids=pos_blk, | |
| past_key_values=copy.deepcopy(uncond_past), use_cache=False | |
| ).logits | |
| logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits) | |
| x_blk_new = diffusion_step_block( | |
| logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking | |
| ) | |
| x[:, T_prefix:T_total] = x_blk_new | |
| if capture_interval > 0 and total_step % capture_interval == 0: | |
| intermediates.append(x.clone()) | |
| total_step += 1 | |
| if tokenizer.eos_token_id is not None: | |
| finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1) | |
| if finished.all(): | |
| break | |
| generated += cur_len | |
| if finished.all(): | |
| break | |
| if capture_interval > 0: | |
| return x, intermediates | |
| return x | |
| def generate_stream( | |
| model, | |
| tokenizer, | |
| prompt, | |
| steps=128, | |
| max_new_tokens=128, | |
| block_size=32, | |
| temperature=0.0, | |
| cfg_scale=0.0, | |
| remasking="low_confidence", | |
| capture_interval=10, | |
| ): | |
| device = model.device | |
| mask_id = tokenizer.mask_token_id | |
| pad_id = tokenizer.pad_token_id | |
| if pad_id is None: | |
| pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id | |
| if isinstance(prompt, torch.Tensor): | |
| x = prompt.to(device).long() | |
| else: | |
| if isinstance(prompt[0], (list, tuple)): | |
| max_len = max(len(p) for p in prompt) | |
| x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long) | |
| for i, p in enumerate(prompt): | |
| x[i, : len(p)] = torch.tensor(p, device=device) | |
| else: | |
| x = torch.tensor(prompt, device=device).long() | |
| if x.dim() == 1: | |
| x = x.unsqueeze(0) | |
| B = x.size(0) | |
| finished = torch.zeros(B, dtype=torch.bool, device=device) | |
| num_blocks = math.ceil(max_new_tokens / block_size) | |
| steps_per_block = math.ceil(steps / num_blocks) | |
| generated = 0 | |
| total_step = 0 | |
| prompt_len = x.size(1) | |
| while generated < max_new_tokens: | |
| if finished.all(): | |
| break | |
| T_prefix = x.size(1) | |
| offset = T_prefix % block_size | |
| room = block_size if offset == 0 else block_size - offset | |
| cur_len = min(room, max_new_tokens - generated) | |
| if cur_len <= 0: | |
| break | |
| attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id) | |
| out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True) | |
| cond_past = out.past_key_values | |
| if cfg_scale > 0: | |
| un_x = x.clone() | |
| un_x[:] = mask_id | |
| out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True) | |
| uncond_past = out_un.past_key_values | |
| else: | |
| uncond_past = None | |
| block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long) | |
| block[finished] = pad_id | |
| x = torch.cat([x, block], dim=1) | |
| T_total = x.size(1) | |
| block_mask = x[:, -cur_len:] == mask_id | |
| num_transfer = get_num_transfer_tokens(block_mask, steps_per_block) | |
| eff_steps = num_transfer.size(1) | |
| full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id) | |
| attn_blk = full_attn[:, :, T_prefix:T_total, :] | |
| pos_blk = full_pos[:, T_prefix:T_total] | |
| for t in range(eff_steps): | |
| x_blk = x[:, T_prefix:T_total] | |
| m_blk = x_blk == mask_id | |
| cond_logits = model( | |
| x_blk, attention_mask=attn_blk, position_ids=pos_blk, | |
| past_key_values=copy.deepcopy(cond_past), use_cache=False | |
| ).logits | |
| logits = cond_logits | |
| if cfg_scale > 0: | |
| un_logits = model( | |
| x_blk, attention_mask=attn_blk, position_ids=pos_blk, | |
| past_key_values=copy.deepcopy(uncond_past), use_cache=False | |
| ).logits | |
| logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits) | |
| x_blk_new = diffusion_step_block( | |
| logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking | |
| ) | |
| x[:, T_prefix:T_total] = x_blk_new | |
| if total_step % capture_interval == 0: | |
| new_tokens = x[0, prompt_len:prompt_len + max_new_tokens].tolist() | |
| text = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| yield { | |
| "type": "intermediate", | |
| "step": total_step, | |
| "text": text, | |
| "total_steps": steps | |
| } | |
| total_step += 1 | |
| if tokenizer.eos_token_id is not None: | |
| finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1) | |
| if finished.all(): | |
| break | |
| generated += cur_len | |
| if finished.all(): | |
| break | |
| new_tokens = x[0, prompt_len:prompt_len + max_new_tokens].tolist() | |
| final_text = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| yield { | |
| "type": "final", | |
| "text": final_text, | |
| "total_steps": total_step | |
| } | |
| def load_model(): | |
| global model, tokenizer, device | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_name = os.getenv("MODEL_NAME", "dllm-hub/Qwen3-0.6B-diffusion-bd3lm-v0.1") | |
| print(f"Loading model {model_name} on {device}...") | |
| model = AutoModelForMaskedLM.from_pretrained( | |
| model_name, | |
| dtype=torch.bfloat16, | |
| trust_remote_code=True | |
| ).to(device).eval() | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| model_name, | |
| trust_remote_code=True | |
| ) | |
| print("Model loaded successfully!") | |
| def health(): | |
| return jsonify({"status": "healthy", "model_loaded": model is not None}) | |
| def generate_text(): | |
| if model is None or tokenizer is None: | |
| return jsonify({"error": "Model not loaded"}), 503 | |
| data = request.get_json() | |
| if not data or 'prompt' not in data: | |
| return jsonify({"error": "Missing 'prompt' field"}), 400 | |
| prompt = data['prompt'] | |
| steps = data.get('steps', 256) | |
| max_new_tokens = data.get('max_new_tokens', 256) | |
| block_size = data.get('block_size', 32) | |
| temperature = data.get('temperature', 0.0) | |
| cfg_scale = data.get('cfg_scale', 0.0) | |
| remasking = data.get('remasking', 'low_confidence') | |
| system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.') | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| encoded = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| enable_thinking=False | |
| ) | |
| input_ids = torch.tensor([encoded], dtype=torch.long, device=device) | |
| output = generate( | |
| model, | |
| tokenizer, | |
| input_ids, | |
| steps=steps, | |
| max_new_tokens=max_new_tokens, | |
| block_size=block_size, | |
| temperature=temperature, | |
| cfg_scale=cfg_scale, | |
| remasking=remasking, | |
| ) | |
| prompt_len = len(encoded) | |
| new_tokens = output[0, prompt_len:prompt_len + max_new_tokens].tolist() | |
| generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| return jsonify({ | |
| "prompt": prompt, | |
| "generated_text": generated_text, | |
| "parameters": { | |
| "steps": steps, | |
| "max_new_tokens": max_new_tokens, | |
| "block_size": block_size, | |
| "temperature": temperature, | |
| "cfg_scale": cfg_scale, | |
| "remasking": remasking | |
| } | |
| }) | |
| def generate_text_stream(): | |
| if model is None or tokenizer is None: | |
| return jsonify({"error": "Model not loaded"}), 503 | |
| data = request.get_json() | |
| if not data or 'prompt' not in data: | |
| return jsonify({"error": "Missing 'prompt' field"}), 400 | |
| prompt = data['prompt'] | |
| steps = data.get('steps', 256) | |
| max_new_tokens = data.get('max_new_tokens', 256) | |
| block_size = data.get('block_size', 32) | |
| temperature = data.get('temperature', 0.0) | |
| cfg_scale = data.get('cfg_scale', 0.0) | |
| remasking = data.get('remasking', 'low_confidence') | |
| system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.') | |
| capture_interval = data.get('capture_interval', 10) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| encoded = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| enable_thinking=False | |
| ) | |
| input_ids = torch.tensor([encoded], dtype=torch.long, device=device) | |
| output, intermediates = generate( | |
| model, | |
| tokenizer, | |
| input_ids, | |
| steps=steps, | |
| max_new_tokens=max_new_tokens, | |
| block_size=block_size, | |
| temperature=temperature, | |
| cfg_scale=cfg_scale, | |
| remasking=remasking, | |
| capture_interval=capture_interval, | |
| ) | |
| prompt_len = len(encoded) | |
| intermediate_states = [] | |
| for i, intermediate in enumerate(intermediates): | |
| new_tokens = intermediate[0, prompt_len:prompt_len + max_new_tokens].tolist() | |
| text = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| intermediate_states.append({ | |
| "step": i * capture_interval, | |
| "text": text | |
| }) | |
| new_tokens = output[0, prompt_len:prompt_len + max_new_tokens].tolist() | |
| generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True) | |
| return jsonify({ | |
| "prompt": prompt, | |
| "generated_text": generated_text, | |
| "intermediate_states": intermediate_states, | |
| "parameters": { | |
| "steps": steps, | |
| "max_new_tokens": max_new_tokens, | |
| "block_size": block_size, | |
| "temperature": temperature, | |
| "cfg_scale": cfg_scale, | |
| "remasking": remasking, | |
| "capture_interval": capture_interval | |
| } | |
| }) | |
| def generate_text_sse(): | |
| if model is None or tokenizer is None: | |
| return jsonify({"error": "Model not loaded"}), 503 | |
| data = request.get_json() | |
| if not data or 'prompt' not in data: | |
| return jsonify({"error": "Missing 'prompt' field"}), 400 | |
| prompt = data['prompt'] | |
| steps = data.get('steps', 256) | |
| max_new_tokens = data.get('max_new_tokens', 256) | |
| block_size = data.get('block_size', 32) | |
| temperature = data.get('temperature', 0.0) | |
| cfg_scale = data.get('cfg_scale', 0.0) | |
| remasking = data.get('remasking', 'low_confidence') | |
| system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.') | |
| capture_interval = data.get('capture_interval', 10) | |
| messages = [ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": prompt} | |
| ] | |
| encoded = tokenizer.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| enable_thinking=False | |
| ) | |
| input_ids = torch.tensor([encoded], dtype=torch.long, device=device) | |
| def stream(): | |
| for state in generate_stream( | |
| model, | |
| tokenizer, | |
| input_ids, | |
| steps=steps, | |
| max_new_tokens=max_new_tokens, | |
| block_size=block_size, | |
| temperature=temperature, | |
| cfg_scale=cfg_scale, | |
| remasking=remasking, | |
| capture_interval=capture_interval, | |
| ): | |
| yield f"data: {json.dumps(state)}\n\n" | |
| return Response( | |
| stream(), | |
| mimetype='text/event-stream', | |
| headers={ | |
| 'Cache-Control': 'no-cache', | |
| 'X-Accel-Buffering': 'no', | |
| } | |
| ) | |
| if __name__ == '__main__': | |
| load_model() | |
| app.run(host='0.0.0.0', port=int(os.getenv('PORT', 5000))) | |