HF Deploy Script commited on
Commit
a919dff
·
0 Parent(s):

Initial deployment: diffusion-chatbot

Browse files
Files changed (5) hide show
  1. .gitignore +7 -0
  2. Dockerfile +22 -0
  3. README.md +68 -0
  4. app.py +564 -0
  5. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .env
4
+ .venv
5
+ venv/
6
+ .git
7
+ .ds_store
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.6.0-cuda12.4-cudnn9-devel
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ git \
7
+ && rm -rf /var/lib/apt/lists/*
8
+
9
+ RUN git clone --recurse-submodules https://github.com/ZHZisZZ/dllm.git /tmp/dllm && \
10
+ pip install --no-cache-dir -e /tmp/dllm && \
11
+ pip install --no-cache-dir -e /tmp/dllm/lm-evaluation-harness
12
+
13
+ RUN pip install --no-cache-dir flask
14
+
15
+ COPY app.py .
16
+
17
+ ENV MODEL_NAME=dllm-hub/Qwen3-0.6B-diffusion-bd3lm-v0.1
18
+ ENV PORT=7860
19
+
20
+ EXPOSE 7860
21
+
22
+ CMD ["python", "app.py"]
README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Diffusion Chatbot
3
+ emoji: 🤖
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # Diffusion Chatbot
12
+
13
+ Flask server hosting the Qwen3-0.6B-diffusion-bd3lm-v0.1 model with real-time streaming inference.
14
+
15
+ ## Features
16
+
17
+ - **Real-time streaming**: Watch the diffusion model generate text step-by-step
18
+ - **Three endpoints**: Simple generation, batch intermediate states, and real-time SSE streaming
19
+ - **GPU support**: Automatically uses GPU if available, falls back to CPU
20
+
21
+ ## API Endpoints
22
+
23
+ ### Health Check
24
+ ```
25
+ GET /health
26
+ ```
27
+
28
+ ### Generate Text
29
+ ```
30
+ POST /generate
31
+ Content-Type: application/json
32
+
33
+ {
34
+ "prompt": "Your question here",
35
+ "max_new_tokens": 256
36
+ }
37
+ ```
38
+
39
+ ### Generate with Real-time Streaming (SSE)
40
+ ```
41
+ POST /generate_sse
42
+ Content-Type: application/json
43
+
44
+ {
45
+ "prompt": "Your question here",
46
+ "max_new_tokens": 100,
47
+ "capture_interval": 10
48
+ }
49
+ ```
50
+
51
+ ## Example Usage
52
+
53
+ ```bash
54
+ curl -X POST https://YOUR_USERNAME-diffusion-chatbot.hf.space/generate \
55
+ -H "Content-Type: application/json" \
56
+ -d '{"prompt": "Hello, how are you?", "max_new_tokens": 50}'
57
+ ```
58
+
59
+ ## Technical Details
60
+
61
+ - **Model**: [dllm-hub/Qwen3-0.6B-diffusion-bd3lm-v0.1](https://huggingface.co/dllm-hub/Qwen3-0.6B-diffusion-bd3lm-v0.1)
62
+ - **Framework**: Flask + PyTorch
63
+ - **Diffusion Method**: Block Diffusion Language Model (BD3LM)
64
+
65
+ ## Environment Variables
66
+
67
+ - `MODEL_NAME`: HuggingFace model name (default: dllm-hub/Qwen3-0.6B-diffusion-bd3lm-v0.1)
68
+ - `PORT`: Server port (default: 7860 for HF Spaces)
app.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import copy
4
+ import json
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from flask import Flask, request, jsonify, Response
8
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
9
+
10
+ app = Flask(__name__)
11
+
12
+ model = None
13
+ tokenizer = None
14
+ device = None
15
+
16
+
17
+ def add_gumbel_noise(logits, temperature):
18
+ if temperature == 0:
19
+ return logits
20
+ logits = logits.to(torch.float64)
21
+ noise = torch.rand_like(logits, dtype=torch.float64)
22
+ g = (-torch.log(noise)) ** temperature
23
+ return logits.exp() / g
24
+
25
+
26
+ def get_num_transfer_tokens(mask_index, steps):
27
+ mask_num = mask_index.sum(dim=1, keepdim=True)
28
+ base = mask_num // steps
29
+ rem = mask_num % steps
30
+ out = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.long) + base
31
+ for i in range(mask_num.size(0)):
32
+ out[i, : rem[i]] += 1
33
+ return out
34
+
35
+
36
+ def build_staircase_attention_mask(x, block_size, pad_id):
37
+ B, T = x.shape
38
+ device = x.device
39
+
40
+ valid = x != pad_id
41
+ pos_raw = torch.cumsum(valid.long(), dim=-1)
42
+ position_ids = torch.where(valid, pos_raw - 1, torch.zeros_like(pos_raw)).long()
43
+
44
+ col = torch.arange(T, device=device)
45
+ block_ids = (col // block_size).view(1, T).expand(B, T)
46
+ block_ids = torch.where(valid, block_ids, torch.full_like(block_ids, -1))
47
+
48
+ q = block_ids.view(B, 1, T, 1)
49
+ k = block_ids.view(B, 1, 1, T)
50
+ attn = (k <= q) & (q >= 0) & (k >= 0)
51
+
52
+ return attn, position_ids
53
+
54
+
55
+ def diffusion_step_block(logits, x_block, mask_block, num_transfer, temperature, remasking):
56
+ B, L, _ = logits.shape
57
+ if not mask_block.any():
58
+ return x_block
59
+
60
+ noisy = add_gumbel_noise(logits, temperature)
61
+ x0 = noisy.argmax(dim=-1)
62
+
63
+ if remasking == "low_confidence":
64
+ p = F.softmax(logits, dim=-1)
65
+ conf = p.gather(-1, x0.unsqueeze(-1)).squeeze(-1)
66
+ elif remasking == "random":
67
+ conf = torch.rand((B, L), device=logits.device)
68
+ else:
69
+ raise ValueError(remasking)
70
+
71
+ x0 = torch.where(mask_block, x0, x_block)
72
+ neg_inf = torch.full_like(conf, -float("inf"))
73
+ conf = torch.where(mask_block, conf, neg_inf)
74
+
75
+ commit = torch.zeros_like(x_block, dtype=torch.bool)
76
+ for i in range(B):
77
+ k = int(num_transfer[i].item())
78
+ if k > 0:
79
+ valid = (conf[i] > -float("inf")).sum().item()
80
+ k = min(k, valid)
81
+ _, idx = torch.topk(conf[i], k)
82
+ commit[i, idx] = True
83
+
84
+ out = x_block.clone()
85
+ out[commit] = x0[commit]
86
+ return out
87
+
88
+
89
+ @torch.no_grad()
90
+ def generate(
91
+ model,
92
+ tokenizer,
93
+ prompt,
94
+ steps=128,
95
+ max_new_tokens=128,
96
+ block_size=32,
97
+ temperature=0.0,
98
+ cfg_scale=0.0,
99
+ remasking="low_confidence",
100
+ capture_interval=0,
101
+ ):
102
+ device = model.device
103
+ mask_id = tokenizer.mask_token_id
104
+ pad_id = tokenizer.pad_token_id
105
+ if pad_id is None:
106
+ pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id
107
+
108
+ if isinstance(prompt, torch.Tensor):
109
+ x = prompt.to(device).long()
110
+ else:
111
+ if isinstance(prompt[0], (list, tuple)):
112
+ max_len = max(len(p) for p in prompt)
113
+ x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long)
114
+ for i, p in enumerate(prompt):
115
+ x[i, : len(p)] = torch.tensor(p, device=device)
116
+ else:
117
+ x = torch.tensor(prompt, device=device).long()
118
+ if x.dim() == 1:
119
+ x = x.unsqueeze(0)
120
+
121
+ B = x.size(0)
122
+ finished = torch.zeros(B, dtype=torch.bool, device=device)
123
+
124
+ num_blocks = math.ceil(max_new_tokens / block_size)
125
+ steps_per_block = math.ceil(steps / num_blocks)
126
+ generated = 0
127
+
128
+ intermediates = []
129
+ total_step = 0
130
+
131
+ while generated < max_new_tokens:
132
+ if finished.all():
133
+ break
134
+ T_prefix = x.size(1)
135
+ offset = T_prefix % block_size
136
+ room = block_size if offset == 0 else block_size - offset
137
+ cur_len = min(room, max_new_tokens - generated)
138
+ if cur_len <= 0:
139
+ break
140
+
141
+ attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id)
142
+
143
+ out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
144
+ cond_past = out.past_key_values
145
+
146
+ if cfg_scale > 0:
147
+ un_x = x.clone()
148
+ un_x[:] = mask_id
149
+ out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
150
+ uncond_past = out_un.past_key_values
151
+ else:
152
+ uncond_past = None
153
+
154
+ block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long)
155
+ block[finished] = pad_id
156
+ x = torch.cat([x, block], dim=1)
157
+ T_total = x.size(1)
158
+
159
+ block_mask = x[:, -cur_len:] == mask_id
160
+ num_transfer = get_num_transfer_tokens(block_mask, steps_per_block)
161
+ eff_steps = num_transfer.size(1)
162
+
163
+ full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id)
164
+ attn_blk = full_attn[:, :, T_prefix:T_total, :]
165
+ pos_blk = full_pos[:, T_prefix:T_total]
166
+
167
+ for t in range(eff_steps):
168
+ x_blk = x[:, T_prefix:T_total]
169
+ m_blk = x_blk == mask_id
170
+
171
+ cond_logits = model(
172
+ x_blk, attention_mask=attn_blk, position_ids=pos_blk,
173
+ past_key_values=copy.deepcopy(cond_past), use_cache=False
174
+ ).logits
175
+
176
+ logits = cond_logits
177
+ if cfg_scale > 0:
178
+ un_logits = model(
179
+ x_blk, attention_mask=attn_blk, position_ids=pos_blk,
180
+ past_key_values=copy.deepcopy(uncond_past), use_cache=False
181
+ ).logits
182
+ logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits)
183
+
184
+ x_blk_new = diffusion_step_block(
185
+ logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking
186
+ )
187
+ x[:, T_prefix:T_total] = x_blk_new
188
+
189
+ if capture_interval > 0 and total_step % capture_interval == 0:
190
+ intermediates.append(x.clone())
191
+
192
+ total_step += 1
193
+
194
+ if tokenizer.eos_token_id is not None:
195
+ finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1)
196
+ if finished.all():
197
+ break
198
+
199
+ generated += cur_len
200
+ if finished.all():
201
+ break
202
+
203
+ if capture_interval > 0:
204
+ return x, intermediates
205
+ return x
206
+
207
+
208
+ @torch.no_grad()
209
+ def generate_stream(
210
+ model,
211
+ tokenizer,
212
+ prompt,
213
+ steps=128,
214
+ max_new_tokens=128,
215
+ block_size=32,
216
+ temperature=0.0,
217
+ cfg_scale=0.0,
218
+ remasking="low_confidence",
219
+ capture_interval=10,
220
+ ):
221
+ device = model.device
222
+ mask_id = tokenizer.mask_token_id
223
+ pad_id = tokenizer.pad_token_id
224
+ if pad_id is None:
225
+ pad_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else tokenizer.mask_token_id
226
+
227
+ if isinstance(prompt, torch.Tensor):
228
+ x = prompt.to(device).long()
229
+ else:
230
+ if isinstance(prompt[0], (list, tuple)):
231
+ max_len = max(len(p) for p in prompt)
232
+ x = torch.full((len(prompt), max_len), pad_id, device=device, dtype=torch.long)
233
+ for i, p in enumerate(prompt):
234
+ x[i, : len(p)] = torch.tensor(p, device=device)
235
+ else:
236
+ x = torch.tensor(prompt, device=device).long()
237
+ if x.dim() == 1:
238
+ x = x.unsqueeze(0)
239
+
240
+ B = x.size(0)
241
+ finished = torch.zeros(B, dtype=torch.bool, device=device)
242
+
243
+ num_blocks = math.ceil(max_new_tokens / block_size)
244
+ steps_per_block = math.ceil(steps / num_blocks)
245
+ generated = 0
246
+ total_step = 0
247
+
248
+ prompt_len = x.size(1)
249
+
250
+ while generated < max_new_tokens:
251
+ if finished.all():
252
+ break
253
+ T_prefix = x.size(1)
254
+ offset = T_prefix % block_size
255
+ room = block_size if offset == 0 else block_size - offset
256
+ cur_len = min(room, max_new_tokens - generated)
257
+ if cur_len <= 0:
258
+ break
259
+
260
+ attn_pfx, pos_pfx = build_staircase_attention_mask(x, block_size, pad_id)
261
+
262
+ out = model(x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
263
+ cond_past = out.past_key_values
264
+
265
+ if cfg_scale > 0:
266
+ un_x = x.clone()
267
+ un_x[:] = mask_id
268
+ out_un = model(un_x, attention_mask=attn_pfx, position_ids=pos_pfx, use_cache=True)
269
+ uncond_past = out_un.past_key_values
270
+ else:
271
+ uncond_past = None
272
+
273
+ block = torch.full((B, cur_len), mask_id, device=device, dtype=torch.long)
274
+ block[finished] = pad_id
275
+ x = torch.cat([x, block], dim=1)
276
+ T_total = x.size(1)
277
+
278
+ block_mask = x[:, -cur_len:] == mask_id
279
+ num_transfer = get_num_transfer_tokens(block_mask, steps_per_block)
280
+ eff_steps = num_transfer.size(1)
281
+
282
+ full_attn, full_pos = build_staircase_attention_mask(x, block_size, pad_id)
283
+ attn_blk = full_attn[:, :, T_prefix:T_total, :]
284
+ pos_blk = full_pos[:, T_prefix:T_total]
285
+
286
+ for t in range(eff_steps):
287
+ x_blk = x[:, T_prefix:T_total]
288
+ m_blk = x_blk == mask_id
289
+
290
+ cond_logits = model(
291
+ x_blk, attention_mask=attn_blk, position_ids=pos_blk,
292
+ past_key_values=copy.deepcopy(cond_past), use_cache=False
293
+ ).logits
294
+
295
+ logits = cond_logits
296
+ if cfg_scale > 0:
297
+ un_logits = model(
298
+ x_blk, attention_mask=attn_blk, position_ids=pos_blk,
299
+ past_key_values=copy.deepcopy(uncond_past), use_cache=False
300
+ ).logits
301
+ logits = un_logits + (cfg_scale + 1.0) * (cond_logits - un_logits)
302
+
303
+ x_blk_new = diffusion_step_block(
304
+ logits, x_blk, m_blk, num_transfer[:, t], temperature, remasking
305
+ )
306
+ x[:, T_prefix:T_total] = x_blk_new
307
+
308
+ if total_step % capture_interval == 0:
309
+ new_tokens = x[0, prompt_len:prompt_len + max_new_tokens].tolist()
310
+ text = tokenizer.decode(new_tokens, skip_special_tokens=True)
311
+ yield {
312
+ "type": "intermediate",
313
+ "step": total_step,
314
+ "text": text,
315
+ "total_steps": steps
316
+ }
317
+
318
+ total_step += 1
319
+
320
+ if tokenizer.eos_token_id is not None:
321
+ finished |= (x_blk_new == tokenizer.eos_token_id).any(dim=1)
322
+ if finished.all():
323
+ break
324
+
325
+ generated += cur_len
326
+ if finished.all():
327
+ break
328
+
329
+ new_tokens = x[0, prompt_len:prompt_len + max_new_tokens].tolist()
330
+ final_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
331
+ yield {
332
+ "type": "final",
333
+ "text": final_text,
334
+ "total_steps": total_step
335
+ }
336
+
337
+
338
+ def load_model():
339
+ global model, tokenizer, device
340
+
341
+ device = "cuda" if torch.cuda.is_available() else "cpu"
342
+ model_name = os.getenv("MODEL_NAME", "dllm-hub/Qwen3-0.6B-diffusion-bd3lm-v0.1")
343
+
344
+ print(f"Loading model {model_name} on {device}...")
345
+ model = AutoModelForMaskedLM.from_pretrained(
346
+ model_name,
347
+ dtype=torch.bfloat16,
348
+ trust_remote_code=True
349
+ ).to(device).eval()
350
+
351
+ tokenizer = AutoTokenizer.from_pretrained(
352
+ model_name,
353
+ trust_remote_code=True
354
+ )
355
+ print("Model loaded successfully!")
356
+
357
+
358
+ @app.route('/health', methods=['GET'])
359
+ def health():
360
+ return jsonify({"status": "healthy", "model_loaded": model is not None})
361
+
362
+
363
+ @app.route('/generate', methods=['POST'])
364
+ def generate_text():
365
+ if model is None or tokenizer is None:
366
+ return jsonify({"error": "Model not loaded"}), 503
367
+
368
+ data = request.get_json()
369
+
370
+ if not data or 'prompt' not in data:
371
+ return jsonify({"error": "Missing 'prompt' field"}), 400
372
+
373
+ prompt = data['prompt']
374
+ steps = data.get('steps', 256)
375
+ max_new_tokens = data.get('max_new_tokens', 256)
376
+ block_size = data.get('block_size', 32)
377
+ temperature = data.get('temperature', 0.0)
378
+ cfg_scale = data.get('cfg_scale', 0.0)
379
+ remasking = data.get('remasking', 'low_confidence')
380
+ system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.')
381
+
382
+ messages = [
383
+ {"role": "system", "content": system_prompt},
384
+ {"role": "user", "content": prompt}
385
+ ]
386
+
387
+ encoded = tokenizer.apply_chat_template(
388
+ messages,
389
+ add_generation_prompt=True,
390
+ tokenize=True,
391
+ enable_thinking=False
392
+ )
393
+
394
+ input_ids = torch.tensor([encoded], dtype=torch.long, device=device)
395
+
396
+ output = generate(
397
+ model,
398
+ tokenizer,
399
+ input_ids,
400
+ steps=steps,
401
+ max_new_tokens=max_new_tokens,
402
+ block_size=block_size,
403
+ temperature=temperature,
404
+ cfg_scale=cfg_scale,
405
+ remasking=remasking,
406
+ )
407
+
408
+ prompt_len = len(encoded)
409
+ new_tokens = output[0, prompt_len:prompt_len + max_new_tokens].tolist()
410
+ generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
411
+
412
+ return jsonify({
413
+ "prompt": prompt,
414
+ "generated_text": generated_text,
415
+ "parameters": {
416
+ "steps": steps,
417
+ "max_new_tokens": max_new_tokens,
418
+ "block_size": block_size,
419
+ "temperature": temperature,
420
+ "cfg_scale": cfg_scale,
421
+ "remasking": remasking
422
+ }
423
+ })
424
+
425
+
426
+ @app.route('/generate_stream', methods=['POST'])
427
+ def generate_text_stream():
428
+ if model is None or tokenizer is None:
429
+ return jsonify({"error": "Model not loaded"}), 503
430
+
431
+ data = request.get_json()
432
+
433
+ if not data or 'prompt' not in data:
434
+ return jsonify({"error": "Missing 'prompt' field"}), 400
435
+
436
+ prompt = data['prompt']
437
+ steps = data.get('steps', 256)
438
+ max_new_tokens = data.get('max_new_tokens', 256)
439
+ block_size = data.get('block_size', 32)
440
+ temperature = data.get('temperature', 0.0)
441
+ cfg_scale = data.get('cfg_scale', 0.0)
442
+ remasking = data.get('remasking', 'low_confidence')
443
+ system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.')
444
+ capture_interval = data.get('capture_interval', 10)
445
+
446
+ messages = [
447
+ {"role": "system", "content": system_prompt},
448
+ {"role": "user", "content": prompt}
449
+ ]
450
+
451
+ encoded = tokenizer.apply_chat_template(
452
+ messages,
453
+ add_generation_prompt=True,
454
+ tokenize=True,
455
+ enable_thinking=False
456
+ )
457
+
458
+ input_ids = torch.tensor([encoded], dtype=torch.long, device=device)
459
+
460
+ output, intermediates = generate(
461
+ model,
462
+ tokenizer,
463
+ input_ids,
464
+ steps=steps,
465
+ max_new_tokens=max_new_tokens,
466
+ block_size=block_size,
467
+ temperature=temperature,
468
+ cfg_scale=cfg_scale,
469
+ remasking=remasking,
470
+ capture_interval=capture_interval,
471
+ )
472
+
473
+ prompt_len = len(encoded)
474
+
475
+ intermediate_states = []
476
+ for i, intermediate in enumerate(intermediates):
477
+ new_tokens = intermediate[0, prompt_len:prompt_len + max_new_tokens].tolist()
478
+ text = tokenizer.decode(new_tokens, skip_special_tokens=True)
479
+ intermediate_states.append({
480
+ "step": i * capture_interval,
481
+ "text": text
482
+ })
483
+
484
+ new_tokens = output[0, prompt_len:prompt_len + max_new_tokens].tolist()
485
+ generated_text = tokenizer.decode(new_tokens, skip_special_tokens=True)
486
+
487
+ return jsonify({
488
+ "prompt": prompt,
489
+ "generated_text": generated_text,
490
+ "intermediate_states": intermediate_states,
491
+ "parameters": {
492
+ "steps": steps,
493
+ "max_new_tokens": max_new_tokens,
494
+ "block_size": block_size,
495
+ "temperature": temperature,
496
+ "cfg_scale": cfg_scale,
497
+ "remasking": remasking,
498
+ "capture_interval": capture_interval
499
+ }
500
+ })
501
+
502
+
503
+ @app.route('/generate_sse', methods=['POST'])
504
+ def generate_text_sse():
505
+ if model is None or tokenizer is None:
506
+ return jsonify({"error": "Model not loaded"}), 503
507
+
508
+ data = request.get_json()
509
+
510
+ if not data or 'prompt' not in data:
511
+ return jsonify({"error": "Missing 'prompt' field"}), 400
512
+
513
+ prompt = data['prompt']
514
+ steps = data.get('steps', 256)
515
+ max_new_tokens = data.get('max_new_tokens', 256)
516
+ block_size = data.get('block_size', 32)
517
+ temperature = data.get('temperature', 0.0)
518
+ cfg_scale = data.get('cfg_scale', 0.0)
519
+ remasking = data.get('remasking', 'low_confidence')
520
+ system_prompt = data.get('system_prompt', 'You are a helpful AI assistant.')
521
+ capture_interval = data.get('capture_interval', 10)
522
+
523
+ messages = [
524
+ {"role": "system", "content": system_prompt},
525
+ {"role": "user", "content": prompt}
526
+ ]
527
+
528
+ encoded = tokenizer.apply_chat_template(
529
+ messages,
530
+ add_generation_prompt=True,
531
+ tokenize=True,
532
+ enable_thinking=False
533
+ )
534
+
535
+ input_ids = torch.tensor([encoded], dtype=torch.long, device=device)
536
+
537
+ def stream():
538
+ for state in generate_stream(
539
+ model,
540
+ tokenizer,
541
+ input_ids,
542
+ steps=steps,
543
+ max_new_tokens=max_new_tokens,
544
+ block_size=block_size,
545
+ temperature=temperature,
546
+ cfg_scale=cfg_scale,
547
+ remasking=remasking,
548
+ capture_interval=capture_interval,
549
+ ):
550
+ yield f"data: {json.dumps(state)}\n\n"
551
+
552
+ return Response(
553
+ stream(),
554
+ mimetype='text/event-stream',
555
+ headers={
556
+ 'Cache-Control': 'no-cache',
557
+ 'X-Accel-Buffering': 'no',
558
+ }
559
+ )
560
+
561
+
562
+ if __name__ == '__main__':
563
+ load_model()
564
+ app.run(host='0.0.0.0', port=int(os.getenv('PORT', 5000)))
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ flask
2
+ torch
3
+ transformers
4
+ accelerate
5
+ numpy<2.0
6
+ git+https://github.com/ZHZisZZ/dllm.git