adowu commited on
Commit
af02d93
·
1 Parent(s): d7110b1
Files changed (3) hide show
  1. Dockerfile +22 -0
  2. main.py +305 -0
  3. requirements.txt +5 -0
Dockerfile ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ build-essential curl && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ # Copy requirements and install Python deps
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy application code
15
+ COPY main.py .
16
+ COPY .env* ./ 2>/dev/null || true
17
+
18
+ # Expose port
19
+ EXPOSE 7860
20
+
21
+ # HuggingFace Spaces expects the app to run on port 7860
22
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Falcon H1R - OpenAI-compatible FastAPI wrapper.
2
+
3
+ Mimics the exact behavior of the working HTML chatbot:
4
+ 1. Client.connect(space_url)
5
+ 2. client.predict(api_name="/new_chat")
6
+ 3. client.predict(api_name="/add_message", input_value=msg, settings_form_value=params)
7
+ 4. Extract res.data[5]['value'][-1]['content']
8
+ """
9
+ from __future__ import annotations
10
+
11
+ import os, json, time, uuid, asyncio, logging
12
+ from typing import Any, AsyncGenerator
13
+ from contextlib import asynccontextmanager
14
+
15
+ from dotenv import load_dotenv
16
+ from fastapi import FastAPI, HTTPException, Request, Depends
17
+ from fastapi.middleware.cors import CORSMiddleware
18
+ from fastapi.responses import StreamingResponse, JSONResponse
19
+ from pydantic import BaseModel
20
+ from gradio_client import Client
21
+
22
+ load_dotenv()
23
+
24
+ # ---------------------------------------------------------------------------
25
+ # Config
26
+ # ---------------------------------------------------------------------------
27
+ API_KEY = os.getenv("API_KEY", "")
28
+ HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://tiiuae-falcon-h1r-playground.hf.space/")
29
+ MODEL_ID = os.getenv("MODEL_ID", "tiiuae/Falcon-H1R-7B")
30
+ DEFAULT_TEMP = float(os.getenv("DEFAULT_TEMPERATURE", "0.6"))
31
+ DEFAULT_TOP_P = float(os.getenv("DEFAULT_TOP_P", "0.95"))
32
+ DEFAULT_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "1024"))
33
+
34
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
35
+ log = logging.getLogger(__name__)
36
+
37
+ # ---------------------------------------------------------------------------
38
+ # Gradio client (singleton)
39
+ # ---------------------------------------------------------------------------
40
+ _client: Client | None = None
41
+
42
+ async def get_client() -> Client:
43
+ global _client
44
+ if _client is None:
45
+ log.info("Connecting to %s", HF_SPACE_URL)
46
+ _client = await asyncio.to_thread(Client, HF_SPACE_URL)
47
+ log.info("Connected.")
48
+ return _client
49
+
50
+ # ---------------------------------------------------------------------------
51
+ # Pydantic schemas
52
+ # ---------------------------------------------------------------------------
53
+
54
+ class Message(BaseModel):
55
+ role: str
56
+ content: str | list[dict] = ""
57
+ name: str | None = None
58
+
59
+ class ChatCompletionRequest(BaseModel):
60
+ model: str = MODEL_ID
61
+ messages: list[Message]
62
+ temperature: float = DEFAULT_TEMP
63
+ top_p: float = DEFAULT_TOP_P
64
+ max_tokens: int = DEFAULT_TOKENS
65
+ stream: bool = False
66
+ frequency_penalty: float = 0
67
+ presence_penalty: float = 0
68
+ stop: str | list[str] | None = None
69
+ seed: int | None = None
70
+ user: str | None = None
71
+
72
+ # ---------------------------------------------------------------------------
73
+ # Auth
74
+ # ---------------------------------------------------------------------------
75
+
76
+ async def verify_key(request: Request) -> None:
77
+ if not API_KEY:
78
+ return
79
+ auth = request.headers.get("Authorization", "")
80
+ if not auth.startswith("Bearer ") or auth[7:] != API_KEY:
81
+ raise HTTPException(status_code=401, detail="Invalid or missing API key")
82
+
83
+ # ---------------------------------------------------------------------------
84
+ # Lifespan context manager (modern FastAPI pattern)
85
+ # ---------------------------------------------------------------------------
86
+
87
+ @asynccontextmanager
88
+ async def lifespan(app: FastAPI):
89
+ # Startup
90
+ log.info("Starting up - connecting to Gradio client...")
91
+ await get_client()
92
+ log.info("Startup complete.")
93
+ yield
94
+ # Shutdown (if needed)
95
+ log.info("Shutting down.")
96
+
97
+ # ---------------------------------------------------------------------------
98
+ # App
99
+ # ---------------------------------------------------------------------------
100
+
101
+ app = FastAPI(
102
+ title="Falcon H1R API",
103
+ version="3.1.0",
104
+ lifespan=lifespan,
105
+ )
106
+
107
+ app.add_middleware(
108
+ CORSMiddleware,
109
+ allow_origins=["*"],
110
+ allow_credentials=True,
111
+ allow_methods=["*"],
112
+ allow_headers=["*"],
113
+ )
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Business logic - EXACTLY like the HTML chatbot
117
+ # ---------------------------------------------------------------------------
118
+
119
+ def _content_str(m: Message) -> str:
120
+ if isinstance(m.content, str):
121
+ return m.content
122
+ return "".join(p.get("text", "") for p in m.content if p.get("type") == "text")
123
+
124
+ def _build_prompt(messages: list[Message]) -> str:
125
+ """Flatten messages into a single prompt string."""
126
+ system, parts = [], []
127
+ for m in messages:
128
+ c = _content_str(m)
129
+ if m.role == "system": system.append(c)
130
+ elif m.role == "user": parts.append(c)
131
+ elif m.role == "assistant": parts.append(f"[ASSISTANT]\n{c}")
132
+ prefix = "[SYSTEM]\n" + "\n".join(system) + "\n[/SYSTEM]\n" if system else ""
133
+ return prefix + "\n".join(parts)
134
+
135
+ def _extract_text(result) -> str:
136
+ """
137
+ HTML chatbot does:
138
+ const last = res.data[5].value.at(-1);
139
+ const text = Array.isArray(last.content)
140
+ ? last.content.filter(p => p.type === 'text').map(p => p.content.trim()).join('')
141
+ : last.content;
142
+ """
143
+ try:
144
+ # res.data is a list, index 5 contains the chatbot component
145
+ chatbot_data = result.data[5]
146
+ # chatbot_data is a dict with 'value' key
147
+ conversation = chatbot_data["value"]
148
+ # last message
149
+ last = conversation[-1]
150
+ content = last["content"]
151
+
152
+ if isinstance(content, list):
153
+ # Filter type='text' blocks
154
+ return "".join(
155
+ p["content"].strip()
156
+ for p in content
157
+ if p.get("type") == "text"
158
+ )
159
+ return str(content)
160
+ except Exception as e:
161
+ log.error("_extract_text failed: %s | raw data: %s", e, result.data)
162
+ raise ValueError(f"Failed to extract text: {e}") from e
163
+
164
+ async def _call_falcon(prompt: str, req: ChatCompletionRequest) -> str:
165
+ """
166
+ Exact replica of HTML submit() function:
167
+ 1. client.predict('/add_message', { input_value: msg, settings_form_value: PARAMS })
168
+ 2. Extract res.data[5].value.at(-1).content
169
+ """
170
+ client = await get_client()
171
+
172
+ settings = {
173
+ "model": req.model,
174
+ "temperature": req.temperature,
175
+ "max_new_tokens": req.max_tokens,
176
+ "top_p": req.top_p,
177
+ }
178
+
179
+ # Step 1: Reset chat (like boot() does once, but we do per request for isolation)
180
+ await asyncio.to_thread(
181
+ client.predict,
182
+ api_name="/new_chat"
183
+ )
184
+
185
+ # Step 2: Send message - EXACTLY like HTML
186
+ result = await asyncio.to_thread(
187
+ client.predict,
188
+ input_value=prompt,
189
+ settings_form_value=settings,
190
+ api_name="/add_message"
191
+ )
192
+
193
+ return _extract_text(result)
194
+
195
+ def _make_response(text: str, req: ChatCompletionRequest) -> dict:
196
+ pt = sum(len(_content_str(m)) for m in req.messages) // 4
197
+ ct = len(text) // 4
198
+ return {
199
+ "id": f"chatcmpl-{uuid.uuid4().hex}",
200
+ "object": "chat.completion",
201
+ "created": int(time.time()),
202
+ "model": req.model,
203
+ "system_fingerprint": f"fp_{uuid.uuid4().hex[:8]}",
204
+ "choices": [{
205
+ "index": 0,
206
+ "message": {
207
+ "role": "assistant",
208
+ "content": text,
209
+ "tool_calls": None,
210
+ "function_call": None,
211
+ },
212
+ "finish_reason": "stop",
213
+ "logprobs": None,
214
+ }],
215
+ "usage": {
216
+ "prompt_tokens": pt,
217
+ "completion_tokens": ct,
218
+ "total_tokens": pt + ct,
219
+ },
220
+ }
221
+
222
+ async def _stream_sse(text: str, req: ChatCompletionRequest) -> AsyncGenerator[str, None]:
223
+ """Simulate streaming by chunking the full response."""
224
+ cid = f"chatcmpl-{uuid.uuid4().hex}"
225
+ created = int(time.time())
226
+
227
+ # Stream in small chunks
228
+ for i in range(0, len(text), 6):
229
+ chunk = {
230
+ "id": cid,
231
+ "object": "chat.completion.chunk",
232
+ "created": created,
233
+ "model": req.model,
234
+ "choices": [{
235
+ "index": 0,
236
+ "delta": {"role": "assistant", "content": text[i:i+6]},
237
+ "finish_reason": None,
238
+ }],
239
+ }
240
+ yield f"data: {json.dumps(chunk)}\n\n"
241
+ await asyncio.sleep(0.01)
242
+
243
+ # Final chunk
244
+ pt = sum(len(_content_str(m)) for m in req.messages) // 4
245
+ ct = len(text) // 4
246
+ final = {
247
+ "id": cid,
248
+ "object": "chat.completion.chunk",
249
+ "created": created,
250
+ "model": req.model,
251
+ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
252
+ "usage": {"prompt_tokens": pt, "completion_tokens": ct, "total_tokens": pt + ct},
253
+ }
254
+ yield f"data: {json.dumps(final)}\n\n"
255
+ yield "data: [DONE]\n\n"
256
+
257
+ # ---------------------------------------------------------------------------
258
+ # Routes
259
+ # ---------------------------------------------------------------------------
260
+
261
+ @app.get("/")
262
+ async def root():
263
+ return {
264
+ "service": "Falcon H1R OpenAI-compatible API",
265
+ "version": "3.1.0",
266
+ "endpoints": {
267
+ "health": "/health",
268
+ "models": "/v1/models",
269
+ "chat": "/v1/chat/completions",
270
+ },
271
+ }
272
+
273
+ @app.get("/health")
274
+ async def health():
275
+ return {"status": "ok", "model": MODEL_ID, "space": HF_SPACE_URL}
276
+
277
+ @app.get("/v1/models")
278
+ async def list_models(_: None = Depends(verify_key)):
279
+ return {"object": "list", "data": [{
280
+ "id": MODEL_ID,
281
+ "object": "model",
282
+ "created": 1710000000,
283
+ "owned_by": "tiiuae",
284
+ }]}
285
+
286
+ @app.post("/v1/chat/completions")
287
+ async def chat_completions(req: ChatCompletionRequest, _: None = Depends(verify_key)):
288
+ prompt = _build_prompt(req.messages)
289
+ log.info("Request | model=%s temp=%.2f tokens=%d stream=%s",
290
+ req.model, req.temperature, req.max_tokens, req.stream)
291
+
292
+ try:
293
+ text = await _call_falcon(prompt, req)
294
+ except Exception as exc:
295
+ log.exception("Falcon call failed")
296
+ raise HTTPException(status_code=502, detail=f"Upstream error: {exc}") from exc
297
+
298
+ if req.stream:
299
+ return StreamingResponse(
300
+ _stream_sse(text, req),
301
+ media_type="text/event-stream",
302
+ headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
303
+ )
304
+
305
+ return JSONResponse(content=_make_response(text, req))
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ fastapi>=0.111.0
2
+ uvicorn[standard]>=0.29.0
3
+ gradio-client>=0.16.0
4
+ python-dotenv>=1.0.0
5
+ pydantic>=2.7.0