ynyg commited on
Commit
ad308d4
·
verified ·
1 Parent(s): 0d080bf

feat: init

Browse files
Files changed (7) hide show
  1. .gitattributes +50 -35
  2. .gitignore +36 -0
  3. README.md +116 -3
  4. best.ckpt +3 -0
  5. config.json +24 -0
  6. configuration.json +1 -0
  7. model.safetensors +3 -0
.gitattributes CHANGED
@@ -1,35 +1,50 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *.tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.db* filter=lfs diff=lfs merge=lfs -text
29
+ *.ark* filter=lfs diff=lfs merge=lfs -text
30
+ **/*ckpt*data* filter=lfs diff=lfs merge=lfs -text
31
+ **/*ckpt*.meta filter=lfs diff=lfs merge=lfs -text
32
+ **/*ckpt*.index filter=lfs diff=lfs merge=lfs -text
33
+
34
+
35
+ *.gguf* filter=lfs diff=lfs merge=lfs -text
36
+ *.ggml filter=lfs diff=lfs merge=lfs -text
37
+ *.llamafile* filter=lfs diff=lfs merge=lfs -text
38
+ *.pt2 filter=lfs diff=lfs merge=lfs -text
39
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
40
+ *.npy filter=lfs diff=lfs merge=lfs -text
41
+ *.npz filter=lfs diff=lfs merge=lfs -text
42
+ *.pickle filter=lfs diff=lfs merge=lfs -text
43
+ *.pkl filter=lfs diff=lfs merge=lfs -text
44
+ *.tar filter=lfs diff=lfs merge=lfs -text
45
+ *.wasm filter=lfs diff=lfs merge=lfs -text
46
+ *.zst filter=lfs diff=lfs merge=lfs -text
47
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
48
+
49
+ model.safetensors filter=lfs diff=lfs merge=lfs -text
50
+ best.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[oc]
4
+ *.pyc
5
+ *.pyo
6
+ *.pyd
7
+
8
+ # IDE
9
+ .vscode/
10
+ .idea/
11
+ *.swp
12
+ *.swo
13
+
14
+ # OS
15
+ .DS_Store
16
+ Thumbs.db
17
+ desktop.ini
18
+
19
+ # Training outputs
20
+ checkpoints/
21
+ output/
22
+ logs/
23
+ tb_logs/
24
+ wandb/
25
+ runs/
26
+
27
+ # Temporary files
28
+ *.tmp
29
+ *.temp
30
+
31
+ # Secrets
32
+ .env
33
+ *.env
34
+ secrets.json
35
+ credentials.json
36
+ api_keys.json
README.md CHANGED
@@ -1,3 +1,116 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ink-eraser-latest(手写墨迹擦除模型)
2
+
3
+ 本目录是一个用于“手写墨迹擦除 / 文档去涂写”的模型导出包(Hugging Face 兼容格式)。模型输入为带墨迹的 RGB 图像,输出为去除墨迹后的 RGB 图像。
4
+
5
+ ## 模型信息
6
+
7
+ - 架构:U-Net++(`segmentation-models-pytorch`)+ ResNet50 编码器
8
+ - 任务:图像到图像(去除手写笔迹/墨迹)
9
+ - 输入:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`
10
+ - 输出:RGB,形状 `[B, 3, H, W]`,数值范围 `[0, 1]`(末端 `sigmoid`)
11
+
12
+ ## 文件说明
13
+
14
+ - `config.json`:模型结构与训练超参数(导出时写入)
15
+ - `model.safetensors`:推理用权重(推荐)
16
+ - `best.ckpt`:原始 PyTorch Lightning checkpoint(用于继续训练/复现实验)
17
+ - `configuration.json`:简要元数据(framework/task)
18
+
19
+ ## 快速推理(SafeTensors,推荐)
20
+
21
+ 依赖:`torch`、`torchvision`、`segmentation-models-pytorch`、`safetensors`,以及 `Pillow`(读写图片可选)。
22
+
23
+ ```bash
24
+ pip install torch torchvision segmentation-models-pytorch safetensors pillow
25
+ ```
26
+
27
+ ```python
28
+ import json
29
+ from pathlib import Path
30
+
31
+ import torch
32
+ import segmentation_models_pytorch as smp
33
+ from safetensors.torch import load_file
34
+ from PIL import Image
35
+ import torchvision.transforms.functional as TF
36
+
37
+ device = "cuda" if torch.cuda.is_available() else "cpu"
38
+
39
+ # 1) 读取配置
40
+ cfg = json.loads(Path("config.json").read_text(encoding="utf-8"))
41
+
42
+ # 2) 构建网络(与导出配置保持一致)
43
+ model = smp.UnetPlusPlus(
44
+ encoder_name=cfg["encoder_name"],
45
+ encoder_weights=None, # 权重来自 model.safetensors
46
+ in_channels=cfg["in_channels"],
47
+ classes=cfg["classes"],
48
+ decoder_attention_type=cfg.get("decoder_attention_type"),
49
+ activation=cfg.get("activation"), # 通常为 "sigmoid"
50
+ ).to(device)
51
+
52
+ # 3) 加载权重
53
+ # 说明:导出时可能混入非网络权重(例如 `edge_loss.kx/ky`),推理只需要 Unet++ 本体参数,过滤掉即可。
54
+ state_dict = load_file("model.safetensors")
55
+ model_keys = set(model.state_dict().keys())
56
+ state_dict = {k: v for k, v in state_dict.items() if k in model_keys}
57
+ model.load_state_dict(state_dict, strict=True)
58
+ model.eval()
59
+
60
+ # 4) 准备输入(训练时仅做 0~1 归一化;如需更贴近训练分布可 resize 到 512x512)
61
+ img = Image.open("input.png").convert("RGB")
62
+ x = TF.to_tensor(img).unsqueeze(0).to(device) # [1,3,H,W] in [0,1]
63
+
64
+ with torch.no_grad():
65
+ y = model(x).clamp(0, 1) # [1,3,H,W]
66
+
67
+ out = TF.to_pil_image(y.squeeze(0).cpu())
68
+ out.save("output.png")
69
+ ```
70
+
71
+ 提示:若输入尺寸不是 32 的倍数,部分编码器结构可能要求先 `pad/resize` 到合适尺寸(例如 `512x512`)。
72
+
73
+ 也可以直接使用本项目提供的高清切块推理脚本(自动对大图切块并融合回原图),从项目根目录运行:
74
+
75
+ ```bash
76
+ python infer_hd.py --model-dir assets/InkErase --input input.png --output output.png
77
+ ```
78
+
79
+ ## 使用 `best.ckpt`(继续训练/复现实验)
80
+
81
+ `best.ckpt` 是 PyTorch Lightning checkpoint,通常需要配合本项目的 `InkEraserModel` 代码使用,并提供 ResNet50 预训练权重文件(例如 `pretrained_weights/resnet50-0676ba61.pth`)。
82
+
83
+ ```python
84
+ import torch
85
+ from model import InkEraserModel
86
+
87
+ model = InkEraserModel.load_from_checkpoint(
88
+ "best.ckpt",
89
+ weight="pretrained_weights/resnet50-0676ba61.pth",
90
+ )
91
+ model.eval()
92
+
93
+ with torch.no_grad():
94
+ y = model(x)
95
+ ```
96
+
97
+ ## 训练超参数(来自 `config.json`)
98
+
99
+ 以下参数主要用于训练/复现,推理不必关心:
100
+
101
+ ```json
102
+ {
103
+ "lr": 0.0001,
104
+ "weight_decay": 0.01,
105
+ "loss_w_charb": 0.78,
106
+ "loss_w_ssim": 0.16,
107
+ "loss_w_edge": 0.06,
108
+ "use_mask_loss": true,
109
+ "loss_mask_weight": 10.0,
110
+ "charbonnier_eps": 0.001
111
+ }
112
+ ```
113
+
114
+ ## 许可证
115
+
116
+ MIT
best.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03a24c5cf35357a002b86a22250a2e165bf0e97a868bfb3492eb9ab0843798ec
3
+ size 614277544
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "InkEraserModel"
4
+ ],
5
+ "model_type": "unet_plus_plus",
6
+ "encoder_name": "resnet50",
7
+ "in_channels": 3,
8
+ "classes": 3,
9
+ "decoder_attention_type": "scse",
10
+ "activation": "sigmoid",
11
+ "framework": "pytorch-lightning",
12
+ "training_config": {
13
+ "lr": 0.0001,
14
+ "weight_decay": 0.01,
15
+ "loss_w_charb": 0.78,
16
+ "loss_w_ssim": 0.16,
17
+ "loss_w_edge": 0.06,
18
+ "use_mask_loss": true,
19
+ "loss_mask_weight": 10.0,
20
+ "charbonnier_eps": 0.001
21
+ },
22
+ "description": "Ink Erasure Model - Handwritten ink removal using U-Net++",
23
+ "license": "MIT"
24
+ }
configuration.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"framework":"Pytorch","task":"image-to-image"}
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78acd1c02b274179ae975134eb827ed0bdb34d4e5890b7981f8dde08bb35eb7a
3
+ size 204806056