| """ |
| HSL Feature Extractor |
| |
| Replaces PaletteFeatureExtractor (which uses nn.Embedding for token IDs) |
| for the HSL color pipeline. |
| |
| Input: (B, H, W, 3) FloatTensor β HSL palette with channels [h, s, l] in [0, 1] |
| Output: (B, H, W, D) FloatTensor β spatial features |
| |
| Architecture: |
| 1. Circular hue encoding: h -> (sin(2*pi*h), cos(2*pi*h)) |
| 2. Stack: [sin_h, cos_h, s, l] -> 4D tensor |
| 3. Linear projection: nn.Linear(4, hidden_dim) |
| 4. VisionTransformer: reuse existing VisionTransformer from models.vit |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
|
|
| from .vit import VisionTransformer, trunc_normal_init_ |
|
|
|
|
| class HSLFeatureExtractor(nn.Module): |
| """ |
| Feature extractor for HSL color palettes. |
| |
| Uses circular hue encoding (sin/cos) to handle hue's circular nature |
| (hue 0 β hue 1), then projects the 4D encoded features through a linear |
| layer and a VisionTransformer for spatial feature extraction. |
| |
| Args: |
| hidden_dim: Transformer hidden dimension (default: 768) |
| num_layers: Number of transformer layers (default: 6) |
| num_heads: Number of attention heads (default: 8) |
| patch_size: Patch size for ViT patchification (default: 4) |
| dropout: Dropout probability (default: 0.1) |
| """ |
|
|
| def __init__( |
| self, |
| hidden_dim: int = 768, |
| num_layers: int = 6, |
| num_heads: int = 8, |
| patch_size: int = 4, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
|
|
| self.hidden_dim = hidden_dim |
|
|
| |
| self.hsl_proj = nn.Linear(4, hidden_dim, bias=True) |
|
|
| |
| self.vit = VisionTransformer( |
| hidden_dim=hidden_dim, |
| num_layers=num_layers, |
| num_heads=num_heads, |
| patch_size=patch_size, |
| dropout=dropout, |
| ) |
|
|
| |
| self._init_weights() |
|
|
| def _init_weights(self): |
| """Initialize hsl_proj weights with truncated normal.""" |
| std = 1.0 / math.sqrt(self.hsl_proj.in_features) |
| trunc_normal_init_(self.hsl_proj.weight, std=std) |
| if self.hsl_proj.bias is not None: |
| self.hsl_proj.bias.data.zero_() |
|
|
| def forward(self, palette_hsl: torch.Tensor) -> torch.Tensor: |
| """ |
| Extract spatial features from an HSL palette. |
| |
| Args: |
| palette_hsl: (B, H, W, 3) FloatTensor with channels [h, s, l] in [0, 1] |
| |
| Returns: |
| (B, H, W, D) FloatTensor spatial features |
| """ |
| |
| h = palette_hsl[..., 0] |
| s = palette_hsl[..., 1] |
| l = palette_hsl[..., 2] |
|
|
| |
| sin_h = torch.sin(2 * math.pi * h) |
| cos_h = torch.cos(2 * math.pi * h) |
|
|
| |
| encoded = torch.stack([sin_h, cos_h, s, l], dim=-1) |
|
|
| |
| embedded = self.hsl_proj(encoded) |
|
|
| |
| features = self.vit(embedded) |
|
|
| return features |
|
|