| import os |
| os.environ['PYOPENGL_PLATFORM'] = 'osmesa' |
|
|
| from fastapi import FastAPI, HTTPException |
| from fastapi.responses import StreamingResponse |
| from pydantic import BaseModel, Field |
| from typing import Optional |
| import io |
| import numpy as np |
| from PIL import Image |
| from pathlib import Path |
|
|
| from measurement_processor import process_measurements |
| from beta_regressor import predict_betas |
| from smpl_generator import generate_mesh |
| from renderer import render_avatar |
|
|
| SMPL_MODEL_PATH = os.getenv("SMPL_MODEL_PATH", "smpl") |
|
|
| smpl_models_dirs = [Path("smpl/smpl/models"), Path("smpl/models"), Path(SMPL_MODEL_PATH) / "models"] |
| found_models = False |
| for models_dir in smpl_models_dirs: |
| if models_dir.exists(): |
| print(f"Found SMPL models in {models_dir}") |
| model_files = list(models_dir.glob("*.pkl")) |
| for f in model_files: |
| print(f" - {f.name}") |
| found_models = True |
| break |
|
|
| if not found_models: |
| print(f"Warning: SMPL models not found in expected locations") |
| print(f"Looking in {SMPL_MODEL_PATH}...") |
|
|
| app = FastAPI( |
| title="Avatar Generation Service", |
| description="Generate 2D avatar images from body measurements using SMPL" |
| ) |
|
|
|
|
| class MeasurementRequest(BaseModel): |
| height: float = Field(..., gt=0, description="Height in cm") |
| weight: float = Field(..., gt=0, description="Weight in kg") |
| chest: float = Field(..., gt=0, description="Chest measurement in cm") |
| waist: float = Field(..., gt=0, description="Waist measurement in cm") |
| hips: float = Field(..., gt=0, description="Hips measurement in cm") |
| shoulder_width: Optional[float] = Field(None, gt=0, description="Shoulder width in cm") |
| arm_length: Optional[float] = Field(None, gt=0, description="Arm length in cm") |
| leg_length: Optional[float] = Field(None, gt=0, description="Leg length in cm") |
| inseam: Optional[float] = Field(None, gt=0, description="Inseam in cm") |
| gender: Optional[str] = Field("male", description="Gender: 'male' or 'female'") |
| |
| class Config: |
| json_schema_extra = { |
| "example": { |
| "height": 178, |
| "weight": 74, |
| "chest": 96, |
| "waist": 82, |
| "hips": 94, |
| "shoulder_width": 47, |
| "arm_length": 60, |
| "leg_length": 98, |
| "inseam": 81, |
| "gender": "male" |
| } |
| } |
|
|
|
|
| @app.get("/") |
| async def root(): |
| return { |
| "service": "Avatar Generation Service", |
| "endpoints": { |
| "/generate-avatar": "POST - Generate 2D avatar image (PNG) from measurements", |
| "/generate-avatar-3d": "POST - Generate 3D avatar mesh (OBJ) from measurements", |
| "/health": "GET - Health check" |
| } |
| } |
|
|
|
|
| @app.get("/health") |
| async def health(): |
| return {"status": "healthy"} |
|
|
|
|
| @app.post("/generate-avatar") |
| async def generate_avatar(measurements: MeasurementRequest): |
| try: |
| measurements_dict = measurements.model_dump(exclude_none=True) |
| gender = measurements_dict.pop("gender", "male") |
| |
| if gender not in ["male", "female", "neutral"]: |
| raise ValueError("Gender must be 'male', 'female', or 'neutral'") |
| |
| normalized = process_measurements(measurements_dict) |
| betas = predict_betas(normalized) |
| vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender) |
| img_np = render_avatar(vertices, faces) |
| |
| if img_np.dtype != np.uint8: |
| img_np = (img_np * 255).astype(np.uint8) if img_np.max() <= 1.0 else img_np.astype(np.uint8) |
| |
| img = Image.fromarray(img_np, mode='RGB') |
| buf = io.BytesIO() |
| img.save(buf, format="PNG") |
| buf.seek(0) |
| |
| return StreamingResponse(buf, media_type="image/png") |
| |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| except FileNotFoundError as e: |
| raise HTTPException( |
| status_code=500, |
| detail=f"SMPL model not found: {str(e)}. " |
| f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. " |
| f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable." |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating avatar: {str(e)}") |
|
|
|
|
| @app.post("/generate-avatar-3d") |
| async def generate_avatar_3d(measurements: MeasurementRequest): |
| try: |
| measurements_dict = measurements.model_dump(exclude_none=True) |
| gender = measurements_dict.pop("gender", "male") |
| |
| if gender not in ["male", "female", "neutral"]: |
| raise ValueError("Gender must be 'male', 'female', or 'neutral'") |
| |
| normalized = process_measurements(measurements_dict) |
| betas = predict_betas(normalized) |
| vertices, faces = generate_mesh(betas, model_path=SMPL_MODEL_PATH, gender=gender) |
| |
| import trimesh |
| mesh = trimesh.Trimesh(vertices=vertices, faces=faces) |
| |
| buf = io.BytesIO() |
| mesh.export(file_obj=buf, file_type='obj') |
| buf.seek(0) |
| |
| return StreamingResponse( |
| buf, |
| media_type="model/obj", |
| headers={"Content-Disposition": "attachment; filename=avatar.obj"} |
| ) |
| |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) |
| except FileNotFoundError as e: |
| raise HTTPException( |
| status_code=500, |
| detail=f"SMPL model not found: {str(e)}. " |
| f"Please ensure SMPL model files are in {SMPL_MODEL_PATH}. " |
| f"Download from https://smpl.is.tue.mpg.de/ or set SMPL_MODEL_PATH environment variable." |
| ) |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating 3D avatar: {str(e)}") |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run(app, host="0.0.0.0", port=7860) |
|
|
|
|