| import time |
| import logging |
| import os |
| import sys |
| import subprocess |
| from contextlib import asynccontextmanager |
| from typing import List |
| from enum import Enum |
| from pydantic import BaseModel |
|
|
| |
| def install_packages(): |
| """Install required packages using pip""" |
| packages = [ |
| "fastapi", |
| "uvicorn[standard]", |
| "pillow", |
| "huggingface_hub", |
| "pydantic" |
| ] |
| |
| for package in packages: |
| try: |
| |
| if package == "uvicorn[standard]": |
| __import__("uvicorn") |
| elif package == "huggingface_hub": |
| __import__("huggingface_hub") |
| else: |
| __import__(package.replace("-", "_")) |
| print(f"{package} already installed") |
| except ImportError: |
| print(f"Installing {package}...") |
| subprocess.check_call([sys.executable, "-m", "pip", "install", package]) |
|
|
| |
| install_packages() |
|
|
| import uvicorn |
| from fastapi import FastAPI, HTTPException |
| from fastapi.staticfiles import StaticFiles |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import JSONResponse |
|
|
| |
| class ResponseFormat(str, Enum): |
| URL = "url" |
| B64_JSON = "b64_json" |
|
|
| class ImageGenerationRequest(BaseModel): |
| prompt: str |
| model: str = "dall-e-3" |
| n: int = 1 |
| size: str = "1024x1024" |
| quality: str = "standard" |
| response_format: ResponseFormat = ResponseFormat.URL |
|
|
| class ImageData(BaseModel): |
| url: str = None |
| b64_json: str = None |
| revised_prompt: str = None |
|
|
| class ImageGenerationResponse(BaseModel): |
| created: int |
| data: List[ImageData] |
|
|
| class ErrorResponse(BaseModel): |
| error: dict |
|
|
| class ModelInfo(BaseModel): |
| id: str |
| created: int |
| owned_by: str |
|
|
| class ModelsResponse(BaseModel): |
| data: List[ModelInfo] |
|
|
| |
| from image_generator import ImageGenerator |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| image_generator = None |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Application lifespan management""" |
| global image_generator |
| |
| logger.info("Starting TTI Frame API...") |
| |
| |
| hf_token = os.getenv("HF_TOKEN") |
| if not hf_token: |
| logger.warning("HF_TOKEN environment variable not set. Image generation may fail.") |
| |
| image_generator = ImageGenerator(hf_token=hf_token) |
| |
| |
| base_url = os.getenv("BASE_URL", "http://localhost:8000") |
| image_generator.set_config(base_url=base_url) |
| |
| |
| app.mount("/images", StaticFiles(directory=image_generator.output_dir), name="images") |
| |
| logger.info(f"Image generator initialized with output directory: {image_generator.output_dir}") |
| |
| yield |
| |
| logger.info("Shutting down TTI Frame API...") |
| if image_generator: |
| image_generator.cleanup() |
|
|
| |
| app = FastAPI( |
| title="TTI Frame - OpenAI Compatible Text-to-Image API", |
| description="A FastAPI wrapper providing OpenAI-compatible endpoints for text-to-image generation", |
| version="1.0.0", |
| lifespan=lifespan |
| ) |
|
|
| |
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"], |
| ) |
|
|
| @app.get("/") |
| async def root(): |
| """Root endpoint""" |
| return { |
| "message": "TTI Frame - OpenAI Compatible Text-to-Image API", |
| "version": "1.0.0", |
| "docs": "/docs", |
| "output_dir": image_generator.output_dir if image_generator else "Not initialized" |
| } |
|
|
| @app.get("/v1/models", response_model=ModelsResponse) |
| async def list_models(): |
| """List available models (OpenAI compatible)""" |
| models = [ |
| ModelInfo( |
| id="dall-e-3", |
| created=1677649963, |
| owned_by="tti-frame" |
| ), |
| ModelInfo( |
| id="dall-e-2", |
| created=1677649963, |
| owned_by="tti-frame" |
| ), |
| ModelInfo( |
| id="black-forest-labs/flux-schnell", |
| created=1677649963, |
| owned_by="tti-frame" |
| ) |
| ] |
| |
| return ModelsResponse(data=models) |
|
|
| @app.post("/v1/images/generations", response_model=ImageGenerationResponse) |
| async def create_image(request: ImageGenerationRequest): |
| """ |
| Generate images from text prompts (OpenAI compatible) |
| |
| Creates images based on a text prompt using advanced diffusion models. |
| Supports various sizes, qualities, and response formats. |
| """ |
| if not image_generator: |
| raise HTTPException( |
| status_code=500, |
| detail="Image generator not initialized. Check HF_TOKEN environment variable." |
| ) |
| |
| try: |
| logger.info(f"Received image generation request: {request.prompt[:50]}...") |
| |
| |
| if not request.prompt or not request.prompt.strip(): |
| raise HTTPException( |
| status_code=400, |
| detail="Prompt cannot be empty" |
| ) |
| |
| if len(request.prompt) > 4000: |
| raise HTTPException( |
| status_code=400, |
| detail="Prompt too long. Maximum 4000 characters allowed." |
| ) |
| |
| |
| model_mapping = { |
| "dall-e-3": "black-forest-labs/flux-schnell", |
| "dall-e-2": "black-forest-labs/flux-schnell", |
| } |
| |
| |
| if request.model in model_mapping: |
| request.model = model_mapping[request.model] |
| |
| |
| image_data = await image_generator.generate_images(request) |
| |
| response = ImageGenerationResponse( |
| created=int(time.time()), |
| data=image_data |
| ) |
| |
| logger.info(f"Successfully generated {len(image_data)} images") |
| return response |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| logger.error(f"Image generation failed: {e}") |
| raise HTTPException( |
| status_code=500, |
| detail=f"Image generation failed: {str(e)}" |
| ) |
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint""" |
| return { |
| "status": "healthy", |
| "timestamp": int(time.time()), |
| "generator_initialized": image_generator is not None, |
| "output_dir": image_generator.output_dir if image_generator else None |
| } |
|
|
| @app.get("/config") |
| async def get_config(): |
| """Get current configuration""" |
| if not image_generator: |
| return {"error": "Image generator not initialized"} |
| |
| return { |
| "output_dir": image_generator.output_dir, |
| "base_url": image_generator.base_url, |
| "default_model": image_generator.default_model, |
| "hf_token_set": bool(image_generator.hf_token) |
| } |
|
|
| @app.post("/config") |
| async def update_config(hf_token: str = None, base_url: str = None, default_model: str = None): |
| """Update configuration""" |
| if not image_generator: |
| raise HTTPException(status_code=500, detail="Image generator not initialized") |
| |
| image_generator.set_config( |
| hf_token=hf_token, |
| base_url=base_url, |
| default_model=default_model |
| ) |
| |
| return {"message": "Configuration updated successfully"} |
|
|
| @app.exception_handler(Exception) |
| async def global_exception_handler(request, exc): |
| """Global exception handler""" |
| logger.error(f"Unhandled exception: {exc}") |
| return JSONResponse( |
| status_code=500, |
| content=ErrorResponse( |
| error={ |
| "message": "Internal server error", |
| "type": "server_error", |
| "code": "internal_error" |
| } |
| ).dict() |
| ) |
|
|
| if __name__ == "__main__": |
| |
| if not os.getenv("HF_TOKEN"): |
| print("Warning: HF_TOKEN environment variable not set.") |
| print("Please set it with: export HF_TOKEN=your_huggingface_token") |
| |
| uvicorn.run( |
| "main:app", |
| host="0.0.0.0", |
| port=8000, |
| reload=True, |
| log_level="info" |
| ) |