swirl commited on
Commit
d820920
·
verified ·
1 Parent(s): 8b733d3

Upload user_tower.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. user_tower.py +102 -0
user_tower.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Isengard - User Tower
3
+
4
+ Neural network that encodes a user's wine preferences from their reviewed wines.
5
+ Uses attention-weighted aggregation of wine embeddings based on user ratings.
6
+ """
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from typing import Optional
12
+
13
+ from .config import (
14
+ EMBEDDING_DIM,
15
+ USER_VECTOR_DIM,
16
+ HIDDEN_DIM,
17
+ )
18
+
19
+
20
+ class UserTower(nn.Module):
21
+ """
22
+ Isengard: Encodes user preferences from their reviewed wines.
23
+
24
+ Architecture:
25
+ 1. Rating-weighted attention over wine embeddings
26
+ 2. MLP: 768 → 256 → 128
27
+ 3. L2 normalization to unit sphere
28
+
29
+ Input:
30
+ wine_embeddings: (batch, num_wines, 768) - embeddings of reviewed wines
31
+ ratings: (batch, num_wines) - user ratings for each wine
32
+ mask: (batch, num_wines) - optional mask for padding
33
+
34
+ Output:
35
+ user_vector: (batch, 128) - normalized user embedding
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ embedding_dim: int = EMBEDDING_DIM,
41
+ hidden_dim: int = HIDDEN_DIM,
42
+ output_dim: int = USER_VECTOR_DIM,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.embedding_dim = embedding_dim
47
+ self.output_dim = output_dim
48
+
49
+ # MLP layers
50
+ self.fc1 = nn.Linear(embedding_dim, hidden_dim)
51
+ self.fc2 = nn.Linear(hidden_dim, output_dim)
52
+
53
+ # Dropout for regularization
54
+ self.dropout = nn.Dropout(0.1)
55
+
56
+ def forward(
57
+ self,
58
+ wine_embeddings: torch.Tensor,
59
+ ratings: torch.Tensor,
60
+ mask: Optional[torch.Tensor] = None,
61
+ ) -> torch.Tensor:
62
+ """
63
+ Forward pass through the user tower.
64
+
65
+ Args:
66
+ wine_embeddings: (batch, num_wines, embedding_dim)
67
+ ratings: (batch, num_wines) - raw ratings (1-5 scale)
68
+ mask: (batch, num_wines) - 1 for valid wines, 0 for padding
69
+
70
+ Returns:
71
+ user_vector: (batch, output_dim) - L2 normalized
72
+ """
73
+ # Convert ratings to attention weights
74
+ # Higher ratings = more attention
75
+ # Shift ratings to be positive and scale
76
+ attention_weights = (ratings - 2.5) / 2.5 # Normalize: 1→-0.6, 5→1.0
77
+ attention_weights = F.softmax(attention_weights, dim=-1)
78
+
79
+ # Apply mask if provided
80
+ if mask is not None:
81
+ attention_weights = attention_weights * mask
82
+ # Re-normalize after masking
83
+ attention_weights = attention_weights / (
84
+ attention_weights.sum(dim=-1, keepdim=True) + 1e-8
85
+ )
86
+
87
+ # Weighted aggregation: (batch, num_wines) @ (batch, num_wines, embed_dim)
88
+ # Result: (batch, embed_dim)
89
+ aggregated = torch.bmm(
90
+ attention_weights.unsqueeze(1), # (batch, 1, num_wines)
91
+ wine_embeddings, # (batch, num_wines, embed_dim)
92
+ ).squeeze(1) # (batch, embed_dim)
93
+
94
+ # MLP projection
95
+ x = F.relu(self.fc1(aggregated))
96
+ x = self.dropout(x)
97
+ user_vector = self.fc2(x)
98
+
99
+ # L2 normalize to unit sphere
100
+ user_vector = F.normalize(user_vector, p=2, dim=-1)
101
+
102
+ return user_vector