ParamDev commited on
Commit
54d2540
·
verified ·
1 Parent(s): 3262d11

Upload export_models.py

Browse files
Files changed (1) hide show
  1. export_models.py +294 -0
export_models.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ export_models.py
3
+ ----------------
4
+ Downloads publicly available pretrained weights for SRCNN and EDSR (HResNet-style)
5
+ and exports them as ONNX files into the ./model/ directory.
6
+
7
+ Run once before starting app.py:
8
+ pip install torch torchvision huggingface_hub basicsr
9
+ python export_models.py
10
+
11
+ After this script finishes you should have:
12
+ model/SRCNN_x4.onnx
13
+ model/HResNet_x4.onnx
14
+
15
+ Then upload both files to Google Drive, copy the file IDs into DRIVE_IDS in app.py,
16
+ OR set LOCAL_ONLY = True below to skip Drive entirely and load straight from disk.
17
+ """
18
+
19
+ import os
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.onnx
23
+ from pathlib import Path
24
+
25
+ MODEL_DIR = Path("model")
26
+ MODEL_DIR.mkdir(exist_ok=True)
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Set to True to skip Drive and have app.py load the ONNX files from disk
30
+ # directly. In app.py, remove the download_from_drive call for these keys
31
+ # (or just leave the placeholder Drive ID — the script already guards against
32
+ # missing files gracefully).
33
+ # ---------------------------------------------------------------------------
34
+ LOCAL_ONLY = True # flip to False once you have Drive IDs
35
+
36
+
37
+ # ===========================================================================
38
+ # 1. SRCNN ×4
39
+ # Architecture: Dong et al. 2014 — 3 conv layers, no upsampling inside
40
+ # the network. Input is bicubic-upscaled LR; output is the refined HR.
41
+ # We bicubic-upsample inside a wrapper so the ONNX takes a raw LR image.
42
+ # ===========================================================================
43
+
44
+ class SRCNN(nn.Module):
45
+ """Original SRCNN (Dong et al., 2014)."""
46
+ def __init__(self, num_channels: int = 3):
47
+ super().__init__()
48
+ self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, padding=9 // 2)
49
+ self.conv2 = nn.Conv2d(64, 32, kernel_size=5, padding=5 // 2)
50
+ self.conv3 = nn.Conv2d(32, num_channels, kernel_size=5, padding=5 // 2)
51
+ self.relu = nn.ReLU(inplace=True)
52
+
53
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
54
+ x = self.relu(self.conv1(x))
55
+ x = self.relu(self.conv2(x))
56
+ return self.conv3(x)
57
+
58
+
59
+ class SRCNNx4Wrapper(nn.Module):
60
+ """
61
+ Wraps SRCNN so the ONNX input is a LOW-resolution image.
62
+ Internally bicubic-upsamples by ×4 before feeding SRCNN,
63
+ matching the interface expected by app.py's tile_upscale_model.
64
+ """
65
+ def __init__(self, srcnn: SRCNN, scale: int = 4):
66
+ super().__init__()
67
+ self.srcnn = srcnn
68
+ self.scale = scale
69
+
70
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
71
+ # x: (1, 3, H, W) — low-res, float32 in [0, 1]
72
+ up = torch.nn.functional.interpolate(
73
+ x, scale_factor=self.scale, mode="bicubic", align_corners=False
74
+ )
75
+ return self.srcnn(up)
76
+
77
+
78
+ def build_srcnn_x4() -> nn.Module:
79
+ """
80
+ Loads pretrained SRCNN weights from the basicsr model zoo.
81
+ Falls back to random init with a warning if download fails.
82
+ """
83
+ srcnn = SRCNN(num_channels=3)
84
+ wrapper = SRCNNx4Wrapper(srcnn, scale=4)
85
+
86
+ # Pretrained weights from the basicsr / mmedit community
87
+ # (original Caffe weights re-converted to PyTorch by https://github.com/yjn870/SRCNN-pytorch)
88
+ SRCNN_WEIGHTS_URL = (
89
+ "https://github.com/yjn870/SRCNN-pytorch/raw/master/models/"
90
+ "srcnn_x4.pth"
91
+ )
92
+ weights_path = MODEL_DIR / "srcnn_x4.pth"
93
+
94
+ if not weights_path.exists():
95
+ print(" Downloading SRCNN ×4 weights …")
96
+ try:
97
+ import urllib.request
98
+ urllib.request.urlretrieve(SRCNN_WEIGHTS_URL, weights_path)
99
+ print(f" Saved → {weights_path}")
100
+ except Exception as e:
101
+ print(f" [WARN] Could not download SRCNN weights: {e}")
102
+ print(" Continuing with random init (quality will be poor).")
103
+ return wrapper
104
+
105
+ state = torch.load(weights_path, map_location="cpu")
106
+ # The yjn870 checkpoint uses keys conv1/conv2/conv3 matching our module
107
+ try:
108
+ srcnn.load_state_dict(state, strict=True)
109
+ print(" SRCNN weights loaded ✓")
110
+ except RuntimeError as e:
111
+ print(f" [WARN] Weight mismatch: {e}\n Proceeding with partial load.")
112
+ srcnn.load_state_dict(state, strict=False)
113
+
114
+ return wrapper
115
+
116
+
117
+ # ===========================================================================
118
+ # 2. EDSR (HResNet-style) ×4
119
+ # EDSR-baseline (Lim et al., 2017) is the canonical "deep residual" SR
120
+ # network. Pretrained weights from eugenesiow/torch-sr (HuggingFace).
121
+ # ===========================================================================
122
+
123
+ class ResBlock(nn.Module):
124
+ def __init__(self, n_feats: int, res_scale: float = 1.0):
125
+ super().__init__()
126
+ self.body = nn.Sequential(
127
+ nn.Conv2d(n_feats, n_feats, 3, padding=1),
128
+ nn.ReLU(inplace=True),
129
+ nn.Conv2d(n_feats, n_feats, 3, padding=1),
130
+ )
131
+ self.res_scale = res_scale
132
+
133
+ def forward(self, x):
134
+ return x + self.body(x) * self.res_scale
135
+
136
+
137
+ class Upsampler(nn.Sequential):
138
+ def __init__(self, scale: int, n_feats: int):
139
+ layers = []
140
+ if scale in (2, 4):
141
+ steps = {2: 1, 4: 2}[scale]
142
+ for _ in range(steps):
143
+ layers += [
144
+ nn.Conv2d(n_feats, 4 * n_feats, 3, padding=1),
145
+ nn.PixelShuffle(2),
146
+ ]
147
+ elif scale == 3:
148
+ layers += [
149
+ nn.Conv2d(n_feats, 9 * n_feats, 3, padding=1),
150
+ nn.PixelShuffle(3),
151
+ ]
152
+ super().__init__(*layers)
153
+
154
+
155
+ class EDSR(nn.Module):
156
+ """
157
+ EDSR-baseline: 16 residual blocks, 64 feature channels.
158
+ Matches the publicly released weights from eugenesiow/torch-sr.
159
+ """
160
+ def __init__(self, n_resblocks: int = 16, n_feats: int = 64,
161
+ scale: int = 4, num_channels: int = 3):
162
+ super().__init__()
163
+ self.head = nn.Conv2d(num_channels, n_feats, 3, padding=1)
164
+ self.body = nn.Sequential(*[ResBlock(n_feats) for _ in range(n_resblocks)])
165
+ self.body_tail = nn.Conv2d(n_feats, n_feats, 3, padding=1)
166
+ self.tail = nn.Sequential(
167
+ Upsampler(scale, n_feats),
168
+ nn.Conv2d(n_feats, num_channels, 3, padding=1),
169
+ )
170
+
171
+ def forward(self, x):
172
+ x = self.head(x)
173
+ res = self.body(x)
174
+ res = self.body_tail(res)
175
+ x = x + res
176
+ return self.tail(x)
177
+
178
+
179
+ def build_edsr_x4() -> nn.Module:
180
+ """
181
+ Downloads EDSR-baseline ×4 weights and loads them.
182
+ Source: eugenesiow/torch-sr (Apache-2.0 licensed).
183
+ """
184
+ model = EDSR(n_resblocks=16, n_feats=64, scale=4)
185
+
186
+ # Direct link to the EDSR-baseline ×4 checkpoint
187
+ EDSR_WEIGHTS_URL = (
188
+ "https://huggingface.co/eugenesiow/edsr-base/resolve/main/"
189
+ "pytorch_model_4x.pt"
190
+ )
191
+ weights_path = MODEL_DIR / "edsr_x4.pt"
192
+
193
+ if not weights_path.exists():
194
+ print(" Downloading EDSR ×4 weights from HuggingFace …")
195
+ try:
196
+ import urllib.request
197
+ urllib.request.urlretrieve(EDSR_WEIGHTS_URL, weights_path)
198
+ print(f" Saved → {weights_path}")
199
+ except Exception as e:
200
+ print(f" [WARN] Could not download EDSR weights: {e}")
201
+ print(" Continuing with random init (quality will be poor).")
202
+ return model
203
+
204
+ state = torch.load(weights_path, map_location="cpu")
205
+
206
+ # eugenesiow checkpoints may wrap state_dict under a 'model' key
207
+ if "model" in state:
208
+ state = state["model"]
209
+ if "state_dict" in state:
210
+ state = state["state_dict"]
211
+
212
+ # Strip any 'module.' prefix from DataParallel wrapping
213
+ state = {k.replace("module.", ""): v for k, v in state.items()}
214
+
215
+ try:
216
+ model.load_state_dict(state, strict=True)
217
+ print(" EDSR weights loaded ✓")
218
+ except RuntimeError as e:
219
+ print(f" [WARN] Weight mismatch ({e}). Trying strict=False …")
220
+ model.load_state_dict(state, strict=False)
221
+ print(" EDSR weights loaded (partial) ✓")
222
+
223
+ return model
224
+
225
+
226
+ # ===========================================================================
227
+ # ONNX export helper
228
+ # ===========================================================================
229
+
230
+ def export_onnx(model: nn.Module, out_path: Path, tile_h: int = 128, tile_w: int = 128):
231
+ """Export *model* to ONNX with dynamic H/W axes."""
232
+ model.eval()
233
+ dummy = torch.zeros(1, 3, tile_h, tile_w)
234
+ torch.onnx.export(
235
+ model,
236
+ dummy,
237
+ str(out_path),
238
+ opset_version=17,
239
+ input_names=["input"],
240
+ output_names=["output"],
241
+ dynamic_axes={
242
+ "input": {0: "batch", 2: "H", 3: "W"},
243
+ "output": {0: "batch", 2: "H_out", 3: "W_out"},
244
+ },
245
+ )
246
+ size_mb = out_path.stat().st_size / 1_048_576
247
+ print(f" Exported → {out_path} ({size_mb:.1f} MB)")
248
+
249
+
250
+ # ===========================================================================
251
+ # Main
252
+ # ===========================================================================
253
+
254
+ if __name__ == "__main__":
255
+ print("=" * 60)
256
+ print("SpectraGAN — ONNX model exporter")
257
+ print("=" * 60)
258
+
259
+ # -- SRCNN ×4 ------------------------------------------------------------
260
+ srcnn_out = MODEL_DIR / "SRCNN_x4.onnx"
261
+ if srcnn_out.exists():
262
+ print(f"\n[SKIP] {srcnn_out} already exists.")
263
+ else:
264
+ print("\n[1/2] Building SRCNN ×4 …")
265
+ srcnn_model = build_srcnn_x4()
266
+ print(" Exporting to ONNX …")
267
+ export_onnx(srcnn_model, srcnn_out, tile_h=128, tile_w=128)
268
+
269
+ # -- EDSR (HResNet) ×4 ---------------------------------------------------
270
+ edsr_out = MODEL_DIR / "HResNet_x4.onnx"
271
+ if edsr_out.exists():
272
+ print(f"\n[SKIP] {edsr_out} already exists.")
273
+ else:
274
+ print("\n[2/2] Building EDSR (HResNet) ×4 …")
275
+ edsr_model = build_edsr_x4()
276
+ print(" Exporting to ONNX …")
277
+ export_onnx(edsr_model, edsr_out, tile_h=128, tile_w=128)
278
+
279
+ print("\n" + "=" * 60)
280
+ print("Done! Files created:")
281
+ for p in [srcnn_out, edsr_out]:
282
+ status = "✓" if p.exists() else "✗ MISSING"
283
+ print(f" {status} {p}")
284
+ print()
285
+
286
+ if LOCAL_ONLY:
287
+ print("LOCAL_ONLY = True:")
288
+ print(" app.py will load these files directly from disk.")
289
+ print(" No Google Drive upload needed.")
290
+ else:
291
+ print("Next step:")
292
+ print(" Upload the .onnx files to Google Drive and paste")
293
+ print(" the file IDs into DRIVE_IDS in app.py.")
294
+ print("=" * 60)