Spaces:
Running
Running
| import sys | |
| import os | |
| from PIL import Image | |
| import torch | |
| from safetensors.torch import load_file | |
| # 添加src目录到系统路径 | |
| sys.path.append(os.path.join(os.path.dirname(__file__), '.')) | |
| from src.model import UNet | |
| from src.pipeline.pipeline import TattooRemovalPipeline | |
| def test_single_image(image_path): | |
| """测试单张图片的纹身移除效果""" | |
| # 加载模型 | |
| model = UNet() | |
| weights_path = "tattoo_remover.safetensors" | |
| if not os.path.exists(weights_path): | |
| print("错误:找不到模型权重文件") | |
| print("请从以下地址下载模型权重:") | |
| print("https://huggingface.co/erickillian/tattoo-removal/blob/main/tattoo_remover.safetensors") | |
| return | |
| # 加载权重 | |
| state_dict = load_file(weights_path) | |
| model.load_state_dict(state_dict) | |
| # 创建pipeline | |
| pipeline = TattooRemovalPipeline(model=model) | |
| # 检查输入图片是否存在 | |
| if not os.path.exists(image_path): | |
| print(f"错误:找不到输入图片 {image_path}") | |
| return | |
| try: | |
| # 处理图片 | |
| input_image = Image.open(image_path).convert('RGB') | |
| output_image = pipeline(input_image) | |
| # 保存结果 | |
| output_path = os.path.splitext(image_path)[0] + "_removed.png" | |
| output_image.save(output_path) | |
| print(f"处理完成!结果已保存至:{output_path}") | |
| except Exception as e: | |
| print(f"处理过程中出现错误:{str(e)}") | |
| if __name__ == "__main__": | |
| if len(sys.argv) != 2: | |
| print("使用方法:python test_model.py <图片路径>") | |
| print("例如:python test_model.py test_image.jpg") | |
| sys.exit(1) | |
| test_single_image(sys.argv[1]) |