""" HSL Feature Extractor Replaces PaletteFeatureExtractor (which uses nn.Embedding for token IDs) for the HSL color pipeline. Input: (B, H, W, 3) FloatTensor — HSL palette with channels [h, s, l] in [0, 1] Output: (B, H, W, D) FloatTensor — spatial features Architecture: 1. Circular hue encoding: h -> (sin(2*pi*h), cos(2*pi*h)) 2. Stack: [sin_h, cos_h, s, l] -> 4D tensor 3. Linear projection: nn.Linear(4, hidden_dim) 4. VisionTransformer: reuse existing VisionTransformer from models.vit """ import math import torch import torch.nn as nn from .vit import VisionTransformer, trunc_normal_init_ class HSLFeatureExtractor(nn.Module): """ Feature extractor for HSL color palettes. Uses circular hue encoding (sin/cos) to handle hue's circular nature (hue 0 ≈ hue 1), then projects the 4D encoded features through a linear layer and a VisionTransformer for spatial feature extraction. Args: hidden_dim: Transformer hidden dimension (default: 768) num_layers: Number of transformer layers (default: 6) num_heads: Number of attention heads (default: 8) patch_size: Patch size for ViT patchification (default: 4) dropout: Dropout probability (default: 0.1) """ def __init__( self, hidden_dim: int = 768, num_layers: int = 6, num_heads: int = 8, patch_size: int = 4, dropout: float = 0.1, ): super().__init__() self.hidden_dim = hidden_dim # Project 4D circular-encoded HSL to hidden_dim self.hsl_proj = nn.Linear(4, hidden_dim, bias=True) # Vision Transformer for spatial feature extraction self.vit = VisionTransformer( hidden_dim=hidden_dim, num_layers=num_layers, num_heads=num_heads, patch_size=patch_size, dropout=dropout, ) # Initialize hsl_proj weights with truncated normal self._init_weights() def _init_weights(self): """Initialize hsl_proj weights with truncated normal.""" std = 1.0 / math.sqrt(self.hsl_proj.in_features) trunc_normal_init_(self.hsl_proj.weight, std=std) if self.hsl_proj.bias is not None: self.hsl_proj.bias.data.zero_() def forward(self, palette_hsl: torch.Tensor) -> torch.Tensor: """ Extract spatial features from an HSL palette. Args: palette_hsl: (B, H, W, 3) FloatTensor with channels [h, s, l] in [0, 1] Returns: (B, H, W, D) FloatTensor spatial features """ # Split channels h = palette_hsl[..., 0] # (B, H, W) s = palette_hsl[..., 1] # (B, H, W) l = palette_hsl[..., 2] # (B, H, W) # Circular hue encoding — handles wraparound: hue 0 ≈ hue 1 sin_h = torch.sin(2 * math.pi * h) # (B, H, W) cos_h = torch.cos(2 * math.pi * h) # (B, H, W) # Stack into 4-channel tensor encoded = torch.stack([sin_h, cos_h, s, l], dim=-1) # (B, H, W, 4) # Project to hidden_dim embedded = self.hsl_proj(encoded) # (B, H, W, D) # Apply VisionTransformer for spatial feature extraction features = self.vit(embedded) # (B, H, W, D) return features