from fastapi import FastAPI, File, UploadFile from fastapi.responses import Response from PIL import Image import io import sys import os import torch from safetensors.torch import load_file # 添加src目录到系统路径 sys.path.append(os.path.join(os.path.dirname(__file__), '..')) from src.pipeline.pipeline import TattooRemovalPipeline from src.model import UNet # 假设您的模型类名为UNet app = FastAPI() def load_model(): model = UNet() # 从huggingface下载模型权重 weights_path = "tattoo_remover.safetensors" if not os.path.exists(weights_path): print("请先下载模型权重文件") print("从 https://huggingface.co/erickillian/tattoo-removal/blob/main/tattoo_remover.safetensors 下载") print("并放置在当前目录") return None state_dict = load_file(weights_path) model.load_state_dict(state_dict) return model # 加载模型 model = load_model() if model is not None: tattoo_remover = TattooRemovalPipeline(model=model) else: print("模型加载失败") @app.post("/remove-tattoo") async def remove_tattoo(file: UploadFile = File(...)): if model is None: return {"error": "模型未正确加载"} # 读取上传的图片 image_data = await file.read() image = Image.open(io.BytesIO(image_data)).convert('RGB') # 处理图片 result = tattoo_remover(image) # 转换结果为字节流 output = io.BytesIO() result.save(output, format='PNG') output.seek(0) return Response(content=output.getvalue(), media_type="image/png") @app.get("/") async def root(): return {"message": "纹身移除API服务正在运行"}