| import torch |
| from jaxtyping import Float |
|
|
|
|
| def srgb_to_linear(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: |
| switch_val = 0.04045 |
| return torch.where( |
| torch.greater(x, switch_val), |
| ((x.clip(min=switch_val) + 0.055) / 1.055).pow(2.4), |
| x / 12.92, |
| ) |
|
|
|
|
| def linear_to_srgb(x: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: |
| switch_val = 0.0031308 |
| return torch.where( |
| torch.greater(x, switch_val), |
| 1.055 * x.clip(min=switch_val).pow(1.0 / 2.4) - 0.055, |
| x * 12.92, |
| ) |
|
|
|
|
| def rgb_to_lab(srgb: Float[torch.Tensor, "*B C"]) -> Float[torch.Tensor, "*B C"]: |
| srgb_pixels = torch.reshape(srgb, [-1, 3]) |
|
|
| linear_mask = srgb_pixels <= 0.04045 |
| exponential_mask = srgb_pixels > 0.04045 |
| rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + ( |
| ((srgb_pixels + 0.055) / 1.055) ** 2.4 |
| ) * exponential_mask |
|
|
| rgb_to_xyz = ( |
| torch.tensor( |
| [ |
| |
| [0.412453, 0.212671, 0.019334], |
| [0.357580, 0.715160, 0.119193], |
| [0.180423, 0.072169, 0.950227], |
| ] |
| ) |
| .to(srgb.dtype) |
| .to(srgb.device) |
| ) |
|
|
| xyz_pixels = torch.mm(rgb_pixels, rgb_to_xyz) |
|
|
| xyz_normalized_pixels = torch.mul( |
| xyz_pixels, |
| torch.tensor([1 / 0.950456, 1.0, 1 / 1.088754]).to(srgb.dtype).to(srgb.device), |
| ) |
|
|
| epsilon = 6.0 / 29.0 |
| linear_mask = (xyz_normalized_pixels <= (epsilon**3)).to(srgb.dtype).to(srgb.device) |
|
|
| exponential_mask = ( |
| (xyz_normalized_pixels > (epsilon**3)).to(srgb.dtype).to(srgb.device) |
| ) |
|
|
| fxfyfz_pixels = ( |
| xyz_normalized_pixels / (3 * epsilon**2) + 4.0 / 29.0 |
| ) * linear_mask + ( |
| (xyz_normalized_pixels + 0.000001) ** (1.0 / 3.0) |
| ) * exponential_mask |
|
|
| fxfyfz_to_lab = ( |
| torch.tensor( |
| [ |
| |
| [0.0, 500.0, 0.0], |
| [116.0, -500.0, 200.0], |
| [0.0, 0.0, -200.0], |
| ] |
| ) |
| .to(srgb.dtype) |
| .to(srgb.device) |
| ) |
| lab_pixels = torch.mm(fxfyfz_pixels, fxfyfz_to_lab) + torch.tensor( |
| [-16.0, 0.0, 0.0] |
| ).to(srgb.dtype).to(srgb.device) |
| return torch.reshape(lab_pixels, srgb.shape) |
|
|