3d / app.py
vish85521's picture
Update app.py
6496b60 verified
import os
import sys
import io
import uuid
import subprocess
from fastapi import FastAPI, UploadFile, File, HTTPException
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
# Add the cloned repository to the system path so Python can locate the modules
sys.path.append("/app/Hunyuan3D-2")
from hy3dgen.shapegen import Hunyuan3DDiTFlowMatchingPipeline
from hy3dgen.texgen import Hunyuan3DPaintPipeline
app = FastAPI(title="Hunyuan3D-2 Multi-View Textured API")
# Add CORS so your HTML file can communicate with the API
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
shape_pipeline = None
paint_pipeline = None
OUTPUT_DIR = "/app/outputs"
def start_tmate():
"""Starts a tmate session in the background and prints the SSH command."""
print("Starting tmate for SSH access...")
try:
subprocess.run(["tmate", "-S", "/tmp/tmate.sock", "new-session", "-d"], check=True)
subprocess.run(["tmate", "-S", "/tmp/tmate.sock", "wait", "tmate-ready"], check=True)
result = subprocess.run(
["tmate", "-S", "/tmp/tmate.sock", "display", "-p", "#{tmate_ssh}"],
capture_output=True, text=True, check=True
)
ssh_command = result.stdout.strip()
print("\n" + "="*60)
print("πŸš€ TMATE SSH CONNECTION STRING READY πŸš€")
print(f"Run this command in your local terminal:\n\n{ssh_command}")
print("="*60 + "\n")
except Exception as e:
print(f"Failed to start tmate: {e}")
@app.on_event("startup")
async def startup_event():
"""Runs on server startup to initialize tmate and load models."""
start_tmate()
global shape_pipeline, paint_pipeline
print("Loading Hunyuan3D-2 Shape and Paint models...")
try:
shape_pipeline = Hunyuan3DDiTFlowMatchingPipeline.from_pretrained(
'tencent/Hunyuan3D-2mv',
subfolder='hunyuan3d-dit-v2-mv',
device='cuda'
)
paint_pipeline = Hunyuan3DPaintPipeline.from_pretrained(
'tencent/Hunyuan3D-2',
device='cuda'
)
print("Both models loaded successfully.")
except Exception as e:
print(f"Error loading models: {e}")
@app.post("/generate-3d")
async def generate_3d_model(
front: UploadFile = File(...),
back: UploadFile = File(...),
left: UploadFile = File(...),
right: UploadFile = File(...)
):
"""
Endpoint that accepts 4 structural views, generates the shape,
paints the texture, and returns a colored .glb 3D file.
"""
if shape_pipeline is None or paint_pipeline is None:
raise HTTPException(status_code=503, detail="Models are still loading or failed to load.")
try:
def read_image(file: UploadFile):
contents = file.file.read()
return Image.open(io.BytesIO(contents)).convert("RGB")
images_dict = {
"front": read_image(front),
"back": read_image(back),
"left": read_image(left),
"right": read_image(right)
}
job_id = str(uuid.uuid4())
output_path = os.path.join(OUTPUT_DIR, f"{job_id}.glb")
print("Generating shape geometry...")
mesh = shape_pipeline(image=images_dict)[0]
print("Applying textures...")
mesh_with_texture = paint_pipeline(mesh, image=images_dict["front"])
mesh_with_texture.export(output_path)
return FileResponse(
path=output_path,
media_type="model/gltf-binary",
filename=f"{job_id}.glb"
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")