import sys import os from PIL import Image import torch from safetensors.torch import load_file # 添加src目录到系统路径 sys.path.append(os.path.join(os.path.dirname(__file__), '.')) from src.model import UNet from src.pipeline.pipeline import TattooRemovalPipeline def test_single_image(image_path): """测试单张图片的纹身移除效果""" # 加载模型 model = UNet() weights_path = "tattoo_remover.safetensors" if not os.path.exists(weights_path): print("错误:找不到模型权重文件") print("请从以下地址下载模型权重:") print("https://huggingface.co/erickillian/tattoo-removal/blob/main/tattoo_remover.safetensors") return # 加载权重 state_dict = load_file(weights_path) model.load_state_dict(state_dict) # 创建pipeline pipeline = TattooRemovalPipeline(model=model) # 检查输入图片是否存在 if not os.path.exists(image_path): print(f"错误:找不到输入图片 {image_path}") return try: # 处理图片 input_image = Image.open(image_path).convert('RGB') output_image = pipeline(input_image) # 保存结果 output_path = os.path.splitext(image_path)[0] + "_removed.png" output_image.save(output_path) print(f"处理完成!结果已保存至:{output_path}") except Exception as e: print(f"处理过程中出现错误:{str(e)}") if __name__ == "__main__": if len(sys.argv) != 2: print("使用方法:python test_model.py <图片路径>") print("例如:python test_model.py test_image.jpg") sys.exit(1) test_single_image(sys.argv[1])