bahi-bh commited on
Commit
0f4cc05
·
verified ·
1 Parent(s): 01b4702

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +192 -222
main.py CHANGED
@@ -1,9 +1,12 @@
1
  import os
2
  import json
3
  import time
 
4
  import asyncio
5
  import logging
6
- import inspect
 
 
7
 
8
  from fastapi import FastAPI, HTTPException, Header
9
  from fastapi.responses import StreamingResponse
@@ -11,25 +14,10 @@ from fastapi.middleware.cors import CORSMiddleware
11
  from pydantic import BaseModel
12
  from typing import List, Optional
13
 
14
- import g4f
15
- import g4f.Provider as Provider
16
- import litellm
17
-
18
- # ----------------------------
19
- # Logging
20
- # ----------------------------
21
-
22
  logging.basicConfig(level=logging.INFO)
23
  logger=logging.getLogger(__name__)
24
 
25
- # ----------------------------
26
- # App
27
- # ----------------------------
28
-
29
- app=FastAPI(
30
- title="AI Gateway",
31
- version="2.0"
32
- )
33
 
34
  app.add_middleware(
35
  CORSMiddleware,
@@ -40,48 +28,36 @@ app.add_middleware(
40
 
41
  API_KEY=os.getenv(
42
  "API_KEY",
43
- "your_secret"
44
  )
45
 
46
- DEFAULT_MODEL=os.getenv(
47
- "DEFAULT_MODEL",
48
- "groq/llama-3.3-70b-versatile"
49
- )
50
-
51
- # ----------------------------
52
  # Models
53
- # ----------------------------
54
 
55
- class Message(BaseModel):
56
  role:str
57
  content:str
58
 
59
 
60
  class ChatRequest(BaseModel):
61
 
62
- model:str=DEFAULT_MODEL
63
 
64
- messages:List[Message]
65
 
66
  stream:bool=False
67
 
68
  provider:Optional[str]=None
69
 
70
 
71
- # ----------------------------
72
- # Auth
73
- # ----------------------------
74
 
75
  def verify(auth):
76
 
77
- if not auth:
78
-
79
- raise HTTPException(
80
- status_code=401,
81
- detail="Missing token"
82
- )
83
-
84
- if auth != f"Bearer {API_KEY}":
85
 
86
  raise HTTPException(
87
  status_code=401,
@@ -89,30 +65,41 @@ def verify(auth):
89
  )
90
 
91
 
92
- # ----------------------------
93
- # g4f provider discovery
94
- # ----------------------------
95
 
96
  SKIP={
97
 
98
  "BaseProvider",
99
  "RetryProvider",
100
- "AsyncProvider"
 
 
 
101
 
102
  }
103
 
104
 
 
 
 
 
105
  def collect_models(cls):
106
 
107
- result=[]
108
 
109
- for attr in [
110
 
111
  "default_model",
 
112
  "models",
113
- "model"
 
 
 
114
 
115
- ]:
116
 
117
  v=getattr(
118
  cls,
@@ -125,137 +112,151 @@ def collect_models(cls):
125
 
126
  if isinstance(v,str):
127
 
128
- result.append(v)
129
 
130
- elif isinstance(
131
- v,
132
- (list,tuple)
133
- ):
134
 
135
- result.extend(
136
  [str(x) for x in v]
137
  )
138
 
139
- return list(
140
- set(result)
141
- )
142
-
143
 
144
- # ----------------------------
145
- # health
146
- # ----------------------------
147
 
148
- @app.get("/")
 
 
149
 
150
- async def root():
151
 
152
- return {
153
 
154
- "status":"online",
 
155
 
156
- "default":DEFAULT_MODEL
 
157
 
158
- }
159
 
 
 
160
 
161
- # ----------------------------
162
- # models
163
- # ----------------------------
164
-
165
- @app.get("/v1/models")
166
-
167
- async def models(
168
- authorization:str=Header(None)
169
- ):
170
 
171
- verify(authorization)
172
 
173
- data=[]
 
 
 
174
 
175
- try:
 
176
 
177
- # LiteLLM models
 
 
 
 
 
 
 
178
 
179
- ll_models=[
 
 
 
 
 
 
 
180
 
181
- "groq/llama-3.3-70b-versatile",
 
 
182
 
183
- "groq/llama-3.1-8b-instant",
 
184
 
185
- "openrouter/qwen/qwen-2.5-72b-instruct",
186
 
187
- "huggingface/Qwen/Qwen2.5-72B-Instruct",
188
 
189
- "openrouter/deepseek/deepseek-chat",
190
 
191
- "openai/gpt-4o",
192
 
193
- "openai/gpt-4o-mini"
 
194
 
195
- ]
196
 
197
- for m in ll_models:
198
 
199
- data.append({
200
 
201
- "id":m,
 
 
202
 
203
- "object":"model",
204
 
205
- "owned_by":"litellm"
206
 
207
- })
208
 
209
- # g4f dynamic providers
210
 
211
- for name in dir(Provider):
212
 
213
- if name.startswith("_"):
214
- continue
215
 
216
- if name in SKIP:
217
- continue
218
 
219
- cls=getattr(
220
- Provider,
221
- name
222
- )
223
 
224
- if not inspect.isclass(cls):
225
- continue
 
226
 
227
- models=collect_models(
228
- cls
229
- )
230
 
231
- for m in models:
 
 
232
 
233
- data.append({
 
 
234
 
235
- "id":m,
236
 
237
- "object":"model",
238
 
239
- "owned_by":name
240
 
241
- })
242
 
243
- except Exception as e:
244
 
245
- logger.error(e)
246
 
247
- return {
248
 
249
- "object":"list",
250
 
251
- "data":data
 
 
252
 
 
253
  }
254
 
255
 
256
- # ----------------------------
257
- # Chat
258
- # ----------------------------
259
 
260
  @app.post("/v1/chat/completions")
261
 
@@ -284,23 +285,42 @@ authorization:str=Header(None)
284
  for m in body.messages
285
  ]
286
 
 
287
 
288
- # =====================
289
- # Streaming
290
- # =====================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  if body.stream:
293
 
294
- async def generate():
295
 
296
  try:
297
 
298
- # LiteLLM first
299
-
300
- response=litellm.completion(
301
 
302
  model=body.model,
303
 
 
 
304
  messages=messages,
305
 
306
  stream=True
@@ -309,132 +329,75 @@ authorization:str=Header(None)
309
 
310
  for chunk in response:
311
 
312
- content=""
313
-
314
- try:
315
-
316
- content=chunk.choices[0].delta.content
317
- except:
318
- pass
319
-
320
- if content:
321
-
322
- payload={
323
-
324
- "id":"chatcmpl",
325
-
326
- "object":"chat.completion.chunk",
327
-
328
- "created":int(time.time()),
329
-
330
- "model":body.model,
331
-
332
- "choices":[{
333
-
334
- "delta":{
335
-
336
- "content":content
337
-
338
- },
339
-
340
- "index":0
341
-
342
- }]
343
- }
344
-
345
- yield f"data:{json.dumps(payload)}\n\n"
346
-
347
- except:
348
-
349
- logger.info(
350
- "Fallback g4f"
351
- )
352
-
353
- response=g4f.ChatCompletion.create(
354
-
355
- model=body.model,
356
 
357
- messages=messages,
358
 
359
- stream=True
360
- )
361
 
362
- for chunk in response:
 
 
363
 
364
- payload={
365
 
366
  "choices":[{
367
 
368
  "delta":{
369
 
370
- "content":str(chunk)
371
-
 
372
  }
 
373
  }]
374
  }
375
 
376
  yield f"data:{json.dumps(payload)}\n\n"
377
 
378
- yield "data:[DONE]\n\n"
379
 
380
- return StreamingResponse(
381
 
382
- generate(),
383
 
384
- media_type="text/event-stream"
385
- )
386
-
387
-
388
-
389
- # =====================
390
- # Normal
391
- # =====================
392
 
393
- try:
394
-
395
- response=await asyncio.to_thread(
396
 
397
- litellm.completion,
398
 
399
- model=body.model,
400
 
401
- messages=messages
402
 
403
  )
404
 
405
- content=response.choices[0].message.content
406
-
407
-
408
- except Exception:
409
 
410
- logger.info(
411
- "Using g4f fallback"
412
- )
413
 
414
- content=await asyncio.to_thread(
415
 
416
  g4f.ChatCompletion.create,
417
 
418
  model=body.model,
419
 
420
- messages=messages
421
- )
422
-
423
 
 
424
 
425
- return {
426
 
427
- "id":"chatcmpl",
428
 
429
- "object":"chat.completion",
430
 
431
- "created":int(time.time()),
432
 
433
- "model":body.model,
434
 
435
- "choices":[
436
 
437
- {
438
 
439
  "index":0,
440
 
@@ -442,14 +405,21 @@ authorization:str=Header(None)
442
 
443
  "role":"assistant",
444
 
445
- "content":str(content)
 
 
446
 
447
- },
448
 
449
- "finish_reason":"stop"
450
 
451
- }
452
 
453
- ]
454
 
455
- }
 
 
 
 
 
 
1
  import os
2
  import json
3
  import time
4
+ import inspect
5
  import asyncio
6
  import logging
7
+
8
+ import g4f
9
+ import g4f.Provider as Provider
10
 
11
  from fastapi import FastAPI, HTTPException, Header
12
  from fastapi.responses import StreamingResponse
 
14
  from pydantic import BaseModel
15
  from typing import List, Optional
16
 
 
 
 
 
 
 
 
 
17
  logging.basicConfig(level=logging.INFO)
18
  logger=logging.getLogger(__name__)
19
 
20
+ app=FastAPI(title="G4F Dynamic API")
 
 
 
 
 
 
 
21
 
22
  app.add_middleware(
23
  CORSMiddleware,
 
28
 
29
  API_KEY=os.getenv(
30
  "API_KEY",
31
+ "secret"
32
  )
33
 
34
+ # ======================
 
 
 
 
 
35
  # Models
36
+ # ======================
37
 
38
+ class ChatMessage(BaseModel):
39
  role:str
40
  content:str
41
 
42
 
43
  class ChatRequest(BaseModel):
44
 
45
+ model:str="gpt-4o"
46
 
47
+ messages:List[ChatMessage]
48
 
49
  stream:bool=False
50
 
51
  provider:Optional[str]=None
52
 
53
 
54
+ # ======================
55
+ # auth
56
+ # ======================
57
 
58
  def verify(auth):
59
 
60
+ if auth!=f"Bearer {API_KEY}":
 
 
 
 
 
 
 
61
 
62
  raise HTTPException(
63
  status_code=401,
 
65
  )
66
 
67
 
68
+ # ======================
69
+ # provider discovery
70
+ # ======================
71
 
72
  SKIP={
73
 
74
  "BaseProvider",
75
  "RetryProvider",
76
+ "AsyncProvider",
77
+ "IterListProvider",
78
+ "ProviderType",
79
+ "CreateResult"
80
 
81
  }
82
 
83
 
84
+ PROVIDERS={}
85
+ MODEL_MAP={}
86
+
87
+
88
  def collect_models(cls):
89
 
90
+ found=[]
91
 
92
+ attrs=[
93
 
94
  "default_model",
95
+ "model",
96
  "models",
97
+ "text_models",
98
+ "vision_models"
99
+
100
+ ]
101
 
102
+ for attr in attrs:
103
 
104
  v=getattr(
105
  cls,
 
112
 
113
  if isinstance(v,str):
114
 
115
+ found.append(v)
116
 
117
+ elif isinstance(v,(list,tuple,set)):
 
 
 
118
 
119
+ found.extend(
120
  [str(x) for x in v]
121
  )
122
 
123
+ elif isinstance(v,dict):
 
 
 
124
 
125
+ found.extend(
126
+ list(v.keys())
127
+ )
128
 
129
+ return list(
130
+ set(found)
131
+ )
132
 
 
133
 
134
+ def build():
135
 
136
+ global PROVIDERS
137
+ global MODEL_MAP
138
 
139
+ PROVIDERS={}
140
+ MODEL_MAP={}
141
 
142
+ for name in dir(Provider):
143
 
144
+ if name.startswith("_"):
145
+ continue
146
 
147
+ if name in SKIP:
148
+ continue
 
 
 
 
 
 
 
149
 
150
+ try:
151
 
152
+ cls=getattr(
153
+ Provider,
154
+ name
155
+ )
156
 
157
+ if not inspect.isclass(cls):
158
+ continue
159
 
160
+ if not bool(
161
+ getattr(
162
+ cls,
163
+ "working",
164
+ False
165
+ )
166
+ ):
167
+ continue
168
 
169
+ if bool(
170
+ getattr(
171
+ cls,
172
+ "needs_auth",
173
+ False
174
+ )
175
+ ):
176
+ continue
177
 
178
+ models=collect_models(
179
+ cls
180
+ )
181
 
182
+ if not models:
183
+ continue
184
 
185
+ PROVIDERS[name]=models
186
 
187
+ for m in models:
188
 
189
+ if m not in MODEL_MAP:
190
 
191
+ MODEL_MAP[m]=name
192
 
193
+ except:
194
+ pass
195
 
 
196
 
197
+ build()
198
 
 
199
 
200
+ # ======================
201
+ # health
202
+ # ======================
203
 
204
+ @app.get("/")
205
 
206
+ async def health():
207
 
208
+ return{
209
 
210
+ "status":"online",
211
 
212
+ "providers":len(PROVIDERS),
213
 
214
+ "models":len(MODEL_MAP)
 
215
 
216
+ }
 
217
 
 
 
 
 
218
 
219
+ # ======================
220
+ # models
221
+ # ======================
222
 
223
+ @app.get("/v1/models")
 
 
224
 
225
+ async def models(
226
+ authorization:str=Header(None)
227
+ ):
228
 
229
+ verify(
230
+ authorization
231
+ )
232
 
233
+ return{
234
 
235
+ "object":"list",
236
 
237
+ "data":[
238
 
239
+ {
240
 
241
+ "id":m,
242
 
243
+ "object":"model",
244
 
245
+ "owned_by":MODEL_MAP[m]
246
 
247
+ }
248
 
249
+ for m in sorted(
250
+ MODEL_MAP.keys()
251
+ )
252
 
253
+ ]
254
  }
255
 
256
 
257
+ # ======================
258
+ # chat
259
+ # ======================
260
 
261
  @app.post("/v1/chat/completions")
262
 
 
285
  for m in body.messages
286
  ]
287
 
288
+ provider=None
289
 
290
+ if body.provider:
291
+
292
+ provider=getattr(
293
+ Provider,
294
+ body.provider,
295
+ None
296
+ )
297
+
298
+ elif body.model in MODEL_MAP:
299
+
300
+ provider=getattr(
301
+ Provider,
302
+ MODEL_MAP[
303
+ body.model
304
+ ],
305
+ None
306
+ )
307
+
308
+ # ==================
309
+ # stream
310
+ # ==================
311
 
312
  if body.stream:
313
 
314
+ def generate():
315
 
316
  try:
317
 
318
+ response=g4f.ChatCompletion.create(
 
 
319
 
320
  model=body.model,
321
 
322
+ provider=provider,
323
+
324
  messages=messages,
325
 
326
  stream=True
 
329
 
330
  for chunk in response:
331
 
332
+ payload={
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
333
 
334
+ "id":"chatcmpl",
335
 
336
+ "object":"chat.completion.chunk",
 
337
 
338
+ "created":int(
339
+ time.time()
340
+ ),
341
 
342
+ "model":body.model,
343
 
344
  "choices":[{
345
 
346
  "delta":{
347
 
348
+ "content":str(
349
+ chunk
350
+ )
351
  }
352
+
353
  }]
354
  }
355
 
356
  yield f"data:{json.dumps(payload)}\n\n"
357
 
358
+ yield "data:[DONE]\n\n"
359
 
360
+ except Exception as e:
361
 
362
+ logger.error(e)
363
 
364
+ yield f"data:{json.dumps({'error':str(e)})}\n\n"
 
 
 
 
 
 
 
365
 
 
 
 
366
 
367
+ return StreamingResponse(
368
 
369
+ generate(),
370
 
371
+ media_type="text/event-stream"
372
 
373
  )
374
 
 
 
 
 
375
 
376
+ try:
 
 
377
 
378
+ response=await asyncio.to_thread(
379
 
380
  g4f.ChatCompletion.create,
381
 
382
  model=body.model,
383
 
384
+ provider=provider,
 
 
385
 
386
+ messages=messages
387
 
388
+ )
389
 
390
+ return{
391
 
392
+ "id":"chatcmpl",
393
 
394
+ "object":"chat.completion",
395
 
396
+ "created":int(time.time()),
397
 
398
+ "model":body.model,
399
 
400
+ "choices":[{
401
 
402
  "index":0,
403
 
 
405
 
406
  "role":"assistant",
407
 
408
+ "content":str(
409
+ response
410
+ )
411
 
412
+ }
413
 
414
+ }]
415
 
416
+ }
417
 
418
+ except Exception as e:
419
 
420
+ logger.exception(e)
421
+
422
+ raise HTTPException(
423
+ status_code=500,
424
+ detail=str(e)
425
+ )