triflix commited on
Commit
55fb776
Β·
verified Β·
1 Parent(s): d180462

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +381 -0
app.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import uuid
4
+ import json
5
+ import logging
6
+ import asyncio
7
+ from contextlib import asynccontextmanager
8
+ from typing import Optional
9
+
10
+ import torch
11
+ from fastapi import FastAPI, HTTPException, Request
12
+ from fastapi.responses import StreamingResponse, JSONResponse
13
+ from fastapi.middleware.cors import CORSMiddleware
14
+ from pydantic import BaseModel, Field
15
+ from transformers import AutoModel, AutoTokenizer
16
+
17
+ # ─── Config ──────────────────────────────────────────────────
18
+
19
+ MODEL_NAME = "Dream-org/Dream-v0-Instruct-1B"
20
+ API_MODEL_ID = "dream-diffusion-1b"
21
+ PORT = int(os.environ.get("PORT", 7860))
22
+ QUANTIZE = os.environ.get("QUANTIZE", "true").lower() == "true"
23
+
24
+ logging.basicConfig(
25
+ level=logging.INFO,
26
+ format="[%(asctime)s] %(levelname)s %(message)s",
27
+ datefmt="%H:%M:%S",
28
+ )
29
+ log = logging.getLogger("dream-api")
30
+
31
+ # ─── Global Model References ─────────────────────────────────
32
+
33
+ model = None
34
+ tokenizer = None
35
+ model_loaded = False
36
+
37
+
38
+ # ─── Model Loading ────────────────────────────────────────────
39
+
40
+ def load_model():
41
+ global model, tokenizer, model_loaded
42
+
43
+ log.info(f"Loading tokenizer: {MODEL_NAME}")
44
+ tokenizer = AutoTokenizer.from_pretrained(
45
+ MODEL_NAME,
46
+ trust_remote_code=True,
47
+ )
48
+
49
+ log.info(f"Loading model: {MODEL_NAME}")
50
+ start = time.time()
51
+
52
+ model = AutoModel.from_pretrained(
53
+ MODEL_NAME,
54
+ torch_dtype=torch.float32,
55
+ trust_remote_code=True,
56
+ )
57
+ model.eval()
58
+
59
+ # INT8 Dynamic Quantization
60
+ if QUANTIZE:
61
+ try:
62
+ from torch.ao.quantization import quantize_dynamic
63
+ model = quantize_dynamic(
64
+ model,
65
+ {torch.nn.Linear},
66
+ dtype=torch.qint8,
67
+ )
68
+ log.info("βœ… INT8 quantization applied")
69
+ except Exception as e:
70
+ log.warning(f"⚠️ Quantization failed: {e}")
71
+
72
+ elapsed = time.time() - start
73
+ log.info(f"βœ… Model loaded in {elapsed:.1f}s")
74
+ model_loaded = True
75
+
76
+
77
+ # ─── Lifespan ─────────────────────────────────────────────────
78
+
79
+ @asynccontextmanager
80
+ async def lifespan(app: FastAPI):
81
+ # Startup: load model in a thread so we don't block
82
+ loop = asyncio.get_event_loop()
83
+ await loop.run_in_executor(None, load_model)
84
+ yield
85
+ # Shutdown
86
+ log.info("Shutting down")
87
+
88
+
89
+ # ─── FastAPI App ──────────────────────────────────────────────
90
+
91
+ app = FastAPI(
92
+ title="Dream Diffusion LLM API",
93
+ version="1.0.0",
94
+ lifespan=lifespan,
95
+ )
96
+
97
+ app.add_middleware(
98
+ CORSMiddleware,
99
+ allow_origins=["*"],
100
+ allow_methods=["*"],
101
+ allow_headers=["*"],
102
+ )
103
+
104
+
105
+ # ─── Pydantic Models ─────────────────────────────────────────
106
+
107
+ class Message(BaseModel):
108
+ role: str
109
+ content: str
110
+
111
+
112
+ class ChatCompletionRequest(BaseModel):
113
+ model: str = API_MODEL_ID
114
+ messages: list[Message]
115
+ max_tokens: Optional[int] = Field(default=256, le=1024, ge=1)
116
+ temperature: Optional[float] = Field(default=0.35, ge=0.0, le=2.0)
117
+ top_p: Optional[float] = Field(default=0.95, ge=0.0, le=1.0)
118
+ stream: Optional[bool] = False
119
+ # Diffusion-specific
120
+ steps: Optional[int] = Field(default=64, le=256, ge=1)
121
+
122
+
123
+ class ChatCompletionMessage(BaseModel):
124
+ role: str = "assistant"
125
+ content: str
126
+
127
+
128
+ class Choice(BaseModel):
129
+ index: int = 0
130
+ message: ChatCompletionMessage
131
+ finish_reason: str = "stop"
132
+
133
+
134
+ class Usage(BaseModel):
135
+ prompt_tokens: int
136
+ completion_tokens: int
137
+ total_tokens: int
138
+
139
+
140
+ class ChatCompletionResponse(BaseModel):
141
+ id: str
142
+ object: str = "chat.completion"
143
+ created: int
144
+ model: str
145
+ choices: list[Choice]
146
+ usage: Usage
147
+
148
+
149
+ # ─── Inference Function ──────────────────────────────────────
150
+
151
+ def run_inference(
152
+ messages: list[Message],
153
+ max_tokens: int,
154
+ steps: int,
155
+ temperature: float,
156
+ top_p: float,
157
+ ) -> tuple[str, float]:
158
+ """Run diffusion generation. Returns (text, elapsed_ms)."""
159
+
160
+ # Build chat prompt
161
+ msgs = [{"role": m.role, "content": m.content} for m in messages]
162
+ input_text = tokenizer.apply_chat_template(
163
+ msgs,
164
+ tokenize=False,
165
+ add_generation_prompt=True,
166
+ )
167
+
168
+ input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"]
169
+ attention_mask = torch.ones_like(input_ids)
170
+ prompt_len = input_ids.shape[1]
171
+
172
+ # Generate
173
+ start = time.time()
174
+ with torch.no_grad():
175
+ output = model.diffusion_generate(
176
+ input_ids,
177
+ attention_mask=attention_mask,
178
+ max_new_tokens=max_tokens,
179
+ output_history=False,
180
+ steps=steps,
181
+ temperature=temperature,
182
+ top_p=top_p,
183
+ alg="entropy",
184
+ alg_temp=0.1,
185
+ )
186
+ elapsed_ms = (time.time() - start) * 1000
187
+
188
+ # Decode
189
+ generated_ids = output[0, prompt_len:]
190
+ text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
191
+
192
+ return text, elapsed_ms
193
+
194
+
195
+ # ─── Token Estimator ─────────────────────────────────────────
196
+
197
+ def estimate_tokens(text: str) -> int:
198
+ words = len(text.split())
199
+ return max(int(words / 0.75), 1)
200
+
201
+
202
+ # ─── Generate Request ID ─────────────────────────────────────
203
+
204
+ def gen_id() -> str:
205
+ return f"chatcmpl-{uuid.uuid4().hex[:12]}"
206
+
207
+
208
+ # ─── SSE Streaming Generator ─────────────────────────────────
209
+
210
+ async def stream_generator(text: str, req_id: str):
211
+ """Yield SSE chunks word-by-word from the generated text."""
212
+ now = int(time.time())
213
+
214
+ # 1) Role chunk
215
+ role_chunk = {
216
+ "id": req_id,
217
+ "object": "chat.completion.chunk",
218
+ "created": now,
219
+ "model": API_MODEL_ID,
220
+ "choices": [{
221
+ "index": 0,
222
+ "delta": {"role": "assistant"},
223
+ "finish_reason": None,
224
+ }],
225
+ }
226
+ yield f"data: {json.dumps(role_chunk)}\n\n"
227
+
228
+ # 2) Content chunks β€” word by word
229
+ words = text.split()
230
+ for i, word in enumerate(words):
231
+ content = word + ("" if i == len(words) - 1 else " ")
232
+ chunk = {
233
+ "id": req_id,
234
+ "object": "chat.completion.chunk",
235
+ "created": now,
236
+ "model": API_MODEL_ID,
237
+ "choices": [{
238
+ "index": 0,
239
+ "delta": {"content": content},
240
+ "finish_reason": None,
241
+ }],
242
+ }
243
+ yield f"data: {json.dumps(chunk)}\n\n"
244
+ await asyncio.sleep(0.015) # typing effect
245
+
246
+ # 3) Stop chunk
247
+ stop_chunk = {
248
+ "id": req_id,
249
+ "object": "chat.completion.chunk",
250
+ "created": now,
251
+ "model": API_MODEL_ID,
252
+ "choices": [{
253
+ "index": 0,
254
+ "delta": {},
255
+ "finish_reason": "stop",
256
+ }],
257
+ }
258
+ yield f"data: {json.dumps(stop_chunk)}\n\n"
259
+ yield "data: [DONE]\n\n"
260
+
261
+
262
+ # ─── Routes ──────────────────────────────────────────────────
263
+
264
+ @app.get("/")
265
+ async def root():
266
+ return {
267
+ "name": "Dream Diffusion LLM API",
268
+ "model": API_MODEL_ID,
269
+ "version": "1.0.0",
270
+ "openai_compatible": True,
271
+ "endpoints": {
272
+ "chat": "POST /v1/chat/completions",
273
+ "models": "GET /v1/models",
274
+ "health": "GET /health",
275
+ },
276
+ }
277
+
278
+
279
+ @app.get("/health")
280
+ async def health():
281
+ if not model_loaded:
282
+ return JSONResponse(
283
+ status_code=503,
284
+ content={"status": "loading", "model": MODEL_NAME},
285
+ )
286
+ return {"status": "healthy", "model": MODEL_NAME}
287
+
288
+
289
+ @app.get("/v1/models")
290
+ async def list_models():
291
+ return {
292
+ "object": "list",
293
+ "data": [
294
+ {
295
+ "id": API_MODEL_ID,
296
+ "object": "model",
297
+ "created": 1700000000,
298
+ "owned_by": "dream-org",
299
+ }
300
+ ],
301
+ }
302
+
303
+
304
+ @app.post("/v1/chat/completions")
305
+ async def chat_completions(req: ChatCompletionRequest):
306
+ if not model_loaded:
307
+ raise HTTPException(status_code=503, detail="Model is still loading")
308
+
309
+ if not req.messages:
310
+ raise HTTPException(status_code=400, detail="messages array is required")
311
+
312
+ log.info(
313
+ f"Request: steps={req.steps}, max_tokens={req.max_tokens}, "
314
+ f"temp={req.temperature}, stream={req.stream}"
315
+ )
316
+
317
+ # Run inference in thread pool (blocking call)
318
+ loop = asyncio.get_event_loop()
319
+ try:
320
+ text, elapsed_ms = await loop.run_in_executor(
321
+ None,
322
+ run_inference,
323
+ req.messages,
324
+ req.max_tokens,
325
+ req.steps,
326
+ req.temperature,
327
+ req.top_p,
328
+ )
329
+ except Exception as e:
330
+ log.error(f"Inference error: {e}", exc_info=True)
331
+ raise HTTPException(status_code=500, detail=f"Inference failed: {str(e)}")
332
+
333
+ log.info(f"Generated {len(text)} chars in {elapsed_ms:.0f}ms")
334
+
335
+ req_id = gen_id()
336
+
337
+ # ── Streaming Response ──
338
+ if req.stream:
339
+ return StreamingResponse(
340
+ stream_generator(text, req_id),
341
+ media_type="text/event-stream",
342
+ headers={
343
+ "Cache-Control": "no-cache",
344
+ "Connection": "keep-alive",
345
+ "X-Accel-Buffering": "no",
346
+ },
347
+ )
348
+
349
+ # ── Non-Streaming Response ──
350
+ prompt_tokens = sum(estimate_tokens(m.content) for m in req.messages)
351
+ completion_tokens = estimate_tokens(text)
352
+
353
+ return ChatCompletionResponse(
354
+ id=req_id,
355
+ created=int(time.time()),
356
+ model=API_MODEL_ID,
357
+ choices=[
358
+ Choice(
359
+ message=ChatCompletionMessage(content=text),
360
+ )
361
+ ],
362
+ usage=Usage(
363
+ prompt_tokens=prompt_tokens,
364
+ completion_tokens=completion_tokens,
365
+ total_tokens=prompt_tokens + completion_tokens,
366
+ ),
367
+ )
368
+
369
+
370
+ # ─── Run ──────────────────────────────────────────────────────
371
+
372
+ if __name__ == "__main__":
373
+ import uvicorn
374
+
375
+ uvicorn.run(
376
+ "app:app",
377
+ host="0.0.0.0",
378
+ port=PORT,
379
+ workers=1,
380
+ log_level="info",
381
+ )