Jonttup commited on
Commit
27fdf85
Β·
verified Β·
1 Parent(s): e27c6bd

Upload models/oklab_utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/oklab_utils.py +263 -0
models/oklab_utils.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OKLab Color Space Utilities
3
+
4
+ Perceptually uniform color space for semantic loss computation.
5
+ OKLab ensures that equal distances in the color space correspond to
6
+ equal perceived differences β€” critical for meaningful color-based encoding.
7
+
8
+ Key functions:
9
+ - srgb_to_oklab / oklab_to_srgb: Color space conversions
10
+ - rotate_ab: Rotate hue in a-b plane (for domain/idiom shifts)
11
+ - set_chroma: Set chroma magnitude (for purity encoding)
12
+ - OKLabMSELoss: Perceptually uniform loss function
13
+ - hsl_to_oklab_batch: Batch conversion for training
14
+ """
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import math
19
+ from typing import Tuple
20
+
21
+
22
+ def clamp(x: float, lo: float, hi: float) -> float:
23
+ """Clamp a value to [lo, hi]."""
24
+ return max(lo, min(hi, x))
25
+
26
+
27
+ # ── sRGB ↔ Linear RGB ──
28
+
29
+ def srgb_to_linear(c: float) -> float:
30
+ """sRGB gamma to linear."""
31
+ if c <= 0.04045:
32
+ return c / 12.92
33
+ return ((c + 0.055) / 1.055) ** 2.4
34
+
35
+
36
+ def linear_to_srgb(c: float) -> float:
37
+ """Linear to sRGB gamma."""
38
+ if c <= 0.0031308:
39
+ return c * 12.92
40
+ return 1.055 * (c ** (1.0 / 2.4)) - 0.055
41
+
42
+
43
+ # ── sRGB ↔ OKLab ──
44
+
45
+ def srgb_to_oklab(r: float, g: float, b: float) -> Tuple[float, float, float]:
46
+ """Convert sRGB [0,1] to OKLab."""
47
+ r_lin = srgb_to_linear(r)
48
+ g_lin = srgb_to_linear(g)
49
+ b_lin = srgb_to_linear(b)
50
+
51
+ l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin
52
+ m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin
53
+ s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin
54
+
55
+ l_c = l_ ** (1.0 / 3.0) if l_ >= 0 else -((-l_) ** (1.0 / 3.0))
56
+ m_c = m_ ** (1.0 / 3.0) if m_ >= 0 else -((-m_) ** (1.0 / 3.0))
57
+ s_c = s_ ** (1.0 / 3.0) if s_ >= 0 else -((-s_) ** (1.0 / 3.0))
58
+
59
+ L = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c
60
+ a = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c
61
+ b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c
62
+
63
+ return (L, a, b_ok)
64
+
65
+
66
+ def oklab_to_srgb(L: float, a: float, b_ok: float) -> Tuple[float, float, float]:
67
+ """Convert OKLab to sRGB [0,1]."""
68
+ l_c = L + 0.3963377774 * a + 0.2158037573 * b_ok
69
+ m_c = L - 0.1055613458 * a - 0.0638541728 * b_ok
70
+ s_c = L - 0.0894841775 * a - 1.2914855480 * b_ok
71
+
72
+ l_ = l_c * l_c * l_c
73
+ m_ = m_c * m_c * m_c
74
+ s_ = s_c * s_c * s_c
75
+
76
+ r_lin = +4.0767416621 * l_ - 3.3077115913 * m_ + 0.2309699292 * s_
77
+ g_lin = -1.2684380046 * l_ + 2.6097574011 * m_ - 0.3413193965 * s_
78
+ b_lin = -0.0041960863 * l_ - 0.7034186147 * m_ + 1.7076147010 * s_
79
+
80
+ r = clamp(linear_to_srgb(clamp(r_lin, 0, 1)), 0, 1)
81
+ g = clamp(linear_to_srgb(clamp(g_lin, 0, 1)), 0, 1)
82
+ b = clamp(linear_to_srgb(clamp(b_lin, 0, 1)), 0, 1)
83
+
84
+ return (r, g, b)
85
+
86
+
87
+ # ── HSL ↔ RGB ──
88
+
89
+ def hsl_to_rgb(h_deg: float, s_pct: float, l_pct: float) -> Tuple[float, float, float]:
90
+ """Convert HSL (degrees, percent, percent) to RGB [0,1]."""
91
+ h = h_deg / 360.0
92
+ s = s_pct / 100.0
93
+ l = l_pct / 100.0
94
+
95
+ if s == 0:
96
+ return (l, l, l)
97
+
98
+ def hue_to_rgb(p, q, t):
99
+ if t < 0: t += 1
100
+ if t > 1: t -= 1
101
+ if t < 1/6: return p + (q - p) * 6 * t
102
+ if t < 1/2: return q
103
+ if t < 2/3: return p + (q - p) * (2/3 - t) * 6
104
+ return p
105
+
106
+ q = l * (1 + s) if l < 0.5 else l + s - l * s
107
+ p = 2 * l - q
108
+
109
+ r = hue_to_rgb(p, q, h + 1/3)
110
+ g = hue_to_rgb(p, q, h)
111
+ b = hue_to_rgb(p, q, h - 1/3)
112
+
113
+ return (r, g, b)
114
+
115
+
116
+ def rgb_to_hsl(r: float, g: float, b: float) -> Tuple[float, float, float]:
117
+ """Convert RGB [0,1] to HSL (degrees, percent, percent)."""
118
+ max_c = max(r, g, b)
119
+ min_c = min(r, g, b)
120
+ l = (max_c + min_c) / 2.0
121
+
122
+ if max_c == min_c:
123
+ h = s = 0.0
124
+ else:
125
+ d = max_c - min_c
126
+ s = d / (2.0 - max_c - min_c) if l > 0.5 else d / (max_c + min_c)
127
+
128
+ if max_c == r:
129
+ h = (g - b) / d + (6 if g < b else 0)
130
+ elif max_c == g:
131
+ h = (b - r) / d + 2
132
+ else:
133
+ h = (r - g) / d + 4
134
+
135
+ h /= 6.0
136
+
137
+ return (h * 360.0, s * 100.0, l * 100.0)
138
+
139
+
140
+ # ── OKLab Operations ──
141
+
142
+ def rotate_ab(a: float, b: float, degrees: float) -> Tuple[float, float]:
143
+ """Rotate hue in OKLab a-b plane by given degrees."""
144
+ rad = math.radians(degrees)
145
+ cos_r = math.cos(rad)
146
+ sin_r = math.sin(rad)
147
+ return (a * cos_r - b * sin_r, a * sin_r + b * cos_r)
148
+
149
+
150
+ def set_chroma(a: float, b: float, target_c: float) -> Tuple[float, float]:
151
+ """Set the chroma (magnitude in a-b plane) to target value."""
152
+ current_c = math.sqrt(a * a + b * b)
153
+ if current_c < 1e-10:
154
+ return (target_c, 0.0) # Default direction
155
+ scale = target_c / current_c
156
+ return (a * scale, b * scale)
157
+
158
+
159
+ def get_chroma(a: float, b: float) -> float:
160
+ """Get chroma magnitude from a-b values."""
161
+ return math.sqrt(a * a + b * b)
162
+
163
+
164
+ def compute_delta_e_oklab(
165
+ L1: float, a1: float, b1: float,
166
+ L2: float, a2: float, b2: float,
167
+ ) -> float:
168
+ """Compute Ξ”E in OKLab space (perceptual color difference)."""
169
+ return math.sqrt((L1 - L2) ** 2 + (a1 - a2) ** 2 + (b1 - b2) ** 2)
170
+
171
+
172
+ # ── Batch Operations (PyTorch) ──
173
+
174
+ def hsl_to_oklab_batch(hsl: torch.Tensor) -> torch.Tensor:
175
+ """
176
+ Batch convert HSL [0,1] normalized to OKLab.
177
+
178
+ Args:
179
+ hsl: (..., 3) tensor with H,S,L in [0,1]
180
+
181
+ Returns:
182
+ (..., 3) tensor with L,a,b in OKLab
183
+ """
184
+ h = hsl[..., 0] * 360.0 # Back to degrees
185
+ s = hsl[..., 1] * 100.0 # Back to percent
186
+ l = hsl[..., 2] * 100.0 # Back to percent
187
+
188
+ # HSL to RGB (vectorized)
189
+ h_norm = h / 360.0
190
+ q = torch.where(l / 100.0 < 0.5,
191
+ (l / 100.0) * (1 + s / 100.0),
192
+ (l / 100.0) + (s / 100.0) - (l / 100.0) * (s / 100.0))
193
+ p = 2 * (l / 100.0) - q
194
+
195
+ def hue2rgb(p, q, t):
196
+ t = t % 1.0
197
+ r = torch.where(t < 1/6, p + (q - p) * 6 * t,
198
+ torch.where(t < 1/2, q,
199
+ torch.where(t < 2/3, p + (q - p) * (2/3 - t) * 6, p)))
200
+ return r
201
+
202
+ r = hue2rgb(p, q, h_norm + 1/3)
203
+ g = hue2rgb(p, q, h_norm)
204
+ b = hue2rgb(p, q, h_norm - 1/3)
205
+
206
+ # Handle achromatic (s == 0)
207
+ achromatic = (s < 0.001)
208
+ r = torch.where(achromatic, l / 100.0, r)
209
+ g = torch.where(achromatic, l / 100.0, g)
210
+ b = torch.where(achromatic, l / 100.0, b)
211
+
212
+ # sRGB to linear
213
+ r_lin = torch.where(r <= 0.04045, r / 12.92, ((r + 0.055) / 1.055) ** 2.4)
214
+ g_lin = torch.where(g <= 0.04045, g / 12.92, ((g + 0.055) / 1.055) ** 2.4)
215
+ b_lin = torch.where(b <= 0.04045, b / 12.92, ((b + 0.055) / 1.055) ** 2.4)
216
+
217
+ # Linear RGB to OKLab
218
+ l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin
219
+ m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin
220
+ s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin
221
+
222
+ l_c = torch.sign(l_) * torch.abs(l_).pow(1/3)
223
+ m_c = torch.sign(m_) * torch.abs(m_).pow(1/3)
224
+ s_c = torch.sign(s_) * torch.abs(s_).pow(1/3)
225
+
226
+ L_ok = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c
227
+ a_ok = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c
228
+ b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c
229
+
230
+ return torch.stack([L_ok, a_ok, b_ok], dim=-1)
231
+
232
+
233
+ def denormalize_hsl(hsl_norm: torch.Tensor) -> torch.Tensor:
234
+ """Convert normalized HSL [0,1] to degrees/percent format."""
235
+ result = hsl_norm.clone()
236
+ result[..., 0] *= 360.0 # H: [0,1] β†’ [0,360]
237
+ result[..., 1] *= 100.0 # S: [0,1] β†’ [0,100]
238
+ result[..., 2] *= 100.0 # L: [0,1] β†’ [0,100]
239
+ return result
240
+
241
+
242
+ class OKLabMSELoss(nn.Module):
243
+ """
244
+ Perceptually uniform loss in OKLab space.
245
+
246
+ Converts predicted and target HSL values to OKLab, then computes MSE.
247
+ This handles hue circularity correctly (359Β° β‰ˆ 1Β°) because OKLab
248
+ represents hue as a-b coordinates, not an angle.
249
+ """
250
+
251
+ def __init__(self):
252
+ super().__init__()
253
+
254
+ def forward(
255
+ self,
256
+ pred_hsl: torch.Tensor, # (B, 3) predicted HSL in [0,1]
257
+ target_hsl: torch.Tensor, # (B, 3) target HSL in [0,1]
258
+ ) -> torch.Tensor:
259
+ """Compute perceptually uniform loss."""
260
+ pred_oklab = hsl_to_oklab_batch(pred_hsl)
261
+ target_oklab = hsl_to_oklab_batch(target_hsl)
262
+
263
+ return torch.nn.functional.mse_loss(pred_oklab, target_oklab)