| from fastapi import FastAPI, Response, HTTPException |
| from pydantic import BaseModel |
| from typing import Union, Dict, Any |
| import os |
| import io |
| import sys |
| from handler import EndpointHandler |
|
|
| |
| def debug_log(message): |
| print(f"DEBUG: {message}") |
| sys.stdout.flush() |
|
|
| debug_log("Starting API initialization") |
|
|
| app = FastAPI() |
|
|
| |
| model_dir = os.environ.get("MODEL_DIR", "/code/diffsketcher") |
| debug_log(f"Using model_dir: {model_dir}") |
| handler = EndpointHandler(model_dir) |
| debug_log("Handler initialized") |
|
|
| class TextRequest(BaseModel): |
| inputs: Union[str, Dict[str, Any]] |
|
|
| @app.get("/") |
| def read_root(): |
| debug_log("Root endpoint called") |
| return {"message": "DiffSketcher Vector Graphics Generation API"} |
|
|
| @app.post("/") |
| async def generate(request: TextRequest): |
| try: |
| debug_log(f"Generate endpoint called with request: {request}") |
| |
| |
| result = handler(request.dict()) |
| debug_log("Handler returned result") |
| |
| |
| if hasattr(result, "save"): |
| debug_log("Result is a PIL Image, converting to bytes") |
| img_byte_arr = io.BytesIO() |
| result.save(img_byte_arr, format="PNG") |
| img_byte_arr.seek(0) |
| |
| |
| debug_log("Returning image response") |
| return Response(content=img_byte_arr.getvalue(), media_type="image/png") |
| else: |
| |
| debug_log(f"Returning JSON response: {result}") |
| return result |
| except Exception as e: |
| debug_log(f"Error in generate endpoint: {e}") |
| import traceback |
| debug_log(traceback.format_exc()) |
| raise HTTPException(status_code=500, detail=str(e)) |