TutlaytAI commited on
Commit
07bfc84
·
verified ·
1 Parent(s): 4ba2aed

Update modeling_upscaler.py

Browse files
Files changed (1) hide show
  1. modeling_upscaler.py +116 -117
modeling_upscaler.py CHANGED
@@ -1,118 +1,117 @@
1
- from dataclasses import dataclass
2
- from typing import Optional
3
-
4
- import torch
5
- import torch.nn as nn
6
- import torch.nn.functional as F
7
-
8
- from transformers import PreTrainedModel
9
- from transformers.utils import ModelOutput
10
-
11
- from configuration_upscaler import UpscalerConfig
12
-
13
-
14
- # -------------------------
15
- # Architecture (same as yours)
16
- # -------------------------
17
-
18
- class ResidualBlock(nn.Module):
19
- def __init__(self, channels: int):
20
- super().__init__()
21
- self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
22
- self.act = nn.ReLU(inplace=True)
23
- self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
24
-
25
- def forward(self, x):
26
- y = self.act(self.conv1(x))
27
- y = self.conv2(y)
28
- return x + y
29
-
30
-
31
- class RestorationNet(nn.Module):
32
- def __init__(self, in_channels=3, width=32, num_blocks=3):
33
- super().__init__()
34
- self.in_conv = nn.Conv2d(in_channels, width, 3, padding=1)
35
- self.blocks = nn.Sequential(*[ResidualBlock(width) for _ in range(num_blocks)])
36
- self.out_conv = nn.Conv2d(width, in_channels, 3, padding=1)
37
-
38
- def forward(self, lr):
39
- y = self.blocks(self.in_conv(lr))
40
- y = self.out_conv(y)
41
- return lr + y
42
-
43
-
44
- class ESPCNUpsampler(nn.Module):
45
- def __init__(self, in_channels=3, scale=2, feat1=64, feat2=32, use_refine=False):
46
- super().__init__()
47
- assert scale in (2, 3, 4)
48
- self.conv1 = nn.Conv2d(in_channels, feat1, 5, padding=2)
49
- self.act1 = nn.ReLU(inplace=True)
50
- self.conv2 = nn.Conv2d(feat1, feat2, 3, padding=1)
51
- self.act2 = nn.ReLU(inplace=True)
52
-
53
- # IMPORTANT: conv3 out_channels depends on scale (PixelShuffle constraint)
54
- self.conv3 = nn.Conv2d(feat2, in_channels * (scale ** 2), 3, padding=1)
55
- self.ps = nn.PixelShuffle(scale)
56
-
57
- self.refine = nn.Conv2d(in_channels, in_channels, 3, padding=1) if use_refine else None
58
-
59
- def forward(self, x):
60
- y = self.act1(self.conv1(x))
61
- y = self.act2(self.conv2(y))
62
- y = self.ps(self.conv3(y))
63
- if self.refine is not None:
64
- y = self.refine(y)
65
- return y
66
-
67
-
68
- class TwoStageSR(nn.Module):
69
- def __init__(self, in_channels=3, scale=2, width=32, num_blocks=3, feat1=64, feat2=32, use_refine=False):
70
- super().__init__()
71
- self.scale = scale
72
- self.restoration = RestorationNet(in_channels=in_channels, width=width, num_blocks=num_blocks)
73
- self.upsampler = ESPCNUpsampler(
74
- in_channels=in_channels, scale=scale, feat1=feat1, feat2=feat2, use_refine=use_refine
75
- )
76
-
77
- def forward(self, lr):
78
- lr_clean = self.restoration(lr)
79
- hr_pred = self.upsampler(lr_clean)
80
- return hr_pred
81
-
82
-
83
- # -------------------------
84
- # Transformers output
85
- # -------------------------
86
-
87
- @dataclass
88
- class UpscalerOutput(ModelOutput):
89
- sr: torch.FloatTensor
90
-
91
-
92
- class UpscalerModel(PreTrainedModel):
93
- config_class = UpscalerConfig
94
- main_input_name = "pixel_values"
95
-
96
- def __init__(self, config: UpscalerConfig):
97
- super().__init__(config)
98
-
99
- self.model = TwoStageSR(
100
- in_channels=config.in_channels,
101
- scale=config.scale,
102
- width=config.width,
103
- num_blocks=config.num_blocks,
104
- feat1=config.feat1,
105
- feat2=config.feat2,
106
- use_refine=config.use_refine,
107
- )
108
-
109
- # init weights (optional; usually weights will be loaded)
110
- self.post_init()
111
-
112
- def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput:
113
- """
114
- pixel_values: float tensor in [0,1], shape (B,3,H,W)
115
- returns: UpscalerOutput(sr=...)
116
- """
117
- sr = self.model(pixel_values)
118
  return UpscalerOutput(sr=sr)
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from transformers import PreTrainedModel
9
+ from transformers.utils import ModelOutput
10
+
11
+ from .configuration_upscaler import UpscalerConfig
12
+
13
+
14
+ # -------------------------
15
+ # Architecture (same as yours)
16
+ # -------------------------
17
+
18
+ class ResidualBlock(nn.Module):
19
+ def __init__(self, channels: int):
20
+ super().__init__()
21
+ self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
22
+ self.act = nn.ReLU(inplace=True)
23
+ self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
24
+
25
+ def forward(self, x):
26
+ y = self.act(self.conv1(x))
27
+ y = self.conv2(y)
28
+ return x + y
29
+
30
+
31
+ class RestorationNet(nn.Module):
32
+ def __init__(self, in_channels=3, width=32, num_blocks=3):
33
+ super().__init__()
34
+ self.in_conv = nn.Conv2d(in_channels, width, 3, padding=1)
35
+ self.blocks = nn.Sequential(*[ResidualBlock(width) for _ in range(num_blocks)])
36
+ self.out_conv = nn.Conv2d(width, in_channels, 3, padding=1)
37
+
38
+ def forward(self, lr):
39
+ y = self.blocks(self.in_conv(lr))
40
+ y = self.out_conv(y)
41
+ return lr + y
42
+
43
+
44
+ class ESPCNUpsampler(nn.Module):
45
+ def __init__(self, in_channels=3, scale=2, feat1=64, feat2=32, use_refine=False):
46
+ super().__init__()
47
+ assert scale in (2, 3, 4)
48
+ self.conv1 = nn.Conv2d(in_channels, feat1, 5, padding=2)
49
+ self.act1 = nn.ReLU(inplace=True)
50
+ self.conv2 = nn.Conv2d(feat1, feat2, 3, padding=1)
51
+ self.act2 = nn.ReLU(inplace=True)
52
+
53
+ # IMPORTANT: conv3 out_channels depends on scale (PixelShuffle constraint)
54
+ self.conv3 = nn.Conv2d(feat2, in_channels * (scale ** 2), 3, padding=1)
55
+ self.ps = nn.PixelShuffle(scale)
56
+
57
+ self.refine = nn.Conv2d(in_channels, in_channels, 3, padding=1) if use_refine else None
58
+
59
+ def forward(self, x):
60
+ y = self.act1(self.conv1(x))
61
+ y = self.act2(self.conv2(y))
62
+ y = self.ps(self.conv3(y))
63
+ if self.refine is not None:
64
+ y = self.refine(y)
65
+ return y
66
+
67
+
68
+ class TwoStageSR(nn.Module):
69
+ def __init__(self, in_channels=3, scale=2, width=32, num_blocks=3, feat1=64, feat2=32, use_refine=False):
70
+ super().__init__()
71
+ self.scale = scale
72
+ self.restoration = RestorationNet(in_channels=in_channels, width=width, num_blocks=num_blocks)
73
+ self.upsampler = ESPCNUpsampler(
74
+ in_channels=in_channels, scale=scale, feat1=feat1, feat2=feat2, use_refine=use_refine
75
+ )
76
+
77
+ def forward(self, lr):
78
+ lr_clean = self.restoration(lr)
79
+ hr_pred = self.upsampler(lr_clean)
80
+ return hr_pred
81
+
82
+
83
+ # -------------------------
84
+ # Transformers output
85
+ # -------------------------
86
+
87
+ @dataclass
88
+ class UpscalerOutput(ModelOutput):
89
+ sr: torch.FloatTensor
90
+
91
+
92
+ class UpscalerModel(PreTrainedModel):
93
+ config_class = UpscalerConfig
94
+ main_input_name = "pixel_values"
95
+
96
+ def __init__(self, config: UpscalerConfig):
97
+ super().__init__(config)
98
+
99
+ self.model = TwoStageSR(
100
+ in_channels=config.in_channels,
101
+ scale=config.scale,
102
+ width=config.width,
103
+ num_blocks=config.num_blocks,
104
+ feat1=config.feat1,
105
+ feat2=config.feat2,
106
+ use_refine=config.use_refine,
107
+ )
108
+
109
+ self.post_init()
110
+
111
+ def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> UpscalerOutput:
112
+ """
113
+ pixel_values: float tensor in [0,1], shape (B,3,H,W)
114
+ returns: UpscalerOutput(sr=...)
115
+ """
116
+ sr = self.model(pixel_values)
 
117
  return UpscalerOutput(sr=sr)