import asyncio import time import uuid import json from fastapi import FastAPI, Request from fastapi.responses import JSONResponse, StreamingResponse from gradio_client import Client app = FastAPI() # HuggingFace Space client = Client("CohereLabs/command-a-vision") # ✅ FIXED: call gradio with positional args def call_gradio(message, max_tokens=100): try: # format input like Gradio expects payload = { "text": message, "files": [] } # IMPORTANT: positional inputs (NOT keyword args) result = client.predict( payload, # input 1 max_tokens, # input 2 api_name="/chat" ) # result comes as dict sometimes if isinstance(result, dict): return json.dumps(result) return str(result) except Exception as e: print("Gradio API error:", e) return "Error: upstream model failed." def format_openai_response(content): return { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion", "created": int(time.time()), "model": "command-a-vision", "choices": [ { "index": 0, "message": { "role": "assistant", "content": content }, "finish_reason": "stop" } ] } @app.post("/v1/chat/completions") async def chat(request: Request): body = await request.json() messages = body.get("messages", []) stream = body.get("stream", False) max_tokens = body.get("max_tokens", 100) user_message = messages[-1]["content"] # ✅ normal response if not stream: result = call_gradio(user_message, max_tokens) return JSONResponse(format_openai_response(result)) # ✅ streaming response async def generate(): result = call_gradio(user_message, max_tokens) words = result.split(" ") for word in words: chunk = { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk", "created": int(time.time()), "model": "command-a-vision", "choices": [ { "delta": {"content": word + " "}, "index": 0, "finish_reason": None } ] } yield f"data: {json.dumps(chunk)}\n\n" await asyncio.sleep(0.02) # end end_chunk = { "id": f"chatcmpl-{uuid.uuid4().hex}", "object": "chat.completion.chunk", "choices": [ { "delta": {}, "index": 0, "finish_reason": "stop" } ] } yield f"data: {json.dumps(end_chunk)}\n\n" yield "data: [DONE]\n\n" return StreamingResponse(generate(), media_type="text/event-stream")