| """ |
| Isengard - User Tower |
| |
| Neural network that encodes a user's wine preferences from their reviewed wines. |
| Uses attention-weighted aggregation of wine embeddings based on user ratings. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from typing import Optional |
|
|
| from .config import ( |
| EMBEDDING_DIM, |
| USER_VECTOR_DIM, |
| HIDDEN_DIM, |
| ) |
|
|
|
|
| class UserTower(nn.Module): |
| """ |
| Isengard: Encodes user preferences from their reviewed wines. |
| |
| Architecture: |
| 1. Rating-weighted attention over wine embeddings |
| 2. MLP: 768 → 256 → 128 |
| 3. L2 normalization to unit sphere |
| |
| Input: |
| wine_embeddings: (batch, num_wines, 768) - embeddings of reviewed wines |
| ratings: (batch, num_wines) - user ratings for each wine |
| mask: (batch, num_wines) - optional mask for padding |
| |
| Output: |
| user_vector: (batch, 128) - normalized user embedding |
| """ |
|
|
| def __init__( |
| self, |
| embedding_dim: int = EMBEDDING_DIM, |
| hidden_dim: int = HIDDEN_DIM, |
| output_dim: int = USER_VECTOR_DIM, |
| ): |
| super().__init__() |
|
|
| self.embedding_dim = embedding_dim |
| self.output_dim = output_dim |
|
|
| |
| self.fc1 = nn.Linear(embedding_dim, hidden_dim) |
| self.fc2 = nn.Linear(hidden_dim, output_dim) |
|
|
| |
| self.dropout = nn.Dropout(0.1) |
|
|
| def forward( |
| self, |
| wine_embeddings: torch.Tensor, |
| ratings: torch.Tensor, |
| mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Forward pass through the user tower. |
| |
| Args: |
| wine_embeddings: (batch, num_wines, embedding_dim) |
| ratings: (batch, num_wines) - raw ratings (1-5 scale) |
| mask: (batch, num_wines) - 1 for valid wines, 0 for padding |
| |
| Returns: |
| user_vector: (batch, output_dim) - L2 normalized |
| """ |
| |
| |
| |
| attention_weights = (ratings - 2.5) / 2.5 |
| attention_weights = F.softmax(attention_weights, dim=-1) |
|
|
| |
| if mask is not None: |
| attention_weights = attention_weights * mask |
| |
| attention_weights = attention_weights / ( |
| attention_weights.sum(dim=-1, keepdim=True) + 1e-8 |
| ) |
|
|
| |
| |
| aggregated = torch.bmm( |
| attention_weights.unsqueeze(1), |
| wine_embeddings, |
| ).squeeze(1) |
|
|
| |
| x = F.relu(self.fc1(aggregated)) |
| x = self.dropout(x) |
| user_vector = self.fc2(x) |
|
|
| |
| user_vector = F.normalize(user_vector, p=2, dim=-1) |
|
|
| return user_vector |
|
|