| """
|
| OKLab Color Space Utilities
|
|
|
| Perceptually uniform color space for semantic loss computation.
|
| OKLab ensures that equal distances in the color space correspond to
|
| equal perceived differences β critical for meaningful color-based encoding.
|
|
|
| Key functions:
|
| - srgb_to_oklab / oklab_to_srgb: Color space conversions
|
| - rotate_ab: Rotate hue in a-b plane (for domain/idiom shifts)
|
| - set_chroma: Set chroma magnitude (for purity encoding)
|
| - OKLabMSELoss: Perceptually uniform loss function
|
| - hsl_to_oklab_batch: Batch conversion for training
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import math
|
| from typing import Tuple
|
|
|
|
|
| def clamp(x: float, lo: float, hi: float) -> float:
|
| """Clamp a value to [lo, hi]."""
|
| return max(lo, min(hi, x))
|
|
|
|
|
|
|
|
|
| def srgb_to_linear(c: float) -> float:
|
| """sRGB gamma to linear."""
|
| if c <= 0.04045:
|
| return c / 12.92
|
| return ((c + 0.055) / 1.055) ** 2.4
|
|
|
|
|
| def linear_to_srgb(c: float) -> float:
|
| """Linear to sRGB gamma."""
|
| if c <= 0.0031308:
|
| return c * 12.92
|
| return 1.055 * (c ** (1.0 / 2.4)) - 0.055
|
|
|
|
|
|
|
|
|
| def srgb_to_oklab(r: float, g: float, b: float) -> Tuple[float, float, float]:
|
| """Convert sRGB [0,1] to OKLab."""
|
| r_lin = srgb_to_linear(r)
|
| g_lin = srgb_to_linear(g)
|
| b_lin = srgb_to_linear(b)
|
|
|
| l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin
|
| m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin
|
| s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin
|
|
|
| l_c = l_ ** (1.0 / 3.0) if l_ >= 0 else -((-l_) ** (1.0 / 3.0))
|
| m_c = m_ ** (1.0 / 3.0) if m_ >= 0 else -((-m_) ** (1.0 / 3.0))
|
| s_c = s_ ** (1.0 / 3.0) if s_ >= 0 else -((-s_) ** (1.0 / 3.0))
|
|
|
| L = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c
|
| a = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c
|
| b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c
|
|
|
| return (L, a, b_ok)
|
|
|
|
|
| def oklab_to_srgb(L: float, a: float, b_ok: float) -> Tuple[float, float, float]:
|
| """Convert OKLab to sRGB [0,1]."""
|
| l_c = L + 0.3963377774 * a + 0.2158037573 * b_ok
|
| m_c = L - 0.1055613458 * a - 0.0638541728 * b_ok
|
| s_c = L - 0.0894841775 * a - 1.2914855480 * b_ok
|
|
|
| l_ = l_c * l_c * l_c
|
| m_ = m_c * m_c * m_c
|
| s_ = s_c * s_c * s_c
|
|
|
| r_lin = +4.0767416621 * l_ - 3.3077115913 * m_ + 0.2309699292 * s_
|
| g_lin = -1.2684380046 * l_ + 2.6097574011 * m_ - 0.3413193965 * s_
|
| b_lin = -0.0041960863 * l_ - 0.7034186147 * m_ + 1.7076147010 * s_
|
|
|
| r = clamp(linear_to_srgb(clamp(r_lin, 0, 1)), 0, 1)
|
| g = clamp(linear_to_srgb(clamp(g_lin, 0, 1)), 0, 1)
|
| b = clamp(linear_to_srgb(clamp(b_lin, 0, 1)), 0, 1)
|
|
|
| return (r, g, b)
|
|
|
|
|
|
|
|
|
| def hsl_to_rgb(h_deg: float, s_pct: float, l_pct: float) -> Tuple[float, float, float]:
|
| """Convert HSL (degrees, percent, percent) to RGB [0,1]."""
|
| h = h_deg / 360.0
|
| s = s_pct / 100.0
|
| l = l_pct / 100.0
|
|
|
| if s == 0:
|
| return (l, l, l)
|
|
|
| def hue_to_rgb(p, q, t):
|
| if t < 0: t += 1
|
| if t > 1: t -= 1
|
| if t < 1/6: return p + (q - p) * 6 * t
|
| if t < 1/2: return q
|
| if t < 2/3: return p + (q - p) * (2/3 - t) * 6
|
| return p
|
|
|
| q = l * (1 + s) if l < 0.5 else l + s - l * s
|
| p = 2 * l - q
|
|
|
| r = hue_to_rgb(p, q, h + 1/3)
|
| g = hue_to_rgb(p, q, h)
|
| b = hue_to_rgb(p, q, h - 1/3)
|
|
|
| return (r, g, b)
|
|
|
|
|
| def rgb_to_hsl(r: float, g: float, b: float) -> Tuple[float, float, float]:
|
| """Convert RGB [0,1] to HSL (degrees, percent, percent)."""
|
| max_c = max(r, g, b)
|
| min_c = min(r, g, b)
|
| l = (max_c + min_c) / 2.0
|
|
|
| if max_c == min_c:
|
| h = s = 0.0
|
| else:
|
| d = max_c - min_c
|
| s = d / (2.0 - max_c - min_c) if l > 0.5 else d / (max_c + min_c)
|
|
|
| if max_c == r:
|
| h = (g - b) / d + (6 if g < b else 0)
|
| elif max_c == g:
|
| h = (b - r) / d + 2
|
| else:
|
| h = (r - g) / d + 4
|
|
|
| h /= 6.0
|
|
|
| return (h * 360.0, s * 100.0, l * 100.0)
|
|
|
|
|
|
|
|
|
| def rotate_ab(a: float, b: float, degrees: float) -> Tuple[float, float]:
|
| """Rotate hue in OKLab a-b plane by given degrees."""
|
| rad = math.radians(degrees)
|
| cos_r = math.cos(rad)
|
| sin_r = math.sin(rad)
|
| return (a * cos_r - b * sin_r, a * sin_r + b * cos_r)
|
|
|
|
|
| def set_chroma(a: float, b: float, target_c: float) -> Tuple[float, float]:
|
| """Set the chroma (magnitude in a-b plane) to target value."""
|
| current_c = math.sqrt(a * a + b * b)
|
| if current_c < 1e-10:
|
| return (target_c, 0.0)
|
| scale = target_c / current_c
|
| return (a * scale, b * scale)
|
|
|
|
|
| def get_chroma(a: float, b: float) -> float:
|
| """Get chroma magnitude from a-b values."""
|
| return math.sqrt(a * a + b * b)
|
|
|
|
|
| def compute_delta_e_oklab(
|
| L1: float, a1: float, b1: float,
|
| L2: float, a2: float, b2: float,
|
| ) -> float:
|
| """Compute ΞE in OKLab space (perceptual color difference)."""
|
| return math.sqrt((L1 - L2) ** 2 + (a1 - a2) ** 2 + (b1 - b2) ** 2)
|
|
|
|
|
|
|
|
|
| def hsl_to_oklab_batch(hsl: torch.Tensor) -> torch.Tensor:
|
| """
|
| Batch convert HSL [0,1] normalized to OKLab.
|
|
|
| Args:
|
| hsl: (..., 3) tensor with H,S,L in [0,1]
|
|
|
| Returns:
|
| (..., 3) tensor with L,a,b in OKLab
|
| """
|
| h = hsl[..., 0] * 360.0
|
| s = hsl[..., 1] * 100.0
|
| l = hsl[..., 2] * 100.0
|
|
|
|
|
| h_norm = h / 360.0
|
| q = torch.where(l / 100.0 < 0.5,
|
| (l / 100.0) * (1 + s / 100.0),
|
| (l / 100.0) + (s / 100.0) - (l / 100.0) * (s / 100.0))
|
| p = 2 * (l / 100.0) - q
|
|
|
| def hue2rgb(p, q, t):
|
| t = t % 1.0
|
| r = torch.where(t < 1/6, p + (q - p) * 6 * t,
|
| torch.where(t < 1/2, q,
|
| torch.where(t < 2/3, p + (q - p) * (2/3 - t) * 6, p)))
|
| return r
|
|
|
| r = hue2rgb(p, q, h_norm + 1/3)
|
| g = hue2rgb(p, q, h_norm)
|
| b = hue2rgb(p, q, h_norm - 1/3)
|
|
|
|
|
| achromatic = (s < 0.001)
|
| r = torch.where(achromatic, l / 100.0, r)
|
| g = torch.where(achromatic, l / 100.0, g)
|
| b = torch.where(achromatic, l / 100.0, b)
|
|
|
|
|
| r_lin = torch.where(r <= 0.04045, r / 12.92, ((r + 0.055) / 1.055) ** 2.4)
|
| g_lin = torch.where(g <= 0.04045, g / 12.92, ((g + 0.055) / 1.055) ** 2.4)
|
| b_lin = torch.where(b <= 0.04045, b / 12.92, ((b + 0.055) / 1.055) ** 2.4)
|
|
|
|
|
| l_ = 0.4122214708 * r_lin + 0.5363325363 * g_lin + 0.0514459929 * b_lin
|
| m_ = 0.2119034982 * r_lin + 0.6806995451 * g_lin + 0.1073969566 * b_lin
|
| s_ = 0.0883024619 * r_lin + 0.2817188376 * g_lin + 0.6299787005 * b_lin
|
|
|
| l_c = torch.sign(l_) * torch.abs(l_).pow(1/3)
|
| m_c = torch.sign(m_) * torch.abs(m_).pow(1/3)
|
| s_c = torch.sign(s_) * torch.abs(s_).pow(1/3)
|
|
|
| L_ok = 0.2104542553 * l_c + 0.7936177850 * m_c - 0.0040720468 * s_c
|
| a_ok = 1.9779984951 * l_c - 2.4285922050 * m_c + 0.4505937099 * s_c
|
| b_ok = 0.0259040371 * l_c + 0.7827717662 * m_c - 0.8086757660 * s_c
|
|
|
| return torch.stack([L_ok, a_ok, b_ok], dim=-1)
|
|
|
|
|
| def denormalize_hsl(hsl_norm: torch.Tensor) -> torch.Tensor:
|
| """Convert normalized HSL [0,1] to degrees/percent format."""
|
| result = hsl_norm.clone()
|
| result[..., 0] *= 360.0
|
| result[..., 1] *= 100.0
|
| result[..., 2] *= 100.0
|
| return result
|
|
|
|
|
| class OKLabMSELoss(nn.Module):
|
| """
|
| Perceptually uniform loss in OKLab space.
|
|
|
| Converts predicted and target HSL values to OKLab, then computes MSE.
|
| This handles hue circularity correctly (359Β° β 1Β°) because OKLab
|
| represents hue as a-b coordinates, not an angle.
|
| """
|
|
|
| def __init__(self):
|
| super().__init__()
|
|
|
| def forward(
|
| self,
|
| pred_hsl: torch.Tensor,
|
| target_hsl: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """Compute perceptually uniform loss."""
|
| pred_oklab = hsl_to_oklab_batch(pred_hsl)
|
| target_oklab = hsl_to_oklab_batch(target_hsl)
|
|
|
| return torch.nn.functional.mse_loss(pred_oklab, target_oklab)
|
|
|