guanwei1225's picture
清理代码,移除敏感信息
a92fb7a
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import Response
from PIL import Image
import io
import sys
import os
import torch
from safetensors.torch import load_file
# 添加src目录到系统路径
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from src.pipeline.pipeline import TattooRemovalPipeline
from src.model import UNet # 假设您的模型类名为UNet
app = FastAPI()
def load_model():
model = UNet()
# 从huggingface下载模型权重
weights_path = "tattoo_remover.safetensors"
if not os.path.exists(weights_path):
print("请先下载模型权重文件")
print("从 https://huggingface.co/erickillian/tattoo-removal/blob/main/tattoo_remover.safetensors 下载")
print("并放置在当前目录")
return None
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
return model
# 加载模型
model = load_model()
if model is not None:
tattoo_remover = TattooRemovalPipeline(model=model)
else:
print("模型加载失败")
@app.post("/remove-tattoo")
async def remove_tattoo(file: UploadFile = File(...)):
if model is None:
return {"error": "模型未正确加载"}
# 读取上传的图片
image_data = await file.read()
image = Image.open(io.BytesIO(image_data)).convert('RGB')
# 处理图片
result = tattoo_remover(image)
# 转换结果为字节流
output = io.BytesIO()
result.save(output, format='PNG')
output.seek(0)
return Response(content=output.getvalue(), media_type="image/png")
@app.get("/")
async def root():
return {"message": "纹身移除API服务正在运行"}