swirl commited on
Commit
8b733d3
·
verified ·
1 Parent(s): ea9cf67

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +298 -0
model.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Two-Tower Model
3
+
4
+ Combined model with User Tower (Isengard) and Wine Tower (Mordor).
5
+ Computes match score via dot product of normalized embeddings.
6
+
7
+ Integrates with HuggingFace Hub for model upload/download via PyTorchModelHubMixin.
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional
13
+ import io
14
+
15
+ try:
16
+ from huggingface_hub import PyTorchModelHubMixin
17
+
18
+ HAS_HF_HUB = True
19
+ except ImportError:
20
+ # Fallback for environments without huggingface_hub
21
+ PyTorchModelHubMixin = object
22
+ HAS_HF_HUB = False
23
+
24
+ from .user_tower import UserTower
25
+ from .wine_tower import WineTower
26
+ from .config import (
27
+ EMBEDDING_DIM,
28
+ USER_VECTOR_DIM,
29
+ WINE_VECTOR_DIM,
30
+ HIDDEN_DIM,
31
+ CATEGORICAL_ENCODING_DIM,
32
+ )
33
+
34
+
35
+ class TwoTowerModel(
36
+ nn.Module,
37
+ PyTorchModelHubMixin,
38
+ library_name="swirl-wine-recommendations",
39
+ tags=["recommendation", "two-tower", "wine"],
40
+ ):
41
+ """
42
+ Two-Tower Recommendation Model
43
+
44
+ Isengard (User Tower): Encodes user preferences from reviewed wines
45
+ Mordor (Wine Tower): Encodes wine characteristics
46
+
47
+ Score = dot_product(user_vector, wine_vector) * 100
48
+
49
+ Since both vectors are L2 normalized, the dot product is in [-1, 1],
50
+ which we scale to [0, 100] for match percentage.
51
+
52
+ HuggingFace Integration:
53
+ # Upload to Hub
54
+ model.push_to_hub("swirl/two-tower-recommender")
55
+
56
+ # Load from Hub
57
+ model = TwoTowerModel.from_pretrained("swirl/two-tower-recommender")
58
+ """
59
+
60
+ def __init__(
61
+ self,
62
+ embedding_dim: int = EMBEDDING_DIM,
63
+ hidden_dim: int = HIDDEN_DIM,
64
+ output_dim: int = USER_VECTOR_DIM,
65
+ categorical_dim: int = CATEGORICAL_ENCODING_DIM,
66
+ ):
67
+ super().__init__()
68
+
69
+ assert USER_VECTOR_DIM == WINE_VECTOR_DIM, "Tower output dims must match"
70
+
71
+ # Store config for serialization (required by PyTorchModelHubMixin)
72
+ self.config = {
73
+ "embedding_dim": embedding_dim,
74
+ "hidden_dim": hidden_dim,
75
+ "output_dim": output_dim,
76
+ "categorical_dim": categorical_dim,
77
+ }
78
+
79
+ self.user_tower = UserTower(
80
+ embedding_dim=embedding_dim,
81
+ hidden_dim=hidden_dim,
82
+ output_dim=output_dim,
83
+ )
84
+
85
+ self.wine_tower = WineTower(
86
+ embedding_dim=embedding_dim,
87
+ categorical_dim=categorical_dim,
88
+ hidden_dim=hidden_dim,
89
+ output_dim=output_dim,
90
+ )
91
+
92
+ def forward(
93
+ self,
94
+ user_wine_embeddings: torch.Tensor,
95
+ user_ratings: torch.Tensor,
96
+ candidate_wine_embedding: torch.Tensor,
97
+ candidate_categorical: torch.Tensor,
98
+ user_mask: Optional[torch.Tensor] = None,
99
+ ) -> torch.Tensor:
100
+ """
101
+ Forward pass computing match scores.
102
+
103
+ Args:
104
+ user_wine_embeddings: (batch, num_wines, 768)
105
+ user_ratings: (batch, num_wines)
106
+ candidate_wine_embedding: (batch, 768)
107
+ candidate_categorical: (batch, categorical_dim)
108
+ user_mask: (batch, num_wines) optional padding mask
109
+
110
+ Returns:
111
+ scores: (batch,) match scores in [0, 100]
112
+ """
113
+ # Get user embedding from reviewed wines
114
+ user_vector = self.user_tower(user_wine_embeddings, user_ratings, user_mask)
115
+
116
+ # Get wine embedding
117
+ wine_vector = self.wine_tower(candidate_wine_embedding, candidate_categorical)
118
+
119
+ # Dot product (batch dot product)
120
+ # Both vectors are normalized, so dot product is in [-1, 1]
121
+ dot_product = (user_vector * wine_vector).sum(dim=-1)
122
+
123
+ # Scale to [0, 100]
124
+ scores = (dot_product + 1) * 50
125
+
126
+ return scores
127
+
128
+ def get_user_embedding(
129
+ self,
130
+ wine_embeddings: torch.Tensor,
131
+ ratings: torch.Tensor,
132
+ mask: Optional[torch.Tensor] = None,
133
+ ) -> torch.Tensor:
134
+ """Get user embedding for caching/batch scoring."""
135
+ return self.user_tower(wine_embeddings, ratings, mask)
136
+
137
+ def get_wine_embedding(
138
+ self,
139
+ wine_embedding: torch.Tensor,
140
+ categorical_features: torch.Tensor,
141
+ ) -> torch.Tensor:
142
+ """Get wine embedding for caching/batch scoring."""
143
+ return self.wine_tower(wine_embedding, categorical_features)
144
+
145
+ def score_from_embeddings(
146
+ self,
147
+ user_vector: torch.Tensor,
148
+ wine_vector: torch.Tensor,
149
+ ) -> torch.Tensor:
150
+ """Score from pre-computed tower embeddings."""
151
+ dot_product = (user_vector * wine_vector).sum(dim=-1)
152
+ return (dot_product + 1) * 50
153
+
154
+ # =========================================================================
155
+ # LEGACY SERIALIZATION (fallback when huggingface_hub not available)
156
+ # =========================================================================
157
+
158
+ def save(self, path: str) -> None:
159
+ """Save model state dict to file."""
160
+ torch.save(
161
+ {
162
+ "state_dict": self.state_dict(),
163
+ "config": self.config,
164
+ },
165
+ path,
166
+ )
167
+
168
+ @classmethod
169
+ def load(cls, path: str) -> "TwoTowerModel":
170
+ """Load model from file."""
171
+ checkpoint = torch.load(path, map_location="cpu")
172
+ model = cls(**checkpoint["config"])
173
+ model.load_state_dict(checkpoint["state_dict"])
174
+ model.eval()
175
+ return model
176
+
177
+ def to_bytes(self) -> bytes:
178
+ """Serialize model to bytes for storage."""
179
+ buffer = io.BytesIO()
180
+ torch.save(
181
+ {
182
+ "state_dict": self.state_dict(),
183
+ "config": self.config,
184
+ },
185
+ buffer,
186
+ )
187
+ return buffer.getvalue()
188
+
189
+ @classmethod
190
+ def from_bytes(cls, data: bytes) -> "TwoTowerModel":
191
+ """Load model from bytes."""
192
+ buffer = io.BytesIO(data)
193
+ checkpoint = torch.load(buffer, map_location="cpu")
194
+ model = cls(**checkpoint["config"])
195
+ model.load_state_dict(checkpoint["state_dict"])
196
+ model.eval()
197
+ return model
198
+
199
+
200
+ # =============================================================================
201
+ # TRAINING UTILITIES (for use with HuggingFace Spaces / AutoTrain)
202
+ # =============================================================================
203
+
204
+
205
+ def create_training_script() -> str:
206
+ """
207
+ Generate a training script for HuggingFace Spaces AutoTrain.
208
+
209
+ This script can be uploaded to a HF Space for remote GPU training.
210
+
211
+ Usage:
212
+ autotrain spacerunner --project-name two-tower-training \\
213
+ --script-path script.py \\
214
+ --username swirl \\
215
+ --token $HF_TOKEN \\
216
+ --backend spaces-a10g-large
217
+ """
218
+ script = '''
219
+ """
220
+ Two-Tower Model Training Script for HuggingFace Spaces
221
+
222
+ Run with: autotrain spacerunner --script-path script.py
223
+ """
224
+
225
+ import torch
226
+ import torch.nn as nn
227
+ from torch.utils.data import DataLoader, Dataset
228
+ from huggingface_hub import login
229
+ import os
230
+
231
+ # Login to HF
232
+ login(token=os.environ.get("HF_TOKEN"))
233
+
234
+ from two_tower.model import TwoTowerModel
235
+ from two_tower.config import TRIPLET_MARGIN, LEARNING_RATE, BATCH_SIZE
236
+
237
+ class WineRecommendationDataset(Dataset):
238
+ """Dataset of (user_wines, positive_wine, negative_wine) triplets."""
239
+
240
+ def __init__(self, triplets):
241
+ self.triplets = triplets
242
+
243
+ def __len__(self):
244
+ return len(self.triplets)
245
+
246
+ def __getitem__(self, idx):
247
+ return self.triplets[idx]
248
+
249
+
250
+ def train_model(
251
+ model: TwoTowerModel,
252
+ train_loader: DataLoader,
253
+ epochs: int = 10,
254
+ lr: float = LEARNING_RATE,
255
+ ):
256
+ """Train the two-tower model using triplet loss."""
257
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
258
+ triplet_loss = nn.TripletMarginLoss(margin=TRIPLET_MARGIN)
259
+
260
+ model.train()
261
+ for epoch in range(epochs):
262
+ total_loss = 0
263
+ for batch in train_loader:
264
+ optimizer.zero_grad()
265
+
266
+ # Get embeddings
267
+ anchor = model.get_user_embedding(batch["user_wines"], batch["ratings"])
268
+ positive = model.get_wine_embedding(batch["positive_wine"], batch["positive_cat"])
269
+ negative = model.get_wine_embedding(batch["negative_wine"], batch["negative_cat"])
270
+
271
+ # Compute triplet loss
272
+ loss = triplet_loss(anchor, positive, negative)
273
+ loss.backward()
274
+ optimizer.step()
275
+
276
+ total_loss += loss.item()
277
+
278
+ print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}")
279
+
280
+ return model
281
+
282
+
283
+ if __name__ == "__main__":
284
+ # Load training data (would be fetched from your database)
285
+ # triplets = load_training_triplets()
286
+
287
+ # Create model
288
+ model = TwoTowerModel()
289
+
290
+ # Train
291
+ # train_loader = DataLoader(WineRecommendationDataset(triplets), batch_size=BATCH_SIZE)
292
+ # model = train_model(model, train_loader, epochs=10)
293
+
294
+ # Push to Hub
295
+ model.push_to_hub("swirl/two-tower-recommender")
296
+ print("Model uploaded to HuggingFace Hub!")
297
+ '''
298
+ return script