Spaces:
Running
Running
| from fastapi import FastAPI, File, UploadFile | |
| from fastapi.responses import Response | |
| from PIL import Image | |
| import io | |
| import sys | |
| import os | |
| import torch | |
| from safetensors.torch import load_file | |
| # 添加src目录到系统路径 | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..')) | |
| from src.pipeline.pipeline import TattooRemovalPipeline | |
| from src.model import UNet # 假设您的模型类名为UNet | |
| app = FastAPI() | |
| def load_model(): | |
| model = UNet() | |
| # 从huggingface下载模型权重 | |
| weights_path = "tattoo_remover.safetensors" | |
| if not os.path.exists(weights_path): | |
| print("请先下载模型权重文件") | |
| print("从 https://huggingface.co/erickillian/tattoo-removal/blob/main/tattoo_remover.safetensors 下载") | |
| print("并放置在当前目录") | |
| return None | |
| state_dict = load_file(weights_path) | |
| model.load_state_dict(state_dict) | |
| return model | |
| # 加载模型 | |
| model = load_model() | |
| if model is not None: | |
| tattoo_remover = TattooRemovalPipeline(model=model) | |
| else: | |
| print("模型加载失败") | |
| async def remove_tattoo(file: UploadFile = File(...)): | |
| if model is None: | |
| return {"error": "模型未正确加载"} | |
| # 读取上传的图片 | |
| image_data = await file.read() | |
| image = Image.open(io.BytesIO(image_data)).convert('RGB') | |
| # 处理图片 | |
| result = tattoo_remover(image) | |
| # 转换结果为字节流 | |
| output = io.BytesIO() | |
| result.save(output, format='PNG') | |
| output.seek(0) | |
| return Response(content=output.getvalue(), media_type="image/png") | |
| async def root(): | |
| return {"message": "纹身移除API服务正在运行"} |