Jonttup commited on
Commit
41bfbd1
·
verified ·
1 Parent(s): 2f9ad67

Upload models/hsl_feature_extractor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. models/hsl_feature_extractor.py +102 -0
models/hsl_feature_extractor.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HSL Feature Extractor
3
+
4
+ Replaces PaletteFeatureExtractor (which uses nn.Embedding for token IDs)
5
+ for the HSL color pipeline.
6
+
7
+ Input: (B, H, W, 3) FloatTensor — HSL palette with channels [h, s, l] in [0, 1]
8
+ Output: (B, H, W, D) FloatTensor — spatial features
9
+
10
+ Architecture:
11
+ 1. Circular hue encoding: h -> (sin(2*pi*h), cos(2*pi*h))
12
+ 2. Stack: [sin_h, cos_h, s, l] -> 4D tensor
13
+ 3. Linear projection: nn.Linear(4, hidden_dim)
14
+ 4. VisionTransformer: reuse existing VisionTransformer from models.vit
15
+ """
16
+
17
+ import math
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from .vit import VisionTransformer, trunc_normal_init_
22
+
23
+
24
+ class HSLFeatureExtractor(nn.Module):
25
+ """
26
+ Feature extractor for HSL color palettes.
27
+
28
+ Uses circular hue encoding (sin/cos) to handle hue's circular nature
29
+ (hue 0 ≈ hue 1), then projects the 4D encoded features through a linear
30
+ layer and a VisionTransformer for spatial feature extraction.
31
+
32
+ Args:
33
+ hidden_dim: Transformer hidden dimension (default: 768)
34
+ num_layers: Number of transformer layers (default: 6)
35
+ num_heads: Number of attention heads (default: 8)
36
+ patch_size: Patch size for ViT patchification (default: 4)
37
+ dropout: Dropout probability (default: 0.1)
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ hidden_dim: int = 768,
43
+ num_layers: int = 6,
44
+ num_heads: int = 8,
45
+ patch_size: int = 4,
46
+ dropout: float = 0.1,
47
+ ):
48
+ super().__init__()
49
+
50
+ self.hidden_dim = hidden_dim
51
+
52
+ # Project 4D circular-encoded HSL to hidden_dim
53
+ self.hsl_proj = nn.Linear(4, hidden_dim, bias=True)
54
+
55
+ # Vision Transformer for spatial feature extraction
56
+ self.vit = VisionTransformer(
57
+ hidden_dim=hidden_dim,
58
+ num_layers=num_layers,
59
+ num_heads=num_heads,
60
+ patch_size=patch_size,
61
+ dropout=dropout,
62
+ )
63
+
64
+ # Initialize hsl_proj weights with truncated normal
65
+ self._init_weights()
66
+
67
+ def _init_weights(self):
68
+ """Initialize hsl_proj weights with truncated normal."""
69
+ std = 1.0 / math.sqrt(self.hsl_proj.in_features)
70
+ trunc_normal_init_(self.hsl_proj.weight, std=std)
71
+ if self.hsl_proj.bias is not None:
72
+ self.hsl_proj.bias.data.zero_()
73
+
74
+ def forward(self, palette_hsl: torch.Tensor) -> torch.Tensor:
75
+ """
76
+ Extract spatial features from an HSL palette.
77
+
78
+ Args:
79
+ palette_hsl: (B, H, W, 3) FloatTensor with channels [h, s, l] in [0, 1]
80
+
81
+ Returns:
82
+ (B, H, W, D) FloatTensor spatial features
83
+ """
84
+ # Split channels
85
+ h = palette_hsl[..., 0] # (B, H, W)
86
+ s = palette_hsl[..., 1] # (B, H, W)
87
+ l = palette_hsl[..., 2] # (B, H, W)
88
+
89
+ # Circular hue encoding — handles wraparound: hue 0 ≈ hue 1
90
+ sin_h = torch.sin(2 * math.pi * h) # (B, H, W)
91
+ cos_h = torch.cos(2 * math.pi * h) # (B, H, W)
92
+
93
+ # Stack into 4-channel tensor
94
+ encoded = torch.stack([sin_h, cos_h, s, l], dim=-1) # (B, H, W, 4)
95
+
96
+ # Project to hidden_dim
97
+ embedded = self.hsl_proj(encoded) # (B, H, W, D)
98
+
99
+ # Apply VisionTransformer for spatial feature extraction
100
+ features = self.vit(embedded) # (B, H, W, D)
101
+
102
+ return features