rickytato commited on
Commit
350b54b
·
1 Parent(s): f714691

Add files copied from SPACES

Browse files
Files changed (7) hide show
  1. Dockerfile +45 -0
  2. README.md +120 -0
  3. generate_model_backup.py +1110 -0
  4. generate_model_gpu.py +1265 -0
  5. requirements.txt +12 -0
  6. run.sh +39 -0
  7. users_200k.json +2 -2
Dockerfile ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
2
+
3
+ # Set environment variables
4
+ ENV PYTHONUNBUFFERED=1 \
5
+ PYTHONDONTWRITEBYTECODE=1 \
6
+ DEBIAN_FRONTEND=noninteractive
7
+
8
+ # Install system dependencies
9
+ RUN apt-get update && apt-get install -y --no-install-recommends \
10
+ python3 \
11
+ python3-pip \
12
+ python3-setuptools \
13
+ python3-dev \
14
+ build-essential \
15
+ git \
16
+ && apt-get clean \
17
+ && rm -rf /var/lib/apt/lists/*
18
+
19
+ # Create a working directory
20
+ WORKDIR /app
21
+
22
+ # Installa dipendenze
23
+ COPY requirements.txt .
24
+ RUN pip install --no-cache-dir -r requirements.txt
25
+
26
+ # Prepara directory di output con permessi corretti
27
+ RUN mkdir -p ./model_checkpoints ./model /tmp/huggingface_cache /app/embedding && \
28
+ chmod 777 ./model_checkpoints ./model /tmp/huggingface_cache /app/embedding
29
+
30
+ # Copia tutti i file necessari
31
+ COPY *.py .
32
+ COPY users.json .
33
+ COPY run.sh .
34
+
35
+ RUN chmod +x run.sh
36
+
37
+ # Crea uno script per eseguire il processo e poi copiare i file
38
+ #python generate_embeddings.py\n\
39
+ # Esegui lo script
40
+
41
+ # Esponi la porta per l'interfaccia web
42
+ EXPOSE 7860
43
+
44
+ # Set the entrypoint
45
+ ENTRYPOINT ["./run.sh"]
README.md ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Test01
3
+ emoji: 🐨
4
+ colorFrom: gray
5
+ colorTo: red
6
+ sdk: docker
7
+ app_port: 7860
8
+ startup_duration_timeout: 60m
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+ # User Embedding Model
13
+
14
+ This repository contains a PyTorch model for generating user embeddings based on DMP (Data Management Platform) data. The model creates dense vector representations of users that can be used for recommendation systems, user clustering, and similarity searches.
15
+
16
+ ## Quick Start with Docker
17
+
18
+ To run the model using Docker:
19
+
20
+ ```bash
21
+ docker build -t user-embedding-model .
22
+
23
+ docker run -v /path/to/your/data:/app/data \
24
+ -e DATA_PATH=/app/data/users.json \
25
+ -e NUM_EPOCHS=10 \
26
+ -e BATCH_SIZE=32 \
27
+ -v /path/to/output:/app/embeddings_output \
28
+ user-embedding-model
29
+ ```
30
+
31
+ ## Pushing to Hugging Face
32
+
33
+ To automatically push the model to Hugging Face, add your credentials:
34
+
35
+ ```bash
36
+ docker run -v /path/to/your/data:/app/data \
37
+ -e DATA_PATH=/app/data/users.json \
38
+ -e HF_REPO_ID="your-username/your-model-name" \
39
+ -e HF_TOKEN="your-huggingface-token" \
40
+ -v /path/to/output:/app/embeddings_output \
41
+ user-embedding-model
42
+ ```
43
+
44
+ ## Input Data Format
45
+
46
+ The model expects user data in JSON format, with each user having DMP fields like:
47
+
48
+ ```json
49
+ {
50
+ "dmp": {
51
+ "city": "milano",
52
+ "domains": ["example.com"],
53
+ "brands": ["brand1", "brand2"],
54
+ "clusters": ["cluster1", "cluster2"],
55
+ "industries": ["industry1"],
56
+ "tags": ["tag1", "tag2"],
57
+ "channels": ["channel1"],
58
+ "~click__host": "host1",
59
+ "~click__domain": "domain1",
60
+ "": {
61
+ "id": "user123"
62
+ }
63
+ }
64
+ }
65
+ ```
66
+
67
+ ## Environment Variables
68
+
69
+ - `DATA_PATH`: Path to your input JSON file (default: "users.json")
70
+ - `NUM_EPOCHS`: Number of training epochs (default: 10)
71
+ - `BATCH_SIZE`: Batch size for training (default: 32)
72
+ - `LEARNING_RATE`: Learning rate for optimizer (default: 0.001)
73
+ - `SAVE_INTERVAL`: Save checkpoint every N epochs (default: 2)
74
+ - `HF_REPO_ID`: Hugging Face repository ID for uploading
75
+ - `HF_TOKEN`: Hugging Face API token
76
+
77
+ ## Output
78
+
79
+ The model generates:
80
+
81
+ 1. `embeddings.json`: User embeddings in JSON format
82
+ 2. `embeddings.npz`: User embeddings in NumPy format
83
+ 3. `vocabularies.json`: Vocabulary mappings
84
+ 4. `model.pth`: Trained PyTorch model
85
+ 5. `model_config.json`: Model configuration
86
+ 6. Hugging Face-compatible model files in the "huggingface" subdirectory
87
+
88
+ ## Hardware Requirements
89
+
90
+ - **Recommended**: NVIDIA GPU with CUDA support
91
+ - The code uses parallel processing for triplet generation to utilize all available CPU cores
92
+ - For L40S GPU, recommended batch size: 32-64
93
+ - Memory requirement: At least 8GB RAM
94
+
95
+ ## Model Architecture
96
+
97
+ The model consists of:
98
+
99
+ - Embedding layers for each user field
100
+ - Sequential fully connected layers for dimensionality reduction
101
+ - Output dimension: 256 (configurable)
102
+ - Training method: Triplet margin loss using similar/dissimilar users
103
+
104
+ ## Performance Optimization
105
+
106
+ The code includes several optimizations:
107
+
108
+ - Parallel triplet generation using all available CPU cores
109
+ - GPU acceleration for model training
110
+ - Efficient memory handling for large datasets
111
+ - TensorBoard integration for monitoring training
112
+
113
+ ## Troubleshooting
114
+
115
+ If you encounter issues:
116
+
117
+ 1. Check that your input data follows the expected format
118
+ 2. Ensure you have sufficient memory for your dataset size
119
+ 3. For GPU errors, try reducing batch size
120
+ 4. Check the logs for detailed error messages
generate_model_backup.py ADDED
@@ -0,0 +1,1110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
3
+ #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
4
+ import multiprocessing
5
+ try:
6
+ multiprocessing.set_start_method('spawn')
7
+ except RuntimeError:
8
+ pass # Il metodo è già stato impostato
9
+ import json
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ import pandas as pd
14
+ from typing import List, Dict
15
+ import logging
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset, DataLoader
18
+ import random
19
+ from collections import defaultdict
20
+ from pathlib import Path
21
+ from tqdm import tqdm
22
+ import os
23
+ import multiprocessing
24
+ from multiprocessing import Pool
25
+ import psutil
26
+ import argparse
27
+
28
+ # =================================
29
+ # CONFIGURABLE PARAMETERS
30
+ # =================================
31
+ # Define default parameters that can be overridden via environment variables
32
+ DEFAULT_NUM_TRIPLETS = 150000 # Number of triplet examples to generate
33
+ DEFAULT_NUM_EPOCHS = 20 # Number of training epochs
34
+ DEFAULT_BATCH_SIZE = 64 # Batch size for training
35
+ DEFAULT_LEARNING_RATE = 0.001 # Learning rate for optimizer
36
+ DEFAULT_OUTPUT_DIM = 256 # Output dimension of embeddings
37
+ DEFAULT_MAX_SEQ_LENGTH = 15 # Maximum sequence length
38
+ DEFAULT_SAVE_INTERVAL = 2 # Save checkpoint every N epochs
39
+ DEFAULT_DATA_PATH = "./users.json" # Path to user data
40
+ DEFAULT_OUTPUT_DIR = "./model" # Output directory
41
+
42
+ # Read parameters from environment variables
43
+ NUM_TRIPLETS = int(os.environ.get("NUM_TRIPLETS", DEFAULT_NUM_TRIPLETS))
44
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", DEFAULT_NUM_EPOCHS))
45
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", DEFAULT_BATCH_SIZE))
46
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", DEFAULT_LEARNING_RATE))
47
+ OUTPUT_DIM = int(os.environ.get("OUTPUT_DIM", DEFAULT_OUTPUT_DIM))
48
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", DEFAULT_MAX_SEQ_LENGTH))
49
+ SAVE_INTERVAL = int(os.environ.get("SAVE_INTERVAL", DEFAULT_SAVE_INTERVAL))
50
+ DATA_PATH = os.environ.get("DATA_PATH", DEFAULT_DATA_PATH)
51
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", DEFAULT_OUTPUT_DIR)
52
+
53
+ # Configure logging
54
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
55
+
56
+ # Get CUDA device
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ logging.info(f"Using device: {device}")
59
+
60
+ # =================================
61
+ # MODEL ARCHITECTURE
62
+ # =================================
63
+ class UserEmbeddingModel(nn.Module):
64
+ def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int],
65
+ output_dim: int = 256, max_sequence_length: int = 15,
66
+ padded_fields_length: int = 10):
67
+ super().__init__()
68
+
69
+ self.max_sequence_length = max_sequence_length
70
+ self.padded_fields_length = padded_fields_length
71
+ self.padded_fields = {'dmp_channels', 'dmp_tags', 'dmp_clusters'}
72
+ self.embedding_layers = nn.ModuleDict()
73
+
74
+ # Create embedding layers for each field
75
+ for field, vocab_size in vocab_sizes.items():
76
+ self.embedding_layers[field] = nn.Embedding(
77
+ vocab_size,
78
+ embedding_dims.get(field, 16),
79
+ padding_idx=0
80
+ )
81
+
82
+ # Calculate total input dimension
83
+ self.total_input_dim = 0
84
+ for field, dim in embedding_dims.items():
85
+ if field in self.padded_fields:
86
+ self.total_input_dim += dim # Single dimension for padded field
87
+ else:
88
+ self.total_input_dim += dim
89
+
90
+ print(f"Total input dimension: {self.total_input_dim}")
91
+
92
+ self.fc = nn.Sequential(
93
+ nn.Linear(self.total_input_dim, self.total_input_dim // 2),
94
+ nn.ReLU(),
95
+ nn.Dropout(0.2),
96
+ nn.Linear(self.total_input_dim // 2, output_dim),
97
+ nn.LayerNorm(output_dim)
98
+ )
99
+
100
+ def _process_sequence(self, embedding_layer: nn.Embedding, indices: torch.Tensor,
101
+ field_name: str) -> torch.Tensor:
102
+ """Process normal sequences"""
103
+ batch_size = indices.size(0)
104
+ if indices.numel() == 0:
105
+ return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device)
106
+
107
+ if field_name in ['dmp_city', 'dmp_domains']:
108
+ if indices.dim() == 1:
109
+ indices = indices.unsqueeze(0)
110
+ if indices.size(1) > 0:
111
+ return embedding_layer(indices[:, 0])
112
+ return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device)
113
+
114
+ # Handle multiple sequences
115
+ embeddings = embedding_layer(indices)
116
+ return embeddings.mean(dim=1) # [batch_size, emb_dim]
117
+
118
+ def _process_padded_sequence(self, embedding_layer: nn.Embedding,
119
+ indices: torch.Tensor) -> torch.Tensor:
120
+ """Process sequences with padding"""
121
+ batch_size = indices.size(0)
122
+ emb_dim = embedding_layer.embedding_dim
123
+
124
+ # Generate embeddings
125
+ embeddings = embedding_layer(indices) # [batch_size, seq_len, emb_dim]
126
+
127
+ # Average along sequence dimension
128
+ mask = (indices != 0).float().unsqueeze(-1)
129
+ masked_embeddings = embeddings * mask
130
+ sum_mask = mask.sum(dim=1).clamp(min=1.0)
131
+
132
+ return (masked_embeddings.sum(dim=1) / sum_mask) # [batch_size, emb_dim]
133
+
134
+ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
135
+ batch_embeddings = []
136
+
137
+ for field in ['dmp_city', 'dmp_domains', 'dmp_brands',
138
+ 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels',
139
+ 'link_host', 'link_path']:
140
+ if field in inputs and field in self.embedding_layers:
141
+ if field in self.padded_fields:
142
+ emb = self._process_padded_sequence(
143
+ self.embedding_layers[field],
144
+ inputs[field]
145
+ )
146
+ else:
147
+ emb = self._process_sequence(
148
+ self.embedding_layers[field],
149
+ inputs[field],
150
+ field
151
+ )
152
+ batch_embeddings.append(emb)
153
+
154
+ combined = torch.cat(batch_embeddings, dim=1)
155
+ return self.fc(combined)
156
+
157
+ # =================================
158
+ # EMBEDDING PIPELINE
159
+ # =================================
160
+ class UserEmbeddingPipeline:
161
+ def __init__(self, output_dim: int = 256, max_sequence_length: int = 15):
162
+ self.output_dim = output_dim
163
+ self.max_sequence_length = max_sequence_length
164
+ self.model = None
165
+ self.vocab_maps = {}
166
+
167
+ self.fields = [
168
+ 'dmp_city', 'dmp_domains', 'dmp_brands',
169
+ 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels',
170
+ 'link_host', 'link_path'
171
+ ]
172
+
173
+ # Map of new JSON fields to old field names used in the model
174
+ self.field_mapping = {
175
+ 'dmp_city': ('dmp', 'city'),
176
+ 'dmp_domains': ('dmp', 'domains'),
177
+ 'dmp_brands': ('dmp', 'brands'),
178
+ 'dmp_clusters': ('dmp', 'clusters'),
179
+ 'dmp_industries': ('dmp', 'industries'),
180
+ 'dmp_tags': ('dmp', 'tags'),
181
+ 'dmp_channels': ('dmp', 'channels'),
182
+ 'link_host': ('dmp', '~click__host'),
183
+ 'link_path': ('dmp', '~click__domain')
184
+ }
185
+
186
+ self.embedding_dims = {
187
+ 'dmp_city': 16,
188
+ 'dmp_domains': 16,
189
+ 'dmp_brands': 32,
190
+ 'dmp_clusters': 32,
191
+ 'dmp_industries': 32,
192
+ 'dmp_tags': 64,
193
+ 'dmp_channels': 32,
194
+ 'link_host': 32,
195
+ 'link_path': 32
196
+ }
197
+
198
+ def _clean_value(self, value):
199
+ if isinstance(value, float) and np.isnan(value):
200
+ return []
201
+ if isinstance(value, str):
202
+ return [value.lower().strip()]
203
+ if isinstance(value, list):
204
+ return [str(v).lower().strip() for v in value if v is not None and str(v).strip()]
205
+ return []
206
+
207
+ def _get_field_from_user(self, user, field):
208
+ """Extract field value from new JSON user format"""
209
+ mapping = self.field_mapping.get(field, (field,))
210
+ value = user
211
+
212
+ # Navigate through nested structure
213
+ for key in mapping:
214
+ if isinstance(value, dict):
215
+ value = value.get(key, {})
216
+ else:
217
+ # If not a dictionary and we're not at the last element
218
+ # of the mapping, return an empty list
219
+ value = []
220
+ break
221
+
222
+ # If we've reached the end and have a value that's not a list but should be,
223
+ # convert it to a list
224
+ if field in {'dmp_brands', 'dmp_channels', 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'link_host', 'link_path'} and not isinstance(value, list):
225
+ # If it's a string or other single value, put it in a list
226
+ if value and not isinstance(value, dict):
227
+ value = [value]
228
+ else:
229
+ value = []
230
+
231
+ return value
232
+
233
+ def build_vocabularies(self, users_data: List[Dict]) -> Dict[str, Dict[str, int]]:
234
+ field_values = {field: {'<PAD>'} for field in self.fields}
235
+
236
+ # Extract the 'user' field from the JSON structure for each record
237
+ users = []
238
+ for data in users_data:
239
+ # Check if there's raw_json.user
240
+ if 'raw_json' in data and 'user' in data['raw_json']:
241
+ users.append(data['raw_json']['user'])
242
+ # Check if there's user
243
+ elif 'user' in data:
244
+ users.append(data['user'])
245
+ else:
246
+ users.append(data) # Assume it's already a user
247
+
248
+ for user in users:
249
+ for field in self.fields:
250
+ values = self._clean_value(self._get_field_from_user(user, field))
251
+ field_values[field].update(values)
252
+
253
+ self.vocab_maps = {
254
+ field: {val: idx for idx, val in enumerate(sorted(values))}
255
+ for field, values in field_values.items()
256
+ }
257
+
258
+ return self.vocab_maps
259
+
260
+ def _prepare_input(self, user: Dict) -> Dict[str, torch.Tensor]:
261
+ inputs = {}
262
+
263
+ for field in self.fields:
264
+ values = self._clean_value(self._get_field_from_user(user, field))
265
+ vocab = self.vocab_maps[field]
266
+ indices = [vocab.get(val, 0) for val in values]
267
+ inputs[field] = torch.tensor(indices, dtype=torch.long)
268
+
269
+ return inputs
270
+
271
+ def initialize_model(self) -> None:
272
+ vocab_sizes = {field: len(vocab) for field, vocab in self.vocab_maps.items()}
273
+
274
+ self.model = UserEmbeddingModel(
275
+ vocab_sizes=vocab_sizes,
276
+ embedding_dims=self.embedding_dims,
277
+ output_dim=self.output_dim,
278
+ max_sequence_length=self.max_sequence_length
279
+ )
280
+ self.model.to(device)
281
+ self.model.eval()
282
+
283
+ def generate_embeddings(self, users_data: List[Dict], batch_size: int = 32) -> Dict[str, np.ndarray]:
284
+ """Generate embeddings for all users"""
285
+ embeddings = {}
286
+ self.model.eval() # Make sure model is in eval mode
287
+
288
+ # Extract the 'user' field from the JSON structure for each record
289
+ users = []
290
+ user_ids = []
291
+
292
+ for data in users_data:
293
+ # Check if there's raw_json.user
294
+ if 'raw_json' in data and 'user' in data['raw_json']:
295
+ user = data['raw_json']['user']
296
+ users.append(user)
297
+ # Use user.dmp[''].id as identifier
298
+ if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']:
299
+ user_ids.append(str(user['dmp']['']['id']))
300
+ else:
301
+ # Fallback to uid or id if dmp.id is not available
302
+ user_ids.append(str(user.get('uid', user.get('id', None))))
303
+ # Check if there's user
304
+ elif 'user' in data:
305
+ user = data['user']
306
+ users.append(user)
307
+ # Use user.dmp[''].id as identifier
308
+ if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']:
309
+ user_ids.append(str(user['dmp']['']['id']))
310
+ else:
311
+ # Fallback to uid or id if dmp.id is not available
312
+ user_ids.append(str(user.get('uid', user.get('id', None))))
313
+ else:
314
+ users.append(data) # Assume it's already a user
315
+ # Use user.dmp[''].id as identifier
316
+ if 'dmp' in data and '' in data['dmp'] and 'id' in data['dmp']['']:
317
+ user_ids.append(str(data['dmp']['']['id']))
318
+ else:
319
+ # Fallback to uid or id if dmp.id is not available
320
+ user_ids.append(str(data.get('uid', data.get('id', None))))
321
+
322
+ with torch.no_grad():
323
+ for i in tqdm(range(0, len(users), batch_size), desc="Generating embeddings"):
324
+ batch_users = users[i:i+batch_size]
325
+ batch_ids = user_ids[i:i+batch_size]
326
+ batch_inputs = []
327
+ valid_indices = []
328
+
329
+ for j, user in enumerate(batch_users):
330
+ if batch_ids[j] is not None:
331
+ batch_inputs.append(self._prepare_input(user))
332
+ valid_indices.append(j)
333
+
334
+ if batch_inputs:
335
+ # Use the same collate function as training for a single batch
336
+ anchor_batch, _, _ = collate_batch([(inputs, inputs, inputs) for inputs in batch_inputs])
337
+
338
+ # Move data to device
339
+ anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()}
340
+
341
+ # Generate embeddings
342
+ batch_embeddings = self.model(anchor_batch).cpu()
343
+
344
+ # Save embeddings
345
+ for j, idx in enumerate(valid_indices):
346
+ if batch_ids[idx]: # Verify that id is not None or empty
347
+ embeddings[batch_ids[idx]] = batch_embeddings[j].numpy()
348
+
349
+ return embeddings
350
+
351
+ def save_embeddings(self, embeddings: Dict[str, np.ndarray], output_dir: str) -> None:
352
+ """Save embeddings to file"""
353
+ output_dir = Path(output_dir)
354
+ output_dir.mkdir(exist_ok=True)
355
+
356
+ # Save embeddings as JSON
357
+ json_path = output_dir / 'embeddings.json'
358
+ with open(json_path, 'w') as f:
359
+ json_embeddings = {user_id: emb.tolist() for user_id, emb in embeddings.items()}
360
+ json.dump(json_embeddings, f)
361
+
362
+ # Save embeddings as NPY
363
+ npy_path = output_dir / 'embeddings.npz'
364
+ np.savez_compressed(npy_path,
365
+ embeddings=np.stack(list(embeddings.values())),
366
+ user_ids=np.array(list(embeddings.keys())))
367
+
368
+ # Save vocabularies
369
+ vocab_path = output_dir / 'vocabularies.json'
370
+ with open(vocab_path, 'w') as f:
371
+ json.dump(self.vocab_maps, f)
372
+
373
+ logging.info(f"\nEmbeddings saved in {output_dir}:")
374
+ logging.info(f"- Embeddings JSON: {json_path}")
375
+ logging.info(f"- Embeddings NPY: {npy_path}")
376
+ logging.info(f"- Vocabularies: {vocab_path}")
377
+
378
+ def save_model(self, output_dir: str) -> None:
379
+ """Save model in PyTorch format (.pth)"""
380
+ output_dir = Path(output_dir)
381
+ output_dir.mkdir(exist_ok=True)
382
+
383
+ # Save path
384
+ model_path = output_dir / 'model.pth'
385
+
386
+ # Prepare dictionary with model state and metadata
387
+ checkpoint = {
388
+ 'model_state_dict': self.model.state_dict(),
389
+ 'vocab_maps': self.vocab_maps,
390
+ 'embedding_dims': self.embedding_dims,
391
+ 'output_dim': self.output_dim,
392
+ 'max_sequence_length': self.max_sequence_length
393
+ }
394
+
395
+ # Save model
396
+ torch.save(checkpoint, model_path)
397
+
398
+ logging.info(f"Model saved to: {model_path}")
399
+
400
+ # Also save a configuration file for reference
401
+ config_info = {
402
+ 'model_type': 'UserEmbeddingModel',
403
+ 'vocab_sizes': {field: len(vocab) for field, vocab in self.vocab_maps.items()},
404
+ 'embedding_dims': self.embedding_dims,
405
+ 'output_dim': self.output_dim,
406
+ 'max_sequence_length': self.max_sequence_length,
407
+ 'padded_fields': list(self.model.padded_fields),
408
+ 'fields': self.fields
409
+ }
410
+
411
+ config_path = output_dir / 'model_config.json'
412
+ with open(config_path, 'w') as f:
413
+ json.dump(config_info, f, indent=2)
414
+
415
+ logging.info(f"Model configuration saved to: {config_path}")
416
+
417
+ # Save model in HuggingFace format
418
+ hf_dir = output_dir / 'huggingface'
419
+ hf_dir.mkdir(exist_ok=True)
420
+
421
+ # Save model in HF format
422
+ torch.save(self.model.state_dict(), hf_dir / 'pytorch_model.bin')
423
+
424
+ # Save config
425
+ with open(hf_dir / 'config.json', 'w') as f:
426
+ json.dump(config_info, f, indent=2)
427
+
428
+ logging.info(f"Model saved in HuggingFace format to: {hf_dir}")
429
+
430
+ def load_model(self, model_path: str) -> None:
431
+ """Load a previously saved model"""
432
+ checkpoint = torch.load(model_path, map_location=device)
433
+
434
+ # Reload vocabularies and dimensions if needed
435
+ self.vocab_maps = checkpoint.get('vocab_maps', self.vocab_maps)
436
+ self.embedding_dims = checkpoint.get('embedding_dims', self.embedding_dims)
437
+ self.output_dim = checkpoint.get('output_dim', self.output_dim)
438
+ self.max_sequence_length = checkpoint.get('max_sequence_length', self.max_sequence_length)
439
+
440
+ # Initialize model if not already done
441
+ if self.model is None:
442
+ self.initialize_model()
443
+
444
+ # Load model weights
445
+ self.model.load_state_dict(checkpoint['model_state_dict'])
446
+ self.model.to(device)
447
+ self.model.eval()
448
+
449
+ logging.info(f"Model loaded from: {model_path}")
450
+
451
+ # =================================
452
+ # SIMILARITY AND TRIPLET GENERATION
453
+ # =================================
454
+ def calculate_similarity(user1, user2, pipeline):
455
+ try:
456
+ channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None)
457
+ channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None)
458
+ clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None)
459
+ clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None)
460
+
461
+ channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2))
462
+ cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | clusters2))
463
+
464
+ return 0.5 * channel_sim + 0.5 * cluster_sim
465
+ except Exception as e:
466
+ logging.error(f"Error calculating similarity: {str(e)}")
467
+ return 0.0
468
+
469
+ def process_batch_triplets(args):
470
+ try:
471
+ batch_idx, users, channel_index, cluster_index, num_triplets, pipeline = args
472
+ batch_triplets = []
473
+
474
+ # Forza l'uso della CPU per tutti i calcoli
475
+ with torch.no_grad():
476
+ # Imposta temporaneamente il dispositivo su CPU per il calcolo delle similarità
477
+ temp_device = torch.device("cpu")
478
+
479
+ for _ in range(num_triplets):
480
+ anchor_idx = random.randint(0, len(users)-1)
481
+ anchor_user = users[anchor_idx]
482
+
483
+ # Find candidates that share channels or clusters
484
+ candidates = set()
485
+ for channel in pipeline._get_field_from_user(anchor_user, 'dmp_channels'):
486
+ candidates.update(channel_index.get(str(channel), []))
487
+ for cluster in pipeline._get_field_from_user(anchor_user, 'dmp_clusters'):
488
+ candidates.update(cluster_index.get(str(cluster), []))
489
+
490
+ # Remove anchor
491
+ candidates.discard(anchor_idx)
492
+
493
+ # Find positive (similar) user
494
+ if not candidates:
495
+ positive_idx = random.randint(0, len(users)-1)
496
+ else:
497
+ # Calculate similarities for candidates
498
+ similarities = []
499
+ for idx in candidates:
500
+ # Calcolo della similarità senza CUDA
501
+ sim = cpu_calculate_similarity(anchor_user, users[idx], pipeline)
502
+ if sim > 0:
503
+ similarities.append((idx, sim))
504
+
505
+ if not similarities:
506
+ positive_idx = random.randint(0, len(users)-1)
507
+ else:
508
+ # Sort by similarity
509
+ similarities.sort(key=lambda x: x[1], reverse=True)
510
+ # Return one of the top K most similar
511
+ top_k = min(10, len(similarities))
512
+ positive_idx = similarities[random.randint(0, top_k-1)][0]
513
+
514
+ # Find negative (dissimilar) user
515
+ max_attempts = 50
516
+ negative_idx = None
517
+
518
+ for _ in range(max_attempts):
519
+ idx = random.randint(0, len(users)-1)
520
+ if idx != anchor_idx and idx != positive_idx:
521
+ # Calcolo della similarità senza CUDA
522
+ if cpu_calculate_similarity(anchor_user, users[idx], pipeline) < 0.1:
523
+ negative_idx = idx
524
+ break
525
+
526
+ if negative_idx is None:
527
+ negative_idx = random.randint(0, len(users)-1)
528
+
529
+ batch_triplets.append((anchor_idx, positive_idx, negative_idx))
530
+
531
+ return batch_triplets
532
+ except Exception as e:
533
+ logging.error(f"Error in batch triplet generation: {str(e)}")
534
+ return []
535
+
536
+ # Versione CPU della funzione di calcolo similarità
537
+ def cpu_calculate_similarity(user1, user2, pipeline):
538
+ try:
539
+ channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None)
540
+ channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None)
541
+ clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None)
542
+ clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None)
543
+
544
+ channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2))
545
+ cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | clusters2))
546
+
547
+ return 0.5 * channel_sim + 0.5 * cluster_sim
548
+ except Exception as e:
549
+ logging.error(f"Error calculating similarity: {str(e)}")
550
+ return 0.0
551
+
552
+ # =================================
553
+ # DATASET AND DATALOADER
554
+ # =================================
555
+ class UserSimilarityDataset(Dataset):
556
+ def __init__(self, pipeline, users_data, num_triplets=10, num_workers=None):
557
+ self.triplets = []
558
+ logging.info("Initializing UserSimilarityDataset...")
559
+
560
+ # Extract the 'user' field from the JSON structure for each record
561
+ self.users = []
562
+ for data in users_data:
563
+ # Check if there's raw_json.user
564
+ if 'raw_json' in data and 'user' in data['raw_json']:
565
+ self.users.append(data['raw_json']['user'])
566
+ # Check if there's user
567
+ elif 'user' in data:
568
+ self.users.append(data['user'])
569
+ else:
570
+ self.users.append(data) # Assume it's already a user
571
+
572
+ self.pipeline = pipeline
573
+ self.num_triplets = num_triplets
574
+
575
+ # Determine number of workers for parallel processing
576
+ if num_workers is None:
577
+ num_workers = max(1, min(8, os.cpu_count()))
578
+ self.num_workers = num_workers
579
+
580
+ # Pre-process inputs for each user
581
+ self.preprocessed_inputs = {}
582
+ for idx, user in enumerate(self.users):
583
+ self.preprocessed_inputs[idx] = pipeline._prepare_input(user)
584
+
585
+ logging.info("Creating indexes for channels and clusters...")
586
+ self.channel_index = defaultdict(list)
587
+ self.cluster_index = defaultdict(list)
588
+
589
+ for idx, user in enumerate(self.users):
590
+ channels = pipeline._get_field_from_user(user, 'dmp_channels')
591
+ clusters = pipeline._get_field_from_user(user, 'dmp_clusters')
592
+
593
+ if channels:
594
+ channels = [str(c) for c in channels if c is not None]
595
+ if clusters:
596
+ clusters = [str(c) for c in clusters if c is not None]
597
+
598
+ for channel in channels:
599
+ self.channel_index[channel].append(idx)
600
+ for cluster in clusters:
601
+ self.cluster_index[cluster].append(idx)
602
+
603
+ logging.info(f"Found {len(self.channel_index)} unique channels and {len(self.cluster_index)} unique clusters")
604
+ logging.info(f"Generating triplets using {self.num_workers} worker processes...")
605
+
606
+ self.triplets = self._generate_triplets_gpu(num_triplets)
607
+ logging.info(f"Generated {len(self.triplets)} triplets")
608
+
609
+ # Verifica che questo metodo sia definito correttamente
610
+ def __len__(self):
611
+ return len(self.triplets)
612
+
613
+ def __getitem__(self, idx):
614
+ if idx >= len(self.triplets):
615
+ raise IndexError(f"Index {idx} out of range for dataset with {len(self.triplets)} triplets")
616
+
617
+ anchor_idx, positive_idx, negative_idx = self.triplets[idx]
618
+ return (
619
+ self.preprocessed_inputs[anchor_idx],
620
+ self.preprocessed_inputs[positive_idx],
621
+ self.preprocessed_inputs[negative_idx]
622
+ )
623
+
624
+
625
+
626
+
627
+ def _generate_triplets_gpu(self, num_triplets):
628
+ """Generate triplets using a more reliable approach with batch processing"""
629
+ logging.info("Generating triplets with batch approach...")
630
+
631
+ triplets = []
632
+ batch_size = 10 # Numero di triplette da generare per batch
633
+ num_batches = (num_triplets + batch_size - 1) // batch_size
634
+
635
+ progress_bar = tqdm(
636
+ range(num_batches),
637
+ desc="Generating triplet batches",
638
+ bar_format='{l_bar}{bar:10}{r_bar}'
639
+ )
640
+
641
+ for _ in progress_bar:
642
+ batch_triplets = []
643
+
644
+ # Genera un batch di triplette
645
+ for i in range(batch_size):
646
+ if len(triplets) >= num_triplets:
647
+ break
648
+
649
+ # Seleziona anchor casuale
650
+ anchor_idx = random.randint(0, len(self.users)-1)
651
+ anchor_user = self.users[anchor_idx]
652
+
653
+ # Trova candidati che condividono channels o clusters
654
+ candidates = set()
655
+ for channel in self.pipeline._get_field_from_user(anchor_user, 'dmp_channels'):
656
+ candidates.update(self.channel_index.get(str(channel), []))
657
+ for cluster in self.pipeline._get_field_from_user(anchor_user, 'dmp_clusters'):
658
+ candidates.update(self.cluster_index.get(str(cluster), []))
659
+
660
+ # Rimuovi l'anchor dai candidati
661
+ candidates.discard(anchor_idx)
662
+
663
+ # Trova esempio positivo
664
+ if candidates:
665
+ similarities = []
666
+ for idx in list(candidates)[:50]: # Limita la ricerca ai primi 50 candidati
667
+ sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline)
668
+ if sim > 0:
669
+ similarities.append((idx, sim))
670
+
671
+ if similarities:
672
+ similarities.sort(key=lambda x: x[1], reverse=True)
673
+ top_k = min(10, len(similarities))
674
+ positive_idx = similarities[random.randint(0, top_k-1)][0]
675
+ else:
676
+ positive_idx = random.randint(0, len(self.users)-1)
677
+ else:
678
+ positive_idx = random.randint(0, len(self.users)-1)
679
+
680
+ # Trova esempio negativo
681
+ attempts = 0
682
+ negative_idx = None
683
+
684
+ while attempts < 20 and negative_idx is None: # Ridotto a 20 tentativi
685
+ idx = random.randint(0, len(self.users)-1)
686
+ if idx != anchor_idx and idx != positive_idx:
687
+ sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline)
688
+ if sim < 0.1:
689
+ negative_idx = idx
690
+ break
691
+ attempts += 1
692
+
693
+ if negative_idx is None:
694
+ negative_idx = random.randint(0, len(self.users)-1)
695
+
696
+ batch_triplets.append((anchor_idx, positive_idx, negative_idx))
697
+
698
+ triplets.extend(batch_triplets)
699
+
700
+ return triplets[:num_triplets] # Assicurati di restituire esattamente num_triplets
701
+
702
+ def collate_batch(batch):
703
+ """Custom collate function to properly handle tensor dimensions"""
704
+ anchor_inputs, positive_inputs, negative_inputs = zip(*batch)
705
+
706
+ def process_group_inputs(group_inputs):
707
+ processed = {}
708
+ for field in group_inputs[0].keys():
709
+ # Find maximum length for this field in the batch
710
+ max_len = max(inputs[field].size(0) for inputs in group_inputs)
711
+
712
+ # Create padded tensors
713
+ padded = torch.stack([
714
+ torch.cat([
715
+ inputs[field],
716
+ torch.zeros(max_len - inputs[field].size(0), dtype=torch.long)
717
+ ]) if inputs[field].size(0) < max_len else inputs[field][:max_len]
718
+ for inputs in group_inputs
719
+ ])
720
+
721
+ processed[field] = padded
722
+
723
+ return processed
724
+
725
+ # Process each group (anchor, positive, negative)
726
+ anchor_batch = process_group_inputs(anchor_inputs)
727
+ positive_batch = process_group_inputs(positive_inputs)
728
+ negative_batch = process_group_inputs(negative_inputs)
729
+
730
+ return anchor_batch, positive_batch, negative_batch
731
+
732
+ # =================================
733
+ # TRAINING FUNCTION
734
+ # =================================
735
+
736
+ def train_user_embeddings(model, users_data, pipeline, num_epochs=10, batch_size=32, lr=0.001, save_dir=None, save_interval=2, num_triplets=150):
737
+ """Main training of the model with proper batch handling and incremental saving"""
738
+ model.train()
739
+ model.to(device)
740
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
741
+
742
+ # Add learning rate scheduler
743
+ scheduler = torch.optim.lr_scheduler.StepLR(
744
+ optimizer,
745
+ step_size=2, # Decay every 2 epochs
746
+ gamma=0.9 # Multiply by 0.9 (10% reduction)
747
+ )
748
+
749
+ # Determine number of CPU cores to use
750
+ num_cpu_cores = max(1, min(32, os.cpu_count()))
751
+ logging.info(f"Using {num_cpu_cores} CPU cores for data processing")
752
+
753
+ # Prepare dataset and dataloader with custom collate function
754
+ dataset = UserSimilarityDataset(
755
+ pipeline,
756
+ users_data,
757
+ num_triplets=num_triplets,
758
+ num_workers=num_cpu_cores
759
+ )
760
+
761
+ dataloader = DataLoader(
762
+ dataset,
763
+ batch_size=batch_size,
764
+ shuffle=True,
765
+ collate_fn=collate_batch,
766
+ num_workers=0, # For loading batches
767
+ pin_memory=True # Speed up data transfer to GPU
768
+ )
769
+
770
+ # Loss function
771
+ criterion = torch.nn.TripletMarginLoss(margin=1.0)
772
+
773
+ # Progress bar for epochs
774
+ epoch_pbar = tqdm(
775
+ range(num_epochs),
776
+ desc="Training Progress",
777
+ bar_format='{l_bar}{bar:10}{r_bar}'
778
+ )
779
+ # Set up tensorboard for logging
780
+ try:
781
+ from torch.utils.tensorboard import SummaryWriter
782
+ log_dir = Path(save_dir) / "logs" if save_dir else Path("./logs")
783
+ log_dir.mkdir(exist_ok=True, parents=True)
784
+ writer = SummaryWriter(log_dir=log_dir)
785
+ tensorboard_available = True
786
+ except ImportError:
787
+ logging.warning("TensorBoard not available, skipping logging")
788
+ tensorboard_available = False
789
+
790
+ for epoch in epoch_pbar:
791
+ total_loss = 0
792
+ num_batches = 0
793
+
794
+ # Progress bar for batches
795
+
796
+ total_batches = len(dataloader)
797
+ update_freq = max(1, total_batches // 10) # Update approximately every 10%
798
+ batch_pbar = tqdm(
799
+ dataloader,
800
+ desc=f"Epoch {epoch+1}/{num_epochs}",
801
+ leave=False,
802
+ miniters=update_freq, # Only update progress bar every update_freq iterations
803
+ bar_format='{l_bar}{bar:10}{r_bar}', # Simplified bar format
804
+ disable=True # Completely disable the inner progress bar
805
+ )
806
+
807
+ # First, create a single progress indicator for the entire epoch instead of the batch_pbar
808
+ epoch_progress = tqdm(
809
+ total=len(dataloader),
810
+ desc=f"Epoch {epoch+1}/{num_epochs}",
811
+ leave=True,
812
+ bar_format='{l_bar}{bar:10}{r_bar}'
813
+ )
814
+
815
+ # Then use this updated batch processing loop
816
+ for batch_idx, batch_inputs in enumerate(dataloader):
817
+ try:
818
+ # Each element in the batch is already a dict of padded tensors
819
+ anchor_batch, positive_batch, negative_batch = batch_inputs
820
+
821
+ # Move data to device (GPU if available)
822
+ anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()}
823
+ positive_batch = {k: v.to(device) for k, v in positive_batch.items()}
824
+ negative_batch = {k: v.to(device) for k, v in negative_batch.items()}
825
+
826
+ # Generate embeddings
827
+ anchor_emb = model(anchor_batch)
828
+ positive_emb = model(positive_batch)
829
+ negative_emb = model(negative_batch)
830
+
831
+ # Calculate loss
832
+ loss = criterion(anchor_emb, positive_emb, negative_emb)
833
+
834
+ # Backward and optimize
835
+ optimizer.zero_grad()
836
+ loss.backward()
837
+ # Add gradient clipping
838
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
839
+ optimizer.step()
840
+
841
+ total_loss += loss.item()
842
+ num_batches += 1
843
+
844
+ # Update the epoch progress bar only at 10% intervals or at the end
845
+ update_interval = max(1, len(dataloader) // 10)
846
+ if (batch_idx + 1) % update_interval == 0 or batch_idx == len(dataloader) - 1:
847
+ # Update progress
848
+ remaining = min(update_interval, len(dataloader) - epoch_progress.n)
849
+ epoch_progress.update(remaining)
850
+ # Update stats with current average loss
851
+ current_avg_loss = total_loss / num_batches
852
+ epoch_progress.set_postfix(avg_loss=f"{current_avg_loss:.4f}",
853
+ last_batch_loss=f"{loss.item():.4f}")
854
+
855
+ except Exception as e:
856
+ logging.error(f"Error during batch processing: {str(e)}")
857
+ logging.error(f"Batch details: {str(e.__class__.__name__)}")
858
+ continue
859
+
860
+ # Close the progress bar at the end of the epoch
861
+ epoch_progress.close()
862
+
863
+ avg_loss = total_loss / max(1, num_batches)
864
+ # Update epoch progress bar with average loss
865
+ epoch_pbar.set_postfix(avg_loss=f"{avg_loss:.4f}")
866
+
867
+ # Log to tensorboard
868
+ if tensorboard_available:
869
+ writer.add_scalar('Loss/train', avg_loss, epoch)
870
+
871
+ # Step the learning rate scheduler at the end of each epoch
872
+ scheduler.step()
873
+
874
+ # Incremental model saving if requested
875
+ if save_dir and (epoch + 1) % save_interval == 0:
876
+ checkpoint = {
877
+ 'epoch': epoch,
878
+ 'model_state_dict': model.state_dict(),
879
+ 'optimizer_state_dict': optimizer.state_dict(),
880
+ 'loss': avg_loss,
881
+ 'scheduler_state_dict': scheduler.state_dict() # Save scheduler state
882
+ }
883
+
884
+ save_path = Path(save_dir) / f'model_checkpoint_epoch_{epoch+1}.pth'
885
+ torch.save(checkpoint, save_path)
886
+ logging.info(f"Checkpoint saved at epoch {epoch+1}: {save_path}")
887
+
888
+ if tensorboard_available:
889
+ writer.close()
890
+
891
+ return model
892
+
893
+ # =================================
894
+ # MAIN FUNCTION
895
+ # =================================
896
+ def main():
897
+ # Configuration
898
+ output_dir = Path(OUTPUT_DIR)
899
+
900
+ # Check if CUDA is available
901
+ cuda_available = torch.cuda.is_available()
902
+ logging.info(f"CUDA available: {cuda_available}")
903
+ if cuda_available:
904
+ logging.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
905
+ logging.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
906
+
907
+ # CPU info
908
+ cpu_count = os.cpu_count()
909
+ memory_info = psutil.virtual_memory()
910
+ logging.info(f"CPU cores: {cpu_count}")
911
+ logging.info(f"System memory: {memory_info.total / 1e9:.2f} GB")
912
+
913
+ # Print configuration
914
+ logging.info("Running with the following configuration:")
915
+ logging.info(f"- Number of triplets: {NUM_TRIPLETS}")
916
+ logging.info(f"- Number of epochs: {NUM_EPOCHS}")
917
+ logging.info(f"- Batch size: {BATCH_SIZE}")
918
+ logging.info(f"- Learning rate: {LEARNING_RATE}")
919
+ logging.info(f"- Output dimension: {OUTPUT_DIM}")
920
+ logging.info(f"- Data path: {DATA_PATH}")
921
+ logging.info(f"- Output directory: {OUTPUT_DIR}")
922
+
923
+ # Load data
924
+ logging.info("Loading user data...")
925
+ try:
926
+ try:
927
+ # First try loading as normal JSON
928
+ with open(DATA_PATH, 'r') as f:
929
+ json_data = json.load(f)
930
+
931
+ # Handle both cases: array of users or single object
932
+ if isinstance(json_data, list):
933
+ users_data = json_data
934
+ elif isinstance(json_data, dict):
935
+ # If it's a single record, put it in a list
936
+ users_data = [json_data]
937
+ else:
938
+ raise ValueError("Unsupported JSON format")
939
+
940
+ except json.JSONDecodeError:
941
+ # If it fails, the file might be a non-standard JSON (objects separated by commas)
942
+ logging.info("Detected possible non-standard JSON format, attempting correction...")
943
+ with open(DATA_PATH, 'r') as f:
944
+ text = f.read().strip()
945
+
946
+ # Add square brackets to create a valid JSON array
947
+ if not text.startswith('['):
948
+ text = '[' + text
949
+ if not text.endswith(']'):
950
+ text = text + ']'
951
+
952
+ # Try loading the corrected text
953
+ users_data = json.loads(text)
954
+ logging.info("JSON format successfully corrected")
955
+
956
+ logging.info(f"Loaded {len(users_data)} records")
957
+ except FileNotFoundError:
958
+ logging.error(f"File {DATA_PATH} not found!")
959
+ return
960
+ except Exception as e:
961
+ logging.error(f"Unable to load file: {str(e)}")
962
+ return
963
+
964
+ # Initialize pipeline
965
+ logging.info("Initializing pipeline...")
966
+ pipeline = UserEmbeddingPipeline(
967
+ output_dim=OUTPUT_DIM,
968
+ max_sequence_length=MAX_SEQ_LENGTH
969
+ )
970
+
971
+ # Build vocabularies
972
+ logging.info("Building vocabularies...")
973
+ try:
974
+ pipeline.build_vocabularies(users_data)
975
+ vocab_sizes = {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()}
976
+ logging.info(f"Vocabulary sizes: {vocab_sizes}")
977
+ except Exception as e:
978
+ logging.error(f"Error building vocabularies: {str(e)}")
979
+ return
980
+
981
+ # Initialize model
982
+ logging.info("Initializing model...")
983
+ try:
984
+ pipeline.initialize_model()
985
+ logging.info("Model initialized successfully")
986
+ except Exception as e:
987
+ logging.error(f"Error initializing model: {str(e)}")
988
+ return
989
+
990
+ # Training
991
+ logging.info("Starting training...")
992
+ try:
993
+ # Create directory for checkpoints if it doesn't exist
994
+ model_dir = output_dir / "model_checkpoints"
995
+ model_dir.mkdir(exist_ok=True, parents=True)
996
+
997
+ model = train_user_embeddings(
998
+ pipeline.model,
999
+ users_data,
1000
+ pipeline, # Pass the pipeline to training
1001
+ num_epochs=NUM_EPOCHS,
1002
+ batch_size=BATCH_SIZE,
1003
+ lr=LEARNING_RATE,
1004
+ save_dir=model_dir, # Add incremental saving
1005
+ save_interval=SAVE_INTERVAL, # Save every N epochs
1006
+ num_triplets=NUM_TRIPLETS
1007
+ )
1008
+ logging.info("Training completed")
1009
+ pipeline.model = model
1010
+
1011
+ # Save only the model file
1012
+ logging.info("Saving model...")
1013
+
1014
+ # Create output directory
1015
+ output_dir.mkdir(exist_ok=True)
1016
+
1017
+ # Save path
1018
+ model_path = output_dir / 'model.pth'
1019
+
1020
+ # Prepare dictionary with model state and metadata
1021
+ checkpoint = {
1022
+ 'model_state_dict': pipeline.model.state_dict(),
1023
+ 'vocab_maps': pipeline.vocab_maps,
1024
+ 'embedding_dims': pipeline.embedding_dims,
1025
+ 'output_dim': pipeline.output_dim,
1026
+ 'max_sequence_length': pipeline.max_sequence_length
1027
+ }
1028
+
1029
+ # Save model
1030
+ torch.save(checkpoint, model_path)
1031
+
1032
+ logging.info(f"Model saved to: {model_path}")
1033
+
1034
+ # Also save a configuration file for reference
1035
+ config_info = {
1036
+ 'model_type': 'UserEmbeddingModel',
1037
+ 'vocab_sizes': {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()},
1038
+ 'embedding_dims': pipeline.embedding_dims,
1039
+ 'output_dim': pipeline.output_dim,
1040
+ 'max_sequence_length': pipeline.max_sequence_length,
1041
+ 'padded_fields': list(pipeline.model.padded_fields),
1042
+ 'fields': pipeline.fields
1043
+ }
1044
+
1045
+ config_path = output_dir / 'model_config.json'
1046
+ with open(config_path, 'w') as f:
1047
+ json.dump(config_info, f, indent=2)
1048
+
1049
+ logging.info(f"Model configuration saved to: {config_path}")
1050
+
1051
+ # Only save in HuggingFace format if requested
1052
+ save_hf = os.environ.get("SAVE_HF_FORMAT", "false").lower() == "true"
1053
+ if save_hf:
1054
+ logging.info("Saving in HuggingFace format...")
1055
+ # Save model in HuggingFace format
1056
+ hf_dir = output_dir / 'huggingface'
1057
+ hf_dir.mkdir(exist_ok=True)
1058
+
1059
+ # Save model in HF format
1060
+ torch.save(pipeline.model.state_dict(), hf_dir / 'pytorch_model.bin')
1061
+
1062
+ # Save config
1063
+ with open(hf_dir / 'config.json', 'w') as f:
1064
+ json.dump(config_info, f, indent=2)
1065
+
1066
+ logging.info(f"Model saved in HuggingFace format to: {hf_dir}")
1067
+
1068
+ # Push to HuggingFace if environment variable is set
1069
+ hf_repo_id = os.environ.get("HF_REPO_ID")
1070
+ hf_token = os.environ.get("HF_TOKEN")
1071
+
1072
+ if save_hf and hf_repo_id and hf_token:
1073
+ try:
1074
+ from huggingface_hub import HfApi
1075
+
1076
+ logging.info(f"Pushing model to HuggingFace: {hf_repo_id}")
1077
+ api = HfApi()
1078
+
1079
+ # Push the model directory
1080
+ api.create_repo(
1081
+ repo_id=hf_repo_id,
1082
+ token=hf_token,
1083
+ exist_ok=True,
1084
+ private=True
1085
+ )
1086
+
1087
+ # Upload the files
1088
+ for file_path in (output_dir / "huggingface").glob("**/*"):
1089
+ if file_path.is_file():
1090
+ api.upload_file(
1091
+ path_or_fileobj=str(file_path),
1092
+ path_in_repo=file_path.relative_to(output_dir / "huggingface"),
1093
+ repo_id=hf_repo_id,
1094
+ token=hf_token
1095
+ )
1096
+
1097
+ logging.info(f"Model successfully pushed to HuggingFace: {hf_repo_id}")
1098
+ except Exception as e:
1099
+ logging.error(f"Error pushing to HuggingFace: {str(e)}")
1100
+
1101
+ except Exception as e:
1102
+ logging.error(f"Error during training or saving: {str(e)}")
1103
+ import traceback
1104
+ traceback.print_exc()
1105
+ return
1106
+
1107
+ logging.info("Process completed successfully!")
1108
+
1109
+ if __name__ == "__main__":
1110
+ main()
generate_model_gpu.py ADDED
@@ -0,0 +1,1265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
3
+ #os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
4
+ import multiprocessing
5
+ try:
6
+ multiprocessing.set_start_method('spawn')
7
+ except RuntimeError:
8
+ pass # Il metodo è già stato impostato
9
+ import json
10
+ import torch
11
+ import torch.nn as nn
12
+ import numpy as np
13
+ import pandas as pd
14
+ from typing import List, Dict
15
+ import logging
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import Dataset, DataLoader
18
+ import random
19
+ from collections import defaultdict
20
+ from pathlib import Path
21
+ from tqdm import tqdm
22
+ import os
23
+ import multiprocessing
24
+ from multiprocessing import Pool
25
+ import psutil
26
+ import argparse
27
+
28
+ # =================================
29
+ # CONFIGURABLE PARAMETERS
30
+ # =================================
31
+ # Define default parameters that can be overridden via environment variables
32
+ DEFAULT_NUM_TRIPLETS = 150 # Number of triplet examples to generate
33
+ DEFAULT_NUM_EPOCHS = 1 # Number of training epochs
34
+ DEFAULT_BATCH_SIZE = 64 # Batch size for training
35
+ DEFAULT_LEARNING_RATE = 0.001 # Learning rate for optimizer
36
+ DEFAULT_OUTPUT_DIM = 256 # Output dimension of embeddings
37
+ DEFAULT_MAX_SEQ_LENGTH = 15 # Maximum sequence length
38
+ DEFAULT_SAVE_INTERVAL = 2 # Save checkpoint every N epochs
39
+ DEFAULT_DATA_PATH = "./users.json" # Path to user data
40
+ DEFAULT_OUTPUT_DIR = "./model" # Output directory
41
+
42
+ # Read parameters from environment variables
43
+ NUM_TRIPLETS = int(os.environ.get("NUM_TRIPLETS", DEFAULT_NUM_TRIPLETS))
44
+ NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", DEFAULT_NUM_EPOCHS))
45
+ BATCH_SIZE = int(os.environ.get("BATCH_SIZE", DEFAULT_BATCH_SIZE))
46
+ LEARNING_RATE = float(os.environ.get("LEARNING_RATE", DEFAULT_LEARNING_RATE))
47
+ OUTPUT_DIM = int(os.environ.get("OUTPUT_DIM", DEFAULT_OUTPUT_DIM))
48
+ MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", DEFAULT_MAX_SEQ_LENGTH))
49
+ SAVE_INTERVAL = int(os.environ.get("SAVE_INTERVAL", DEFAULT_SAVE_INTERVAL))
50
+ DATA_PATH = os.environ.get("DATA_PATH", DEFAULT_DATA_PATH)
51
+ OUTPUT_DIR = os.environ.get("OUTPUT_DIR", DEFAULT_OUTPUT_DIR)
52
+
53
+ # Configure logging
54
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
55
+
56
+ # Get CUDA device
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ logging.info(f"Using device: {device}")
59
+
60
+ # =================================
61
+ # MODEL ARCHITECTURE
62
+ # =================================
63
+ class UserEmbeddingModel(nn.Module):
64
+ def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int],
65
+ output_dim: int = 256, max_sequence_length: int = 15,
66
+ padded_fields_length: int = 10):
67
+ super().__init__()
68
+
69
+ self.max_sequence_length = max_sequence_length
70
+ self.padded_fields_length = padded_fields_length
71
+ self.padded_fields = {'dmp_channels', 'dmp_tags', 'dmp_clusters'}
72
+ self.embedding_layers = nn.ModuleDict()
73
+
74
+ # Create embedding layers for each field
75
+ for field, vocab_size in vocab_sizes.items():
76
+ self.embedding_layers[field] = nn.Embedding(
77
+ vocab_size,
78
+ embedding_dims.get(field, 16),
79
+ padding_idx=0
80
+ )
81
+
82
+ # Calculate total input dimension
83
+ self.total_input_dim = 0
84
+ for field, dim in embedding_dims.items():
85
+ if field in self.padded_fields:
86
+ self.total_input_dim += dim # Single dimension for padded field
87
+ else:
88
+ self.total_input_dim += dim
89
+
90
+ print(f"Total input dimension: {self.total_input_dim}")
91
+
92
+ self.fc = nn.Sequential(
93
+ nn.Linear(self.total_input_dim, self.total_input_dim // 2),
94
+ nn.ReLU(),
95
+ nn.Dropout(0.2),
96
+ nn.Linear(self.total_input_dim // 2, output_dim),
97
+ nn.LayerNorm(output_dim)
98
+ )
99
+
100
+ def _process_sequence(self, embedding_layer: nn.Embedding, indices: torch.Tensor,
101
+ field_name: str) -> torch.Tensor:
102
+ """Process normal sequences"""
103
+ batch_size = indices.size(0)
104
+ if indices.numel() == 0:
105
+ return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device)
106
+
107
+ if field_name in ['dmp_city', 'dmp_domains']:
108
+ if indices.dim() == 1:
109
+ indices = indices.unsqueeze(0)
110
+ if indices.size(1) > 0:
111
+ return embedding_layer(indices[:, 0])
112
+ return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device)
113
+
114
+ # Handle multiple sequences
115
+ embeddings = embedding_layer(indices)
116
+ return embeddings.mean(dim=1) # [batch_size, emb_dim]
117
+
118
+ def _process_padded_sequence(self, embedding_layer: nn.Embedding,
119
+ indices: torch.Tensor) -> torch.Tensor:
120
+ """Process sequences with padding"""
121
+ batch_size = indices.size(0)
122
+ emb_dim = embedding_layer.embedding_dim
123
+
124
+ # Generate embeddings
125
+ embeddings = embedding_layer(indices) # [batch_size, seq_len, emb_dim]
126
+
127
+ # Average along sequence dimension
128
+ mask = (indices != 0).float().unsqueeze(-1)
129
+ masked_embeddings = embeddings * mask
130
+ sum_mask = mask.sum(dim=1).clamp(min=1.0)
131
+
132
+ return (masked_embeddings.sum(dim=1) / sum_mask) # [batch_size, emb_dim]
133
+
134
+ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
135
+ batch_embeddings = []
136
+
137
+ for field in ['dmp_city', 'source', 'dmp_brands', # modificato: rimosso 'dmp_domains', aggiunto 'source'
138
+ 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels',
139
+ 'device']: # aggiunto 'device'
140
+ if field in inputs and field in self.embedding_layers:
141
+ if field in self.padded_fields:
142
+ emb = self._process_padded_sequence(
143
+ self.embedding_layers[field],
144
+ inputs[field]
145
+ )
146
+ else:
147
+ emb = self._process_sequence(
148
+ self.embedding_layers[field],
149
+ inputs[field],
150
+ field
151
+ )
152
+ batch_embeddings.append(emb)
153
+
154
+ combined = torch.cat(batch_embeddings, dim=1)
155
+ return self.fc(combined)
156
+
157
+ # =================================
158
+ # EMBEDDING PIPELINE
159
+ # =================================
160
+ class UserEmbeddingPipeline:
161
+ def __init__(self, output_dim: int = 256, max_sequence_length: int = 15):
162
+ self.output_dim = output_dim
163
+ self.max_sequence_length = max_sequence_length
164
+ self.model = None
165
+ self.vocab_maps = {}
166
+
167
+
168
+ self.fields = [
169
+ 'dmp_city', 'source', 'dmp_brands', # 'dmp_domains' rimosso, 'source' aggiunto
170
+ 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels',
171
+ 'device' # 'device' aggiunto
172
+ ]
173
+
174
+ # Map of new JSON fields to old field names used in the model
175
+ self.field_mapping = {
176
+ 'dmp_city': ('dmp', 'city'),
177
+ 'source': ('dmp', '', 'source'),
178
+ 'dmp_brands': ('dmp', 'brands'),
179
+ 'dmp_clusters': ('dmp', 'clusters'),
180
+ 'dmp_industries': ('dmp', 'industries'),
181
+ 'dmp_tags': ('dmp', 'tags'),
182
+ 'dmp_channels': ('dmp', 'channels'),
183
+ 'device': ('device',) # Nuovo campo device
184
+ }
185
+
186
+ self.embedding_dims = {
187
+ 'dmp_city': 8,
188
+ 'source': 8, # Dimensione per source
189
+ 'dmp_brands': 32,
190
+ 'dmp_clusters': 64,
191
+ 'dmp_industries': 32,
192
+ 'dmp_tags': 128,
193
+ 'dmp_channels': 64,
194
+ 'device': 8 # Dimensione per device
195
+ }
196
+
197
+ def _clean_value(self, value):
198
+ if isinstance(value, float) and np.isnan(value):
199
+ return []
200
+ if isinstance(value, str):
201
+ return [value.lower().strip()]
202
+ if isinstance(value, list):
203
+ return [str(v).lower().strip() for v in value if v is not None and str(v).strip()]
204
+ return []
205
+
206
+ def _get_field_from_user(self, user, field):
207
+ """Extract field value from new JSON user format"""
208
+ mapping = self.field_mapping.get(field, (field,))
209
+ value = user
210
+
211
+ # Navigate through nested structure
212
+ for key in mapping:
213
+ if isinstance(value, dict):
214
+ value = value.get(key, {})
215
+ else:
216
+ # If not a dictionary and we're not at the last element
217
+ # of the mapping, return an empty list
218
+ value = []
219
+ break
220
+
221
+ # If we've reached the end and have a value that's not a list but should be,
222
+ # convert it to a list
223
+ if field in {'dmp_brands', 'dmp_channels', 'dmp_clusters', 'dmp_industries', 'dmp_tags'} and not isinstance(value, list):
224
+ # If it's a string or other single value, put it in a list
225
+ if value and not isinstance(value, dict):
226
+ value = [value]
227
+ else:
228
+ value = []
229
+
230
+ return value
231
+
232
+ def build_vocabularies(self, users_data: List[Dict]) -> Dict[str, Dict[str, int]]:
233
+ field_values = {field: {'<PAD>'} for field in self.fields}
234
+
235
+ # Extract the 'user' field from the JSON structure for each record
236
+ users = []
237
+ for data in users_data:
238
+ # Check if there's raw_json.user
239
+ if 'raw_json' in data and 'user' in data['raw_json']:
240
+ users.append(data['raw_json']['user'])
241
+ # Check if there's user
242
+ elif 'user' in data:
243
+ users.append(data['user'])
244
+ else:
245
+ users.append(data) # Assume it's already a user
246
+
247
+ for user in users:
248
+ for field in self.fields:
249
+ values = self._clean_value(self._get_field_from_user(user, field))
250
+ field_values[field].update(values)
251
+
252
+ self.vocab_maps = {
253
+ field: {val: idx for idx, val in enumerate(sorted(values))}
254
+ for field, values in field_values.items()
255
+ }
256
+
257
+ return self.vocab_maps
258
+
259
+ def _prepare_input(self, user: Dict) -> Dict[str, torch.Tensor]:
260
+ inputs = {}
261
+
262
+ for field in self.fields:
263
+ values = self._clean_value(self._get_field_from_user(user, field))
264
+ vocab = self.vocab_maps[field]
265
+ indices = [vocab.get(val, 0) for val in values]
266
+ inputs[field] = torch.tensor(indices, dtype=torch.long)
267
+
268
+ return inputs
269
+
270
+ def initialize_model(self) -> None:
271
+ vocab_sizes = {field: len(vocab) for field, vocab in self.vocab_maps.items()}
272
+
273
+ self.model = UserEmbeddingModel(
274
+ vocab_sizes=vocab_sizes,
275
+ embedding_dims=self.embedding_dims,
276
+ output_dim=self.output_dim,
277
+ max_sequence_length=self.max_sequence_length
278
+ )
279
+ self.model.to(device)
280
+ self.model.eval()
281
+
282
+ def generate_embeddings(self, users_data: List[Dict], batch_size: int = 32) -> Dict[str, np.ndarray]:
283
+ """Generate embeddings for all users"""
284
+ embeddings = {}
285
+ self.model.eval() # Make sure model is in eval mode
286
+
287
+ # Extract the 'user' field from the JSON structure for each record
288
+ users = []
289
+ user_ids = []
290
+
291
+ for data in users_data:
292
+ # Check if there's raw_json.user
293
+ if 'raw_json' in data and 'user' in data['raw_json']:
294
+ user = data['raw_json']['user']
295
+ users.append(user)
296
+ # Use user.dmp[''].id as identifier
297
+ if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']:
298
+ user_ids.append(str(user['dmp']['']['id']))
299
+ else:
300
+ # Fallback to uid or id if dmp.id is not available
301
+ user_ids.append(str(user.get('uid', user.get('id', None))))
302
+ # Check if there's user
303
+ elif 'user' in data:
304
+ user = data['user']
305
+ users.append(user)
306
+ # Use user.dmp[''].id as identifier
307
+ if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']:
308
+ user_ids.append(str(user['dmp']['']['id']))
309
+ else:
310
+ # Fallback to uid or id if dmp.id is not available
311
+ user_ids.append(str(user.get('uid', user.get('id', None))))
312
+ else:
313
+ users.append(data) # Assume it's already a user
314
+ # Use user.dmp[''].id as identifier
315
+ if 'dmp' in data and '' in data['dmp'] and 'id' in data['dmp']['']:
316
+ user_ids.append(str(data['dmp']['']['id']))
317
+ else:
318
+ # Fallback to uid or id if dmp.id is not available
319
+ user_ids.append(str(data.get('uid', data.get('id', None))))
320
+
321
+ with torch.no_grad():
322
+ for i in tqdm(range(0, len(users), batch_size), desc="Generating embeddings"):
323
+ batch_users = users[i:i+batch_size]
324
+ batch_ids = user_ids[i:i+batch_size]
325
+ batch_inputs = []
326
+ valid_indices = []
327
+
328
+ for j, user in enumerate(batch_users):
329
+ if batch_ids[j] is not None:
330
+ batch_inputs.append(self._prepare_input(user))
331
+ valid_indices.append(j)
332
+
333
+ if batch_inputs:
334
+ # Use the same collate function as training for a single batch
335
+ anchor_batch, _, _ = collate_batch([(inputs, inputs, inputs) for inputs in batch_inputs])
336
+
337
+ # Move data to device
338
+ anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()}
339
+
340
+ # Generate embeddings
341
+ batch_embeddings = self.model(anchor_batch).cpu()
342
+
343
+ # Save embeddings
344
+ for j, idx in enumerate(valid_indices):
345
+ if batch_ids[idx]: # Verify that id is not None or empty
346
+ embeddings[batch_ids[idx]] = batch_embeddings[j].numpy()
347
+
348
+ return embeddings
349
+
350
+ def save_embeddings(self, embeddings: Dict[str, np.ndarray], output_dir: str) -> None:
351
+ """Save embeddings to file"""
352
+ output_dir = Path(output_dir)
353
+ output_dir.mkdir(exist_ok=True)
354
+
355
+ # Save embeddings as JSON
356
+ json_path = output_dir / 'embeddings.json'
357
+ with open(json_path, 'w') as f:
358
+ json_embeddings = {user_id: emb.tolist() for user_id, emb in embeddings.items()}
359
+ json.dump(json_embeddings, f)
360
+
361
+ # Save embeddings as NPY
362
+ npy_path = output_dir / 'embeddings.npz'
363
+ np.savez_compressed(npy_path,
364
+ embeddings=np.stack(list(embeddings.values())),
365
+ user_ids=np.array(list(embeddings.keys())))
366
+
367
+ # Save vocabularies
368
+ vocab_path = output_dir / 'vocabularies.json'
369
+ with open(vocab_path, 'w') as f:
370
+ json.dump(self.vocab_maps, f)
371
+
372
+ logging.info(f"\nEmbeddings saved in {output_dir}:")
373
+ logging.info(f"- Embeddings JSON: {json_path}")
374
+ logging.info(f"- Embeddings NPY: {npy_path}")
375
+ logging.info(f"- Vocabularies: {vocab_path}")
376
+
377
+ def save_model(self, output_dir: str) -> None:
378
+ """Save model in PyTorch format (.pth)"""
379
+ output_dir = Path(output_dir)
380
+ output_dir.mkdir(exist_ok=True)
381
+
382
+ # Save path
383
+ model_path = output_dir / 'model.pth'
384
+
385
+ # Prepare dictionary with model state and metadata
386
+ checkpoint = {
387
+ 'model_state_dict': self.model.state_dict(),
388
+ 'vocab_maps': self.vocab_maps,
389
+ 'embedding_dims': self.embedding_dims,
390
+ 'output_dim': self.output_dim,
391
+ 'max_sequence_length': self.max_sequence_length
392
+ }
393
+
394
+ # Save model
395
+ torch.save(checkpoint, model_path)
396
+
397
+ logging.info(f"Model saved to: {model_path}")
398
+
399
+ # Also save a configuration file for reference
400
+ config_info = {
401
+ 'model_type': 'UserEmbeddingModel',
402
+ 'vocab_sizes': {field: len(vocab) for field, vocab in self.vocab_maps.items()},
403
+ 'embedding_dims': self.embedding_dims,
404
+ 'output_dim': self.output_dim,
405
+ 'max_sequence_length': self.max_sequence_length,
406
+ 'padded_fields': list(self.model.padded_fields),
407
+ 'fields': self.fields
408
+ }
409
+
410
+ config_path = output_dir / 'model_config.json'
411
+ with open(config_path, 'w') as f:
412
+ json.dump(config_info, f, indent=2)
413
+
414
+ logging.info(f"Model configuration saved to: {config_path}")
415
+
416
+ # Save model in HuggingFace format
417
+ hf_dir = output_dir / 'huggingface'
418
+ hf_dir.mkdir(exist_ok=True)
419
+
420
+ # Save model in HF format
421
+ torch.save(self.model.state_dict(), hf_dir / 'pytorch_model.bin')
422
+
423
+ # Save config
424
+ with open(hf_dir / 'config.json', 'w') as f:
425
+ json.dump(config_info, f, indent=2)
426
+
427
+ logging.info(f"Model saved in HuggingFace format to: {hf_dir}")
428
+
429
+ def load_model(self, model_path: str) -> None:
430
+ """Load a previously saved model"""
431
+ checkpoint = torch.load(model_path, map_location=device)
432
+
433
+ # Reload vocabularies and dimensions if needed
434
+ self.vocab_maps = checkpoint.get('vocab_maps', self.vocab_maps)
435
+ self.embedding_dims = checkpoint.get('embedding_dims', self.embedding_dims)
436
+ self.output_dim = checkpoint.get('output_dim', self.output_dim)
437
+ self.max_sequence_length = checkpoint.get('max_sequence_length', self.max_sequence_length)
438
+
439
+ # Initialize model if not already done
440
+ if self.model is None:
441
+ self.initialize_model()
442
+
443
+ # Load model weights
444
+ self.model.load_state_dict(checkpoint['model_state_dict'])
445
+ self.model.to(device)
446
+ self.model.eval()
447
+
448
+ logging.info(f"Model loaded from: {model_path}")
449
+
450
+ # =================================
451
+ # SIMILARITY AND TRIPLET GENERATION
452
+ # =================================
453
+
454
+ def calculate_similarity(user1, user2, pipeline, filtered_tags=None):
455
+ try:
456
+ # Estrai i campi originali: channels e clusters
457
+ channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None)
458
+ channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None)
459
+ clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None)
460
+ clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None)
461
+
462
+ # RIMOSSO domains1/domains2
463
+
464
+ # Estrai i tag e applica il filtro se necessario
465
+ tags1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'dmp_tags') if c is not None)
466
+ tags2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'dmp_tags') if c is not None)
467
+
468
+ # NUOVI CAMPI: source, brands, device
469
+ source1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'source') if c is not None)
470
+ source2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'source') if c is not None)
471
+
472
+ brands1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'dmp_brands') if c is not None)
473
+ brands2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'dmp_brands') if c is not None)
474
+
475
+ device1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'device') if c is not None)
476
+ device2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'device') if c is not None)
477
+
478
+ if filtered_tags is not None:
479
+ # Filtra i tag usando solo quelli presenti nel set di tag filtrati
480
+ tags1 = {tag for tag in tags1 if tag in filtered_tags}
481
+ tags2 = {tag for tag in tags2 if tag in filtered_tags}
482
+
483
+ # Calcola le similarità Jaccard per ogni campo
484
+ channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2))
485
+ cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | channels2))
486
+
487
+ tag_sim = len(tags1 & tags2) / max(1, len(tags1 | tags2))
488
+
489
+ # Nuove similarità
490
+ source_sim = len(source1 & source2) / max(1, len(source1 | source2))
491
+ brands_sim = len(brands1 & brands2) / max(1, len(brands1 | brands2))
492
+ device_sim = len(device1 & device2) / max(1, len(device1 | device2))
493
+
494
+ # Calcola la similarità totale con i pesi specificati:
495
+ # 6 per clusters, 5 per channels, 3 per tags,
496
+ # 3 per source, 5 per brands, 2 per device
497
+ total_weight = 6 + 5 + 3 + 2 + 5 + 3 # Somma dei pesi = 24
498
+ weighted_sim = (
499
+ 6 * cluster_sim +
500
+ 5 * channel_sim +
501
+ 3 * tag_sim +
502
+ 2 * source_sim + # Nuovo: source con peso 0.3 (3/24 ≈ 0.125)
503
+ 5 * brands_sim + # Nuovo: brands con peso 0.5 (5/24 ≈ 0.208)
504
+ 3 * device_sim # Nuovo: device con peso 0.2 (2/24 ≈ 0.083)
505
+ ) / total_weight
506
+
507
+ return weighted_sim
508
+ except Exception as e:
509
+ logging.error(f"Error calculating similarity: {str(e)}")
510
+ return 0.0
511
+
512
+
513
+ def process_batch_triplets(args):
514
+ try:
515
+ batch_idx, users, channel_index, cluster_index, num_triplets, pipeline = args
516
+ batch_triplets = []
517
+
518
+ # Forza l'uso della CPU per tutti i calcoli
519
+ with torch.no_grad():
520
+ # Imposta temporaneamente il dispositivo su CPU per il calcolo delle similarità
521
+ temp_device = torch.device("cpu")
522
+
523
+ for _ in range(num_triplets):
524
+ anchor_idx = random.randint(0, len(users)-1)
525
+ anchor_user = users[anchor_idx]
526
+
527
+ # Find candidates that share channels or clusters
528
+ candidates = set()
529
+ for channel in pipeline._get_field_from_user(anchor_user, 'dmp_channels'):
530
+ candidates.update(channel_index.get(str(channel), []))
531
+ for cluster in pipeline._get_field_from_user(anchor_user, 'dmp_clusters'):
532
+ candidates.update(cluster_index.get(str(cluster), []))
533
+
534
+ # Remove anchor
535
+ candidates.discard(anchor_idx)
536
+
537
+ # Find positive (similar) user
538
+ if not candidates:
539
+ positive_idx = random.randint(0, len(users)-1)
540
+ else:
541
+ # Calculate similarities for candidates
542
+ similarities = []
543
+ for idx in candidates:
544
+ # Calcolo della similarità senza CUDA
545
+ sim = calculate_similarity(anchor_user, users[idx], pipeline)
546
+ if sim > 0:
547
+ similarities.append((idx, sim))
548
+
549
+ if not similarities:
550
+ positive_idx = random.randint(0, len(users)-1)
551
+ else:
552
+ # Sort by similarity
553
+ similarities.sort(key=lambda x: x[1], reverse=True)
554
+ # Return one of the top K most similar
555
+ top_k = min(10, len(similarities))
556
+ positive_idx = similarities[random.randint(0, top_k-1)][0]
557
+
558
+ # Find negative (dissimilar) user
559
+ max_attempts = 50
560
+ negative_idx = None
561
+
562
+ for _ in range(max_attempts):
563
+ idx = random.randint(0, len(users)-1)
564
+ if idx != anchor_idx and idx != positive_idx:
565
+ # Calcolo della similarità senza CUDA
566
+ if calculate_similarity(anchor_user, users[idx], pipeline) < 0.1:
567
+ negative_idx = idx
568
+ break
569
+
570
+ if negative_idx is None:
571
+ negative_idx = random.randint(0, len(users)-1)
572
+
573
+ batch_triplets.append((anchor_idx, positive_idx, negative_idx))
574
+
575
+ return batch_triplets
576
+ except Exception as e:
577
+ logging.error(f"Error in batch triplet generation: {str(e)}")
578
+ return []
579
+
580
+
581
+ # =================================
582
+ # DATASET AND DATALOADER
583
+ # =================================
584
+
585
+ class UserSimilarityDataset(Dataset):
586
+ def __init__(self, pipeline, users_data, num_triplets=10, num_workers=None, filtered_tags=None):
587
+ self.triplets = []
588
+ self.filtered_tags = filtered_tags
589
+ logging.info("Initializing UserSimilarityDataset...")
590
+
591
+ # Extract the 'user' field from the JSON structure for each record
592
+ self.users = []
593
+ for data in users_data:
594
+ # Check if there's raw_json.user
595
+ if 'raw_json' in data and 'user' in data['raw_json']:
596
+ self.users.append(data['raw_json']['user'])
597
+ # Check if there's user
598
+ elif 'user' in data:
599
+ self.users.append(data['user'])
600
+ else:
601
+ self.users.append(data) # Assume it's already a user
602
+
603
+ self.pipeline = pipeline
604
+ self.num_triplets = num_triplets
605
+
606
+ # Determine number of workers for parallel processing
607
+ if num_workers is None:
608
+ num_workers = max(1, min(8, os.cpu_count()))
609
+ self.num_workers = num_workers
610
+
611
+ # Pre-processa gli input per ogni utente
612
+ self.preprocessed_inputs = {}
613
+ for idx, user in enumerate(self.users):
614
+ self.preprocessed_inputs[idx] = pipeline._prepare_input(user)
615
+
616
+ logging.info("Creating indexes for channels, clusters, tags, brands, source, and device...")
617
+ self.channel_index = defaultdict(list)
618
+ self.cluster_index = defaultdict(list)
619
+ self.tag_index = defaultdict(list)
620
+ # Rimosso domain_index
621
+ self.brands_index = defaultdict(list) # Nuovo indice per brands
622
+ self.source_index = defaultdict(list) # Nuovo indice per source
623
+ self.device_index = defaultdict(list) # Nuovo indice per device
624
+
625
+ for idx, user in enumerate(self.users):
626
+ channels = pipeline._get_field_from_user(user, 'dmp_channels')
627
+ clusters = pipeline._get_field_from_user(user, 'dmp_clusters')
628
+ tags = pipeline._get_field_from_user(user, 'dmp_tags')
629
+ # Rimosso domains
630
+ brands = pipeline._get_field_from_user(user, 'dmp_brands') # Aggiunto brands
631
+ source = pipeline._get_field_from_user(user, 'source') # Aggiunto source
632
+ device = pipeline._get_field_from_user(user, 'device') # Aggiunto device
633
+
634
+ if channels:
635
+ channels = [str(c) for c in channels if c is not None]
636
+ if clusters:
637
+ clusters = [str(c) for c in clusters if c is not None]
638
+ if tags:
639
+ tags = [str(c) for c in tags if c is not None]
640
+ # Filtra i tag se è stato fornito un set di tag filtrati
641
+ if self.filtered_tags:
642
+ tags = [tag for tag in tags if tag in self.filtered_tags]
643
+ # Rimosso codice per domains
644
+ if brands:
645
+ brands = [str(c) for c in brands if c is not None]
646
+
647
+
648
+ if source:
649
+ if isinstance(source, str):
650
+ source = [source] # Se è una stringa singola, mettila in una lista
651
+ else:
652
+ source = [str(c) for c in source if c is not None] # Altrimenti trattala come prima
653
+
654
+ if device:
655
+ if isinstance(device, str):
656
+ device = [device] # Se è una stringa singola, mettila in una lista
657
+ else:
658
+ device = [str(c) for c in device if c is not None]
659
+
660
+ for channel in channels:
661
+ self.channel_index[channel].append(idx)
662
+ for cluster in clusters:
663
+ self.cluster_index[cluster].append(idx)
664
+ for tag in tags:
665
+ self.tag_index[tag].append(idx)
666
+ # Rimosso ciclo per domains
667
+ for brand in brands:
668
+ self.brands_index[brand].append(idx)
669
+ for s in source:
670
+ self.source_index[s].append(idx)
671
+ for d in device:
672
+ self.device_index[d].append(idx)
673
+
674
+ logging.info(f"Found {len(self.channel_index)} unique channels, {len(self.cluster_index)} unique clusters, {len(self.tag_index)} unique tags")
675
+ logging.info(f"Found {len(self.brands_index)} unique brands, {len(self.source_index)} unique sources, and {len(self.device_index)} unique devices")
676
+
677
+ logging.info(f"Generating triplets using {self.num_workers} worker processes...")
678
+
679
+ self.triplets = self._generate_triplets_gpu(num_triplets)
680
+ logging.info(f"Generated {len(self.triplets)} triplets")
681
+
682
+ def __len__(self):
683
+ return len(self.triplets)
684
+
685
+ def __getitem__(self, idx):
686
+ if idx >= len(self.triplets):
687
+ raise IndexError(f"Index {idx} out of range for dataset with {len(self.triplets)} triplets")
688
+
689
+ anchor_idx, positive_idx, negative_idx = self.triplets[idx]
690
+ return (
691
+ self.preprocessed_inputs[anchor_idx],
692
+ self.preprocessed_inputs[positive_idx],
693
+ self.preprocessed_inputs[negative_idx]
694
+ )
695
+
696
+ def _generate_triplets_gpu(self, num_triplets):
697
+ """Generate triplets using a more reliable approach with batch processing"""
698
+ logging.info("Generating triplets with batch approach...")
699
+
700
+ triplets = []
701
+ batch_size = 10 # Numero di triplette da generare per batch
702
+ num_batches = (num_triplets + batch_size - 1) // batch_size
703
+
704
+ progress_bar = tqdm(
705
+ range(num_batches),
706
+ desc="Generating triplet batches",
707
+ bar_format='{l_bar}{bar:10}{r_bar}'
708
+ )
709
+
710
+ for _ in progress_bar:
711
+ batch_triplets = []
712
+
713
+ # Genera un batch di triplette
714
+ for i in range(batch_size):
715
+ if len(triplets) >= num_triplets:
716
+ break
717
+
718
+ # Seleziona anchor casuale
719
+ anchor_idx = random.randint(0, len(self.users)-1)
720
+ anchor_user = self.users[anchor_idx]
721
+
722
+ # Trova candidati che condividono channels, clusters, tags o domains
723
+ # Trova candidati che condividono channels, clusters, tags, brands, source, device
724
+ candidates = set()
725
+ for channel in self.pipeline._get_field_from_user(anchor_user, 'dmp_channels'):
726
+ if channel is not None:
727
+ candidates.update(self.channel_index.get(str(channel), []))
728
+ for cluster in self.pipeline._get_field_from_user(anchor_user, 'dmp_clusters'):
729
+ if cluster is not None:
730
+ candidates.update(self.cluster_index.get(str(cluster), []))
731
+ for tag in self.pipeline._get_field_from_user(anchor_user, 'dmp_tags'):
732
+ if tag is not None and (self.filtered_tags is None or str(tag) in self.filtered_tags):
733
+ candidates.update(self.tag_index.get(str(tag), []))
734
+ # Rimosso il loop per domains
735
+
736
+ # Nuovi loop per brands, source, device
737
+ for brand in self.pipeline._get_field_from_user(anchor_user, 'dmp_brands'):
738
+ if brand is not None:
739
+ candidates.update(self.brands_index.get(str(brand), []))
740
+ for source in self.pipeline._get_field_from_user(anchor_user, 'source'):
741
+ if source is not None:
742
+ candidates.update(self.source_index.get(str(source), []))
743
+ for device in self.pipeline._get_field_from_user(anchor_user, 'device'):
744
+ if device is not None:
745
+ candidates.update(self.device_index.get(str(device), []))
746
+
747
+ # Rimuovi l'anchor dai candidati
748
+ candidates.discard(anchor_idx)
749
+
750
+ # Trova esempio positivo
751
+ if candidates:
752
+ similarities = []
753
+ for idx in list(candidates)[:50]: # Limita la ricerca ai primi 50 candidati
754
+ sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline, self.filtered_tags)
755
+ if sim > 0:
756
+ similarities.append((idx, sim))
757
+
758
+ if similarities:
759
+ similarities.sort(key=lambda x: x[1], reverse=True)
760
+ top_k = min(10, len(similarities))
761
+ positive_idx = similarities[random.randint(0, top_k-1)][0]
762
+ else:
763
+ positive_idx = random.randint(0, len(self.users)-1)
764
+ else:
765
+ positive_idx = random.randint(0, len(self.users)-1)
766
+
767
+ # Trova esempio negativo
768
+ attempts = 0
769
+ negative_idx = None
770
+
771
+ while attempts < 20 and negative_idx is None: # Ridotto a 20 tentativi
772
+ idx = random.randint(0, len(self.users)-1)
773
+ if idx != anchor_idx and idx != positive_idx:
774
+ sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline, self.filtered_tags)
775
+ if sim < 0.1:
776
+ negative_idx = idx
777
+ break
778
+ attempts += 1
779
+
780
+ if negative_idx is None:
781
+ negative_idx = random.randint(0, len(self.users)-1)
782
+
783
+ batch_triplets.append((anchor_idx, positive_idx, negative_idx))
784
+
785
+ triplets.extend(batch_triplets)
786
+
787
+ return triplets[:num_triplets] # Assicurati di restituire esattamente num_triplets
788
+
789
+
790
+ def collate_batch(batch):
791
+ """Custom collate function to properly handle tensor dimensions"""
792
+ anchor_inputs, positive_inputs, negative_inputs = zip(*batch)
793
+
794
+ def process_group_inputs(group_inputs):
795
+ processed = {}
796
+ for field in group_inputs[0].keys():
797
+ # Find maximum length for this field in the batch
798
+ max_len = max(inputs[field].size(0) for inputs in group_inputs)
799
+
800
+ # Create padded tensors
801
+ padded = torch.stack([
802
+ torch.cat([
803
+ inputs[field],
804
+ torch.zeros(max_len - inputs[field].size(0), dtype=torch.long)
805
+ ]) if inputs[field].size(0) < max_len else inputs[field][:max_len]
806
+ for inputs in group_inputs
807
+ ])
808
+
809
+ processed[field] = padded
810
+
811
+ return processed
812
+
813
+ # Process each group (anchor, positive, negative)
814
+ anchor_batch = process_group_inputs(anchor_inputs)
815
+ positive_batch = process_group_inputs(positive_inputs)
816
+ negative_batch = process_group_inputs(negative_inputs)
817
+
818
+ return anchor_batch, positive_batch, negative_batch
819
+
820
+ # =================================
821
+ # TRAINING FUNCTION
822
+ # =================================
823
+
824
+ def train_user_embeddings(model, users_data, pipeline, num_epochs=10, batch_size=32, lr=0.001, save_dir=None, save_interval=2, num_triplets=150):
825
+ """Main training of the model with proper batch handling and incremental saving"""
826
+ model.train()
827
+ model.to(device)
828
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr)
829
+
830
+ # Add learning rate scheduler
831
+ scheduler = torch.optim.lr_scheduler.StepLR(
832
+ optimizer,
833
+ step_size=2, # Decay every 2 epochs
834
+ gamma=0.9 # Multiply by 0.9 (10% reduction)
835
+ )
836
+
837
+ # Determine number of CPU cores to use
838
+ num_cpu_cores = max(1, min(32, os.cpu_count()))
839
+ logging.info(f"Using {num_cpu_cores} CPU cores for data processing")
840
+
841
+ # Prepare dataset and dataloader with custom collate function
842
+ dataset = UserSimilarityDataset(
843
+ pipeline,
844
+ users_data,
845
+ num_triplets=num_triplets,
846
+ num_workers=num_cpu_cores
847
+ )
848
+
849
+ dataloader = DataLoader(
850
+ dataset,
851
+ batch_size=batch_size,
852
+ shuffle=True,
853
+ collate_fn=collate_batch,
854
+ num_workers=0, # For loading batches
855
+ pin_memory=True # Speed up data transfer to GPU
856
+ )
857
+
858
+ # Loss function
859
+ criterion = torch.nn.TripletMarginLoss(margin=1.0)
860
+
861
+ # Progress bar for epochs
862
+ epoch_pbar = tqdm(
863
+ range(num_epochs),
864
+ desc="Training Progress",
865
+ bar_format='{l_bar}{bar:10}{r_bar}'
866
+ )
867
+ # Set up tensorboard for logging
868
+ try:
869
+ from torch.utils.tensorboard import SummaryWriter
870
+ log_dir = Path(save_dir) / "logs" if save_dir else Path("./logs")
871
+ log_dir.mkdir(exist_ok=True, parents=True)
872
+ writer = SummaryWriter(log_dir=log_dir)
873
+ tensorboard_available = True
874
+ except ImportError:
875
+ logging.warning("TensorBoard not available, skipping logging")
876
+ tensorboard_available = False
877
+
878
+ for epoch in epoch_pbar:
879
+ total_loss = 0
880
+ num_batches = 0
881
+
882
+ # Progress bar for batches
883
+
884
+ total_batches = len(dataloader)
885
+ update_freq = max(1, total_batches // 10) # Update approximately every 10%
886
+ batch_pbar = tqdm(
887
+ dataloader,
888
+ desc=f"Epoch {epoch+1}/{num_epochs}",
889
+ leave=False,
890
+ miniters=update_freq, # Only update progress bar every update_freq iterations
891
+ bar_format='{l_bar}{bar:10}{r_bar}', # Simplified bar format
892
+ disable=True # Completely disable the inner progress bar
893
+ )
894
+
895
+ # First, create a single progress indicator for the entire epoch instead of the batch_pbar
896
+ epoch_progress = tqdm(
897
+ total=len(dataloader),
898
+ desc=f"Epoch {epoch+1}/{num_epochs}",
899
+ leave=True,
900
+ bar_format='{l_bar}{bar:10}{r_bar}'
901
+ )
902
+
903
+ # Then use this updated batch processing loop
904
+ for batch_idx, batch_inputs in enumerate(dataloader):
905
+ try:
906
+ # Each element in the batch is already a dict of padded tensors
907
+ anchor_batch, positive_batch, negative_batch = batch_inputs
908
+
909
+ # Move data to device (GPU if available)
910
+ anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()}
911
+ positive_batch = {k: v.to(device) for k, v in positive_batch.items()}
912
+ negative_batch = {k: v.to(device) for k, v in negative_batch.items()}
913
+
914
+ # Generate embeddings
915
+ anchor_emb = model(anchor_batch)
916
+ positive_emb = model(positive_batch)
917
+ negative_emb = model(negative_batch)
918
+
919
+ # Calculate loss
920
+ loss = criterion(anchor_emb, positive_emb, negative_emb)
921
+
922
+ # Backward and optimize
923
+ optimizer.zero_grad()
924
+ loss.backward()
925
+ # Add gradient clipping
926
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
927
+ optimizer.step()
928
+
929
+ total_loss += loss.item()
930
+ num_batches += 1
931
+
932
+ # Update the epoch progress bar only at 10% intervals or at the end
933
+ update_interval = max(1, len(dataloader) // 10)
934
+ if (batch_idx + 1) % update_interval == 0 or batch_idx == len(dataloader) - 1:
935
+ # Update progress
936
+ remaining = min(update_interval, len(dataloader) - epoch_progress.n)
937
+ epoch_progress.update(remaining)
938
+ # Update stats with current average loss
939
+ current_avg_loss = total_loss / num_batches
940
+ epoch_progress.set_postfix(avg_loss=f"{current_avg_loss:.4f}",
941
+ last_batch_loss=f"{loss.item():.4f}")
942
+
943
+ except Exception as e:
944
+ logging.error(f"Error during batch processing: {str(e)}")
945
+ logging.error(f"Batch details: {str(e.__class__.__name__)}")
946
+ continue
947
+
948
+ # Close the progress bar at the end of the epoch
949
+ epoch_progress.close()
950
+
951
+ avg_loss = total_loss / max(1, num_batches)
952
+ # Update epoch progress bar with average loss
953
+ epoch_pbar.set_postfix(avg_loss=f"{avg_loss:.4f}")
954
+
955
+ # Log to tensorboard
956
+ if tensorboard_available:
957
+ writer.add_scalar('Loss/train', avg_loss, epoch)
958
+
959
+ # Step the learning rate scheduler at the end of each epoch
960
+ scheduler.step()
961
+
962
+ # Incremental model saving if requested
963
+ if save_dir and (epoch + 1) % save_interval == 0:
964
+ checkpoint = {
965
+ 'epoch': epoch,
966
+ 'model_state_dict': model.state_dict(),
967
+ 'optimizer_state_dict': optimizer.state_dict(),
968
+ 'loss': avg_loss,
969
+ 'scheduler_state_dict': scheduler.state_dict() # Save scheduler state
970
+ }
971
+
972
+ save_path = Path(save_dir) / f'model_checkpoint_epoch_{epoch+1}.pth'
973
+ torch.save(checkpoint, save_path)
974
+ logging.info(f"Checkpoint saved at epoch {epoch+1}: {save_path}")
975
+
976
+ if tensorboard_available:
977
+ writer.close()
978
+
979
+ return model
980
+
981
+
982
+ def compute_tag_frequencies(pipeline, users_data):
983
+ """
984
+ Calcola le frequenze dei tag nel dataset.
985
+
986
+ Args:
987
+ pipeline: Pipeline di embedding
988
+ users_data: Lista di utenti
989
+
990
+ Returns:
991
+ dict: Dizionario con tag come chiavi e frequenze come valori
992
+ int: Numero totale di tag processati
993
+ """
994
+ logging.info("Calcolando frequenze dei tag...")
995
+ tag_frequencies = {}
996
+ total_tags = 0
997
+
998
+ # Estrai i tag da tutti gli utenti
999
+ for data in users_data:
1000
+ # Estrai l'utente dalla struttura JSON
1001
+ if 'raw_json' in data and 'user' in data['raw_json']:
1002
+ user = data['raw_json']['user']
1003
+ elif 'user' in data:
1004
+ user = data['user']
1005
+ else:
1006
+ user = data # Assume che sia già un utente
1007
+
1008
+ # Estrai e conta i tag
1009
+ tags = pipeline._get_field_from_user(user, 'dmp_tags')
1010
+ for tag in tags:
1011
+ if tag is not None:
1012
+ tag_str = str(tag).lower().strip()
1013
+ tag_frequencies[tag_str] = tag_frequencies.get(tag_str, 0) + 1
1014
+ total_tags += 1
1015
+
1016
+ logging.info(f"Trovati {len(tag_frequencies)} tag unici su {total_tags} occorrenze totali")
1017
+ return tag_frequencies, total_tags
1018
+
1019
+ # funzione per filtrare i tag in base alla frequenza o al percentile
1020
+
1021
+ def filter_tags_by_criteria(tag_frequencies, min_frequency=100, percentile=None):
1022
+ """
1023
+ Filtra i tag in base a criteri di frequenza o percentile.
1024
+
1025
+ Args:
1026
+ tag_frequencies: Dizionario con tag e frequenze
1027
+ min_frequency: Frequenza minima richiesta (default: 100)
1028
+ percentile: Se specificato, mantiene solo i tag fino al percentile indicato
1029
+
1030
+ Returns:
1031
+ set: Set di tag che soddisfano i criteri
1032
+ """
1033
+ if percentile is not None:
1034
+ # Ordina i tag per frequenza e mantieni solo fino al percentile specificato
1035
+ sorted_tags = sorted(tag_frequencies.items(), key=lambda x: x[1], reverse=True)
1036
+ cutoff_index = int(len(sorted_tags) * (percentile / 100.0))
1037
+ filtered_tags = {tag for tag, _ in sorted_tags[:cutoff_index]}
1038
+
1039
+ min_freq_in_set = sorted_tags[cutoff_index-1][1] if cutoff_index > 0 else 0
1040
+ logging.info(f"Filtrati tag al {percentile}° percentile. Mantenuti {len(filtered_tags)} tag con frequenza >= {min_freq_in_set}")
1041
+ else:
1042
+ # Filtra solo in base alla frequenza minima
1043
+ filtered_tags = {tag for tag, freq in tag_frequencies.items() if freq >= min_frequency}
1044
+ logging.info(f"Filtrati tag con frequenza < {min_frequency}. Mantenuti {len(filtered_tags)} tag")
1045
+
1046
+ return filtered_tags
1047
+
1048
+ # =================================
1049
+ # MAIN FUNCTION
1050
+ # =================================
1051
+ def main():
1052
+ # Configuration
1053
+ output_dir = Path(OUTPUT_DIR)
1054
+
1055
+ # Check if CUDA is available
1056
+ cuda_available = torch.cuda.is_available()
1057
+ logging.info(f"CUDA available: {cuda_available}")
1058
+ if cuda_available:
1059
+ logging.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
1060
+ logging.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
1061
+
1062
+ # CPU info
1063
+ cpu_count = os.cpu_count()
1064
+ memory_info = psutil.virtual_memory()
1065
+ logging.info(f"CPU cores: {cpu_count}")
1066
+ logging.info(f"System memory: {memory_info.total / 1e9:.2f} GB")
1067
+
1068
+ # Print configuration
1069
+ logging.info("Running with the following configuration:")
1070
+ logging.info(f"- Number of triplets: {NUM_TRIPLETS}")
1071
+ logging.info(f"- Number of epochs: {NUM_EPOCHS}")
1072
+ logging.info(f"- Batch size: {BATCH_SIZE}")
1073
+ logging.info(f"- Learning rate: {LEARNING_RATE}")
1074
+ logging.info(f"- Output dimension: {OUTPUT_DIM}")
1075
+ logging.info(f"- Data path: {DATA_PATH}")
1076
+ logging.info(f"- Output directory: {OUTPUT_DIR}")
1077
+
1078
+ # Load data
1079
+ logging.info("Loading user data...")
1080
+ try:
1081
+ try:
1082
+ # First try loading as normal JSON
1083
+ with open(DATA_PATH, 'r') as f:
1084
+ json_data = json.load(f)
1085
+
1086
+ # Handle both cases: array of users or single object
1087
+ if isinstance(json_data, list):
1088
+ users_data = json_data
1089
+ elif isinstance(json_data, dict):
1090
+ # If it's a single record, put it in a list
1091
+ users_data = [json_data]
1092
+ else:
1093
+ raise ValueError("Unsupported JSON format")
1094
+
1095
+ except json.JSONDecodeError:
1096
+ # If it fails, the file might be a non-standard JSON (objects separated by commas)
1097
+ logging.info("Detected possible non-standard JSON format, attempting correction...")
1098
+ with open(DATA_PATH, 'r') as f:
1099
+ text = f.read().strip()
1100
+
1101
+ # Add square brackets to create a valid JSON array
1102
+ if not text.startswith('['):
1103
+ text = '[' + text
1104
+ if not text.endswith(']'):
1105
+ text = text + ']'
1106
+
1107
+ # Try loading the corrected text
1108
+ users_data = json.loads(text)
1109
+ logging.info("JSON format successfully corrected")
1110
+
1111
+ logging.info(f"Loaded {len(users_data)} records")
1112
+ except FileNotFoundError:
1113
+ logging.error(f"File {DATA_PATH} not found!")
1114
+ return
1115
+ except Exception as e:
1116
+ logging.error(f"Unable to load file: {str(e)}")
1117
+ return
1118
+
1119
+ # Initialize pipeline
1120
+ logging.info("Initializing pipeline...")
1121
+ pipeline = UserEmbeddingPipeline(
1122
+ output_dim=OUTPUT_DIM,
1123
+ max_sequence_length=MAX_SEQ_LENGTH
1124
+ )
1125
+
1126
+ # Build vocabularies
1127
+ logging.info("Building vocabularies...")
1128
+ try:
1129
+ pipeline.build_vocabularies(users_data)
1130
+ vocab_sizes = {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()}
1131
+ logging.info(f"Vocabulary sizes: {vocab_sizes}")
1132
+ except Exception as e:
1133
+ logging.error(f"Error building vocabularies: {str(e)}")
1134
+ return
1135
+
1136
+ # Initialize model
1137
+ logging.info("Initializing model...")
1138
+ try:
1139
+ pipeline.initialize_model()
1140
+ logging.info("Model initialized successfully")
1141
+ except Exception as e:
1142
+ logging.error(f"Error initializing model: {str(e)}")
1143
+ return
1144
+
1145
+ # Training
1146
+ logging.info("Starting training...")
1147
+ try:
1148
+ # Create directory for checkpoints if it doesn't exist
1149
+ model_dir = output_dir / "model_checkpoints"
1150
+ model_dir.mkdir(exist_ok=True, parents=True)
1151
+
1152
+ model = train_user_embeddings(
1153
+ pipeline.model,
1154
+ users_data,
1155
+ pipeline, # Pass the pipeline to training
1156
+ num_epochs=NUM_EPOCHS,
1157
+ batch_size=BATCH_SIZE,
1158
+ lr=LEARNING_RATE,
1159
+ save_dir=model_dir, # Add incremental saving
1160
+ save_interval=SAVE_INTERVAL, # Save every N epochs
1161
+ num_triplets=NUM_TRIPLETS
1162
+ )
1163
+ logging.info("Training completed")
1164
+ pipeline.model = model
1165
+
1166
+ # Save only the model file
1167
+ logging.info("Saving model...")
1168
+
1169
+ # Create output directory
1170
+ output_dir.mkdir(exist_ok=True)
1171
+
1172
+ # Save path
1173
+ model_path = output_dir / 'model.pth'
1174
+
1175
+ # Prepare dictionary with model state and metadata
1176
+ checkpoint = {
1177
+ 'model_state_dict': pipeline.model.state_dict(),
1178
+ 'vocab_maps': pipeline.vocab_maps,
1179
+ 'embedding_dims': pipeline.embedding_dims,
1180
+ 'output_dim': pipeline.output_dim,
1181
+ 'max_sequence_length': pipeline.max_sequence_length
1182
+ }
1183
+
1184
+ # Save model
1185
+ torch.save(checkpoint, model_path)
1186
+
1187
+ logging.info(f"Model saved to: {model_path}")
1188
+
1189
+ # Also save a configuration file for reference
1190
+ config_info = {
1191
+ 'model_type': 'UserEmbeddingModel',
1192
+ 'vocab_sizes': {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()},
1193
+ 'embedding_dims': pipeline.embedding_dims,
1194
+ 'output_dim': pipeline.output_dim,
1195
+ 'max_sequence_length': pipeline.max_sequence_length,
1196
+ 'padded_fields': list(pipeline.model.padded_fields),
1197
+ 'fields': pipeline.fields
1198
+ }
1199
+
1200
+ config_path = output_dir / 'model_config.json'
1201
+ with open(config_path, 'w') as f:
1202
+ json.dump(config_info, f, indent=2)
1203
+
1204
+ logging.info(f"Model configuration saved to: {config_path}")
1205
+
1206
+ # Only save in HuggingFace format if requested
1207
+ save_hf = os.environ.get("SAVE_HF_FORMAT", "false").lower() == "true"
1208
+ if save_hf:
1209
+ logging.info("Saving in HuggingFace format...")
1210
+ # Save model in HuggingFace format
1211
+ hf_dir = output_dir / 'huggingface'
1212
+ hf_dir.mkdir(exist_ok=True)
1213
+
1214
+ # Save model in HF format
1215
+ torch.save(pipeline.model.state_dict(), hf_dir / 'pytorch_model.bin')
1216
+
1217
+ # Save config
1218
+ with open(hf_dir / 'config.json', 'w') as f:
1219
+ json.dump(config_info, f, indent=2)
1220
+
1221
+ logging.info(f"Model saved in HuggingFace format to: {hf_dir}")
1222
+
1223
+ # Push to HuggingFace if environment variable is set
1224
+ hf_repo_id = os.environ.get("HF_REPO_ID")
1225
+ hf_token = os.environ.get("HF_TOKEN")
1226
+
1227
+ if save_hf and hf_repo_id and hf_token:
1228
+ try:
1229
+ from huggingface_hub import HfApi
1230
+
1231
+ logging.info(f"Pushing model to HuggingFace: {hf_repo_id}")
1232
+ api = HfApi()
1233
+
1234
+ # Push the model directory
1235
+ api.create_repo(
1236
+ repo_id=hf_repo_id,
1237
+ token=hf_token,
1238
+ exist_ok=True,
1239
+ private=True
1240
+ )
1241
+
1242
+ # Upload the files
1243
+ for file_path in (output_dir / "huggingface").glob("**/*"):
1244
+ if file_path.is_file():
1245
+ api.upload_file(
1246
+ path_or_fileobj=str(file_path),
1247
+ path_in_repo=file_path.relative_to(output_dir / "huggingface"),
1248
+ repo_id=hf_repo_id,
1249
+ token=hf_token
1250
+ )
1251
+
1252
+ logging.info(f"Model successfully pushed to HuggingFace: {hf_repo_id}")
1253
+ except Exception as e:
1254
+ logging.error(f"Error pushing to HuggingFace: {str(e)}")
1255
+
1256
+ except Exception as e:
1257
+ logging.error(f"Error during training or saving: {str(e)}")
1258
+ import traceback
1259
+ traceback.print_exc()
1260
+ return
1261
+
1262
+ logging.info("Process completed successfully!")
1263
+
1264
+ if __name__ == "__main__":
1265
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.0
2
+ torchvision==0.16.0
3
+ numpy==1.24.3
4
+ pandas==2.0.3
5
+ tqdm==4.66.1
6
+ scikit-learn==1.3.0
7
+ matplotlib==3.7.2
8
+ transformers==4.34.0
9
+ datasets==2.14.5
10
+ huggingface_hub==0.17.3
11
+ tensorboard==2.14.1
12
+ psutil==5.9.5
run.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ echo "Starting user embedding model pipeline..."
5
+
6
+ # Check if model already exists
7
+ if [ -f "embeddings_output/model.pth" ]; then
8
+ echo "Model already trained."
9
+ else
10
+ # Check if input data exists
11
+ if [ ! -f "$DATA_PATH" ] && [ ! -f "users.json" ]; then
12
+ echo "Error: No data file found. Please mount a volume with users.json or set DATA_PATH."
13
+ exit 1
14
+ fi
15
+
16
+ # Run the model generation script
17
+
18
+ python3 generate_model_gpu.py
19
+ echo "Process completed. Check embeddings_output directory for results."
20
+ fi
21
+
22
+ # Get the hostname or public URL if available (for Hugging Face Spaces)
23
+ if [ -n "$SPACE_ID" ]; then
24
+ # If running in Hugging Face Spaces
25
+ BASE_URL="https://${SPACE_ID}.hf.space"
26
+ else
27
+ # If running locally or in generic Docker
28
+ BASE_URL="http://localhost:7860"
29
+ fi
30
+
31
+ echo "=========================================================="
32
+ echo "Model is available for download at the following URLs:"
33
+ echo "${BASE_URL}/file=embeddings_output/model.pth"
34
+ echo "${BASE_URL}/file=embeddings_output/model_config.json"
35
+ echo "=========================================================="
36
+
37
+ # Start the web server to serve files
38
+ echo "Starting web server on port 7860..."
39
+ python3 -m http.server 7860
users_200k.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:066da954febcc008311c23735766b751be086b4e61394b797692916f488ccbe4
3
- size 297966576
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d91ddb646dfcc955f89edcf756041415a6ba9f3d34cd217e8dc7a6f35cc9161
3
+ size 297966779