Raju2024 commited on
Commit
4f1e93d
·
verified ·
1 Parent(s): b42238e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -19
app.py CHANGED
@@ -13,21 +13,28 @@ app = FastAPI()
13
  client = Client("CohereLabs/command-a-vision")
14
 
15
 
16
- # call gradio safely
17
- def call_gradio(message, max_tokens=12800, temperature=0.1, top_p=0.9):
18
 
19
  try:
20
- job = client.submit(
21
- message=message,
22
- max_tokens=max_tokens,
23
- temperature=temperature,
24
- top_p=top_p,
 
 
 
 
 
25
  api_name="/chat"
26
  )
27
 
28
- result = job.result()
 
 
29
 
30
- return result
31
 
32
  except Exception as e:
33
  print("Gradio API error:", e)
@@ -39,7 +46,7 @@ def format_openai_response(content):
39
  "id": f"chatcmpl-{uuid.uuid4().hex}",
40
  "object": "chat.completion",
41
  "created": int(time.time()),
42
- "model": "minimax-text-01",
43
  "choices": [
44
  {
45
  "index": 0,
@@ -61,23 +68,21 @@ async def chat(request: Request):
61
  messages = body.get("messages", [])
62
  stream = body.get("stream", False)
63
 
64
- max_tokens = body.get("max_tokens", 12800)
65
- temperature = body.get("temperature", 0.1)
66
- top_p = body.get("top_p", 0.9)
67
 
68
  user_message = messages[-1]["content"]
69
 
70
- # normal response
71
  if not stream:
72
 
73
- result = call_gradio(user_message, max_tokens, temperature, top_p)
74
 
75
  return JSONResponse(format_openai_response(result))
76
 
77
- # streaming response
78
  async def generate():
79
 
80
- result = call_gradio(user_message, max_tokens, temperature, top_p)
81
 
82
  words = result.split(" ")
83
 
@@ -87,7 +92,7 @@ async def chat(request: Request):
87
  "id": f"chatcmpl-{uuid.uuid4().hex}",
88
  "object": "chat.completion.chunk",
89
  "created": int(time.time()),
90
- "model": "minimax-text-01",
91
  "choices": [
92
  {
93
  "delta": {"content": word + " "},
@@ -98,9 +103,9 @@ async def chat(request: Request):
98
  }
99
 
100
  yield f"data: {json.dumps(chunk)}\n\n"
101
-
102
  await asyncio.sleep(0.02)
103
 
 
104
  end_chunk = {
105
  "id": f"chatcmpl-{uuid.uuid4().hex}",
106
  "object": "chat.completion.chunk",
 
13
  client = Client("CohereLabs/command-a-vision")
14
 
15
 
16
+ # ✅ FIXED: call gradio with positional args
17
+ def call_gradio(message, max_tokens=100):
18
 
19
  try:
20
+ # format input like Gradio expects
21
+ payload = {
22
+ "text": message,
23
+ "files": []
24
+ }
25
+
26
+ # IMPORTANT: positional inputs (NOT keyword args)
27
+ result = client.predict(
28
+ payload, # input 1
29
+ max_tokens, # input 2
30
  api_name="/chat"
31
  )
32
 
33
+ # result comes as dict sometimes
34
+ if isinstance(result, dict):
35
+ return json.dumps(result)
36
 
37
+ return str(result)
38
 
39
  except Exception as e:
40
  print("Gradio API error:", e)
 
46
  "id": f"chatcmpl-{uuid.uuid4().hex}",
47
  "object": "chat.completion",
48
  "created": int(time.time()),
49
+ "model": "command-a-vision",
50
  "choices": [
51
  {
52
  "index": 0,
 
68
  messages = body.get("messages", [])
69
  stream = body.get("stream", False)
70
 
71
+ max_tokens = body.get("max_tokens", 100)
 
 
72
 
73
  user_message = messages[-1]["content"]
74
 
75
+ # normal response
76
  if not stream:
77
 
78
+ result = call_gradio(user_message, max_tokens)
79
 
80
  return JSONResponse(format_openai_response(result))
81
 
82
+ # streaming response
83
  async def generate():
84
 
85
+ result = call_gradio(user_message, max_tokens)
86
 
87
  words = result.split(" ")
88
 
 
92
  "id": f"chatcmpl-{uuid.uuid4().hex}",
93
  "object": "chat.completion.chunk",
94
  "created": int(time.time()),
95
+ "model": "command-a-vision",
96
  "choices": [
97
  {
98
  "delta": {"content": word + " "},
 
103
  }
104
 
105
  yield f"data: {json.dumps(chunk)}\n\n"
 
106
  await asyncio.sleep(0.02)
107
 
108
+ # end
109
  end_chunk = {
110
  "id": f"chatcmpl-{uuid.uuid4().hex}",
111
  "object": "chat.completion.chunk",