""" OKLab Color Space Utilities Perceptually uniform color space for semantic loss computation. OKLab ensures that equal distances in the color space correspond to equal perceived differences — critical for meaningful color-based encoding. Key functions: - srgb_to_oklab / oklab_to_srgb: Color space conversions - rotate_ab: Rotate hue in a-b plane (for domain/idiom shifts) - set_chroma: Set chroma magnitude (for purity encoding) - OKLabMSELoss: Perceptually uniform loss function - hsl_to_oklab_batch: Batch conversion for training """ import torch import torch.nn as nn import math from typing import Tuple def clamp(x: float, lo: float, hi: float) -> float: """Clamp a value to [lo, hi].""" return max(lo, min(hi, x)) # ── sRGB ↔ Linear RGB ── def srgb_to_linear(c: float) -> float: """sRGB gamma to linear.""" if c <= 0.04045: return c / 12.92 return ((c + 0.055) / 1.055) ** 2.4 def linear_to_srgb(c: float) -> float: """Linear to sRGB gamma.""" if c <= 0.0031308: return c * 12.92 return 1.055 * (c ** (1.0 / 2.4)) - 0.055 # ── sRGB ↔ OKLab ── def srgb_to_oklab(r: float, g: float, b: float) -> Tuple[float, float, float]: """Convert sRGB [0,1] to OKLab.""" r_lin = srgb_to_linear(r) g_lin = srgb_to_linear(g) b_lin = srgb_to_linear(b) l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin l_c = l_ ** (1.0 / 3.0) if l_ >= 0 else -((-l_) ** (1.0 / 3.0)) m_c = m_ ** (1.0 / 3.0) if m_ >= 0 else -((-m_) ** (1.0 / 3.0)) s_c = s_ ** (1.0 / 3.0) if s_ >= 0 else -((-s_) ** (1.0 / 3.0)) L = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c a = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c return (L, a, b_ok) def oklab_to_srgb(L: float, a: float, b_ok: float) -> Tuple[float, float, float]: """Convert OKLab to sRGB [0,1].""" l_c = L + 0.3963377774 * a + 0.2158037573 * b_ok m_c = L - 0.1055613458 * a - 0.0638541728 * b_ok s_c = L - 0.0894841775 * a - 1.2914855480 * b_ok l_ = l_c * l_c * l_c m_ = m_c * m_c * m_c s_ = s_c * s_c * s_c r_lin = +4.0767416621 * l_ - 3.3077115913 * m_ + 0.2309699292 * s_ g_lin = -1.2684380046 * l_ + 2.6097574011 * m_ - 0.3413193965 * s_ b_lin = -0.0041960863 * l_ - 0.7034186147 * m_ + 1.7076147010 * s_ r = clamp(linear_to_srgb(clamp(r_lin, 0, 1)), 0, 1) g = clamp(linear_to_srgb(clamp(g_lin, 0, 1)), 0, 1) b = clamp(linear_to_srgb(clamp(b_lin, 0, 1)), 0, 1) return (r, g, b) # ── HSL ↔ RGB ── def hsl_to_rgb(h_deg: float, s_pct: float, l_pct: float) -> Tuple[float, float, float]: """Convert HSL (degrees, percent, percent) to RGB [0,1].""" h = h_deg / 360.0 s = s_pct / 100.0 l = l_pct / 100.0 if s == 0: return (l, l, l) def hue_to_rgb(p, q, t): if t < 0: t += 1 if t > 1: t -= 1 if t < 1/6: return p + (q - p) * 6 * t if t < 1/2: return q if t < 2/3: return p + (q - p) * (2/3 - t) * 6 return p q = l * (1 + s) if l < 0.5 else l + s - l * s p = 2 * l - q r = hue_to_rgb(p, q, h + 1/3) g = hue_to_rgb(p, q, h) b = hue_to_rgb(p, q, h - 1/3) return (r, g, b) def rgb_to_hsl(r: float, g: float, b: float) -> Tuple[float, float, float]: """Convert RGB [0,1] to HSL (degrees, percent, percent).""" max_c = max(r, g, b) min_c = min(r, g, b) l = (max_c + min_c) / 2.0 if max_c == min_c: h = s = 0.0 else: d = max_c - min_c s = d / (2.0 - max_c - min_c) if l > 0.5 else d / (max_c + min_c) if max_c == r: h = (g - b) / d + (6 if g < b else 0) elif max_c == g: h = (b - r) / d + 2 else: h = (r - g) / d + 4 h /= 6.0 return (h * 360.0, s * 100.0, l * 100.0) # ── OKLab Operations ── def rotate_ab(a: float, b: float, degrees: float) -> Tuple[float, float]: """Rotate hue in OKLab a-b plane by given degrees.""" rad = math.radians(degrees) cos_r = math.cos(rad) sin_r = math.sin(rad) return (a * cos_r - b * sin_r, a * sin_r + b * cos_r) def set_chroma(a: float, b: float, target_c: float) -> Tuple[float, float]: """Set the chroma (magnitude in a-b plane) to target value.""" current_c = math.sqrt(a * a + b * b) if current_c < 1e-10: return (target_c, 0.0) # Default direction scale = target_c / current_c return (a * scale, b * scale) def get_chroma(a: float, b: float) -> float: """Get chroma magnitude from a-b values.""" return math.sqrt(a * a + b * b) def compute_delta_e_oklab( L1: float, a1: float, b1: float, L2: float, a2: float, b2: float, ) -> float: """Compute ΔE in OKLab space (perceptual color difference).""" return math.sqrt((L1 - L2) ** 2 + (a1 - a2) ** 2 + (b1 - b2) ** 2) # ── Batch Operations (PyTorch) ── def hsl_to_oklab_batch(hsl: torch.Tensor) -> torch.Tensor: """ Batch convert HSL [0,1] normalized to OKLab. Args: hsl: (..., 3) tensor with H,S,L in [0,1] Returns: (..., 3) tensor with L,a,b in OKLab """ h = hsl[..., 0] * 360.0 # Back to degrees s = hsl[..., 1] * 100.0 # Back to percent l = hsl[..., 2] * 100.0 # Back to percent # HSL to RGB (vectorized) h_norm = h / 360.0 q = torch.where(l / 100.0 < 0.5, (l / 100.0) * (1 + s / 100.0), (l / 100.0) + (s / 100.0) - (l / 100.0) * (s / 100.0)) p = 2 * (l / 100.0) - q def hue2rgb(p, q, t): t = t % 1.0 r = torch.where(t < 1/6, p + (q - p) * 6 * t, torch.where(t < 1/2, q, torch.where(t < 2/3, p + (q - p) * (2/3 - t) * 6, p))) return r r = hue2rgb(p, q, h_norm + 1/3) g = hue2rgb(p, q, h_norm) b = hue2rgb(p, q, h_norm - 1/3) # Handle achromatic (s == 0) achromatic = (s < 0.001) r = torch.where(achromatic, l / 100.0, r) g = torch.where(achromatic, l / 100.0, g) b = torch.where(achromatic, l / 100.0, b) # sRGB to linear r_lin = torch.where(r <= 0.04045, r / 12.92, ((r + 0.055) / 1.055) ** 2.4) g_lin = torch.where(g <= 0.04045, g / 12.92, ((g + 0.055) / 1.055) ** 2.4) b_lin = torch.where(b <= 0.04045, b / 12.92, ((b + 0.055) / 1.055) ** 2.4) # Linear RGB to OKLab l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin l_c = torch.sign(l_) * torch.abs(l_).pow(1/3) m_c = torch.sign(m_) * torch.abs(m_).pow(1/3) s_c = torch.sign(s_) * torch.abs(s_).pow(1/3) L_ok = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c a_ok = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c return torch.stack([L_ok, a_ok, b_ok], dim=-1) def denormalize_hsl(hsl_norm: torch.Tensor) -> torch.Tensor: """Convert normalized HSL [0,1] to degrees/percent format.""" result = hsl_norm.clone() result[..., 0] *= 360.0 # H: [0,1] → [0,360] result[..., 1] *= 100.0 # S: [0,1] → [0,100] result[..., 2] *= 100.0 # L: [0,1] → [0,100] return result class OKLabMSELoss(nn.Module): """ Perceptually uniform loss in OKLab space. Converts predicted and target HSL values to OKLab, then computes MSE. This handles hue circularity correctly (359° ≈ 1°) because OKLab represents hue as a-b coordinates, not an angle. """ def __init__(self): super().__init__() def forward( self, pred_hsl: torch.Tensor, # (B, 3) predicted HSL in [0,1] target_hsl: torch.Tensor, # (B, 3) target HSL in [0,1] ) -> torch.Tensor: """Compute perceptually uniform loss.""" pred_oklab = hsl_to_oklab_batch(pred_hsl) target_oklab = hsl_to_oklab_batch(target_hsl) return torch.nn.functional.mse_loss(pred_oklab, target_oklab)