tattoo-removal-app / test_model.py
guanwei1225's picture
清理代码,移除敏感信息
a92fb7a
import sys
import os
from PIL import Image
import torch
from safetensors.torch import load_file
# 添加src目录到系统路径
sys.path.append(os.path.join(os.path.dirname(__file__), '.'))
from src.model import UNet
from src.pipeline.pipeline import TattooRemovalPipeline
def test_single_image(image_path):
"""测试单张图片的纹身移除效果"""
# 加载模型
model = UNet()
weights_path = "tattoo_remover.safetensors"
if not os.path.exists(weights_path):
print("错误:找不到模型权重文件")
print("请从以下地址下载模型权重:")
print("https://huggingface.co/erickillian/tattoo-removal/blob/main/tattoo_remover.safetensors")
return
# 加载权重
state_dict = load_file(weights_path)
model.load_state_dict(state_dict)
# 创建pipeline
pipeline = TattooRemovalPipeline(model=model)
# 检查输入图片是否存在
if not os.path.exists(image_path):
print(f"错误:找不到输入图片 {image_path}")
return
try:
# 处理图片
input_image = Image.open(image_path).convert('RGB')
output_image = pipeline(input_image)
# 保存结果
output_path = os.path.splitext(image_path)[0] + "_removed.png"
output_image.save(output_path)
print(f"处理完成!结果已保存至:{output_path}")
except Exception as e:
print(f"处理过程中出现错误:{str(e)}")
if __name__ == "__main__":
if len(sys.argv) != 2:
print("使用方法:python test_model.py <图片路径>")
print("例如:python test_model.py test_image.jpg")
sys.exit(1)
test_single_image(sys.argv[1])