Add files copied from SPACES
Browse files- Dockerfile +45 -0
- README.md +120 -0
- generate_model_backup.py +1110 -0
- generate_model_gpu.py +1265 -0
- requirements.txt +12 -0
- run.sh +39 -0
- users_200k.json +2 -2
Dockerfile
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04
|
| 2 |
+
|
| 3 |
+
# Set environment variables
|
| 4 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 5 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 6 |
+
DEBIAN_FRONTEND=noninteractive
|
| 7 |
+
|
| 8 |
+
# Install system dependencies
|
| 9 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 10 |
+
python3 \
|
| 11 |
+
python3-pip \
|
| 12 |
+
python3-setuptools \
|
| 13 |
+
python3-dev \
|
| 14 |
+
build-essential \
|
| 15 |
+
git \
|
| 16 |
+
&& apt-get clean \
|
| 17 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 18 |
+
|
| 19 |
+
# Create a working directory
|
| 20 |
+
WORKDIR /app
|
| 21 |
+
|
| 22 |
+
# Installa dipendenze
|
| 23 |
+
COPY requirements.txt .
|
| 24 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# Prepara directory di output con permessi corretti
|
| 27 |
+
RUN mkdir -p ./model_checkpoints ./model /tmp/huggingface_cache /app/embedding && \
|
| 28 |
+
chmod 777 ./model_checkpoints ./model /tmp/huggingface_cache /app/embedding
|
| 29 |
+
|
| 30 |
+
# Copia tutti i file necessari
|
| 31 |
+
COPY *.py .
|
| 32 |
+
COPY users.json .
|
| 33 |
+
COPY run.sh .
|
| 34 |
+
|
| 35 |
+
RUN chmod +x run.sh
|
| 36 |
+
|
| 37 |
+
# Crea uno script per eseguire il processo e poi copiare i file
|
| 38 |
+
#python generate_embeddings.py\n\
|
| 39 |
+
# Esegui lo script
|
| 40 |
+
|
| 41 |
+
# Esponi la porta per l'interfaccia web
|
| 42 |
+
EXPOSE 7860
|
| 43 |
+
|
| 44 |
+
# Set the entrypoint
|
| 45 |
+
ENTRYPOINT ["./run.sh"]
|
README.md
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Test01
|
| 3 |
+
emoji: 🐨
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
+
startup_duration_timeout: 60m
|
| 9 |
+
pinned: false
|
| 10 |
+
license: apache-2.0
|
| 11 |
+
---
|
| 12 |
+
# User Embedding Model
|
| 13 |
+
|
| 14 |
+
This repository contains a PyTorch model for generating user embeddings based on DMP (Data Management Platform) data. The model creates dense vector representations of users that can be used for recommendation systems, user clustering, and similarity searches.
|
| 15 |
+
|
| 16 |
+
## Quick Start with Docker
|
| 17 |
+
|
| 18 |
+
To run the model using Docker:
|
| 19 |
+
|
| 20 |
+
```bash
|
| 21 |
+
docker build -t user-embedding-model .
|
| 22 |
+
|
| 23 |
+
docker run -v /path/to/your/data:/app/data \
|
| 24 |
+
-e DATA_PATH=/app/data/users.json \
|
| 25 |
+
-e NUM_EPOCHS=10 \
|
| 26 |
+
-e BATCH_SIZE=32 \
|
| 27 |
+
-v /path/to/output:/app/embeddings_output \
|
| 28 |
+
user-embedding-model
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Pushing to Hugging Face
|
| 32 |
+
|
| 33 |
+
To automatically push the model to Hugging Face, add your credentials:
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
docker run -v /path/to/your/data:/app/data \
|
| 37 |
+
-e DATA_PATH=/app/data/users.json \
|
| 38 |
+
-e HF_REPO_ID="your-username/your-model-name" \
|
| 39 |
+
-e HF_TOKEN="your-huggingface-token" \
|
| 40 |
+
-v /path/to/output:/app/embeddings_output \
|
| 41 |
+
user-embedding-model
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
## Input Data Format
|
| 45 |
+
|
| 46 |
+
The model expects user data in JSON format, with each user having DMP fields like:
|
| 47 |
+
|
| 48 |
+
```json
|
| 49 |
+
{
|
| 50 |
+
"dmp": {
|
| 51 |
+
"city": "milano",
|
| 52 |
+
"domains": ["example.com"],
|
| 53 |
+
"brands": ["brand1", "brand2"],
|
| 54 |
+
"clusters": ["cluster1", "cluster2"],
|
| 55 |
+
"industries": ["industry1"],
|
| 56 |
+
"tags": ["tag1", "tag2"],
|
| 57 |
+
"channels": ["channel1"],
|
| 58 |
+
"~click__host": "host1",
|
| 59 |
+
"~click__domain": "domain1",
|
| 60 |
+
"": {
|
| 61 |
+
"id": "user123"
|
| 62 |
+
}
|
| 63 |
+
}
|
| 64 |
+
}
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Environment Variables
|
| 68 |
+
|
| 69 |
+
- `DATA_PATH`: Path to your input JSON file (default: "users.json")
|
| 70 |
+
- `NUM_EPOCHS`: Number of training epochs (default: 10)
|
| 71 |
+
- `BATCH_SIZE`: Batch size for training (default: 32)
|
| 72 |
+
- `LEARNING_RATE`: Learning rate for optimizer (default: 0.001)
|
| 73 |
+
- `SAVE_INTERVAL`: Save checkpoint every N epochs (default: 2)
|
| 74 |
+
- `HF_REPO_ID`: Hugging Face repository ID for uploading
|
| 75 |
+
- `HF_TOKEN`: Hugging Face API token
|
| 76 |
+
|
| 77 |
+
## Output
|
| 78 |
+
|
| 79 |
+
The model generates:
|
| 80 |
+
|
| 81 |
+
1. `embeddings.json`: User embeddings in JSON format
|
| 82 |
+
2. `embeddings.npz`: User embeddings in NumPy format
|
| 83 |
+
3. `vocabularies.json`: Vocabulary mappings
|
| 84 |
+
4. `model.pth`: Trained PyTorch model
|
| 85 |
+
5. `model_config.json`: Model configuration
|
| 86 |
+
6. Hugging Face-compatible model files in the "huggingface" subdirectory
|
| 87 |
+
|
| 88 |
+
## Hardware Requirements
|
| 89 |
+
|
| 90 |
+
- **Recommended**: NVIDIA GPU with CUDA support
|
| 91 |
+
- The code uses parallel processing for triplet generation to utilize all available CPU cores
|
| 92 |
+
- For L40S GPU, recommended batch size: 32-64
|
| 93 |
+
- Memory requirement: At least 8GB RAM
|
| 94 |
+
|
| 95 |
+
## Model Architecture
|
| 96 |
+
|
| 97 |
+
The model consists of:
|
| 98 |
+
|
| 99 |
+
- Embedding layers for each user field
|
| 100 |
+
- Sequential fully connected layers for dimensionality reduction
|
| 101 |
+
- Output dimension: 256 (configurable)
|
| 102 |
+
- Training method: Triplet margin loss using similar/dissimilar users
|
| 103 |
+
|
| 104 |
+
## Performance Optimization
|
| 105 |
+
|
| 106 |
+
The code includes several optimizations:
|
| 107 |
+
|
| 108 |
+
- Parallel triplet generation using all available CPU cores
|
| 109 |
+
- GPU acceleration for model training
|
| 110 |
+
- Efficient memory handling for large datasets
|
| 111 |
+
- TensorBoard integration for monitoring training
|
| 112 |
+
|
| 113 |
+
## Troubleshooting
|
| 114 |
+
|
| 115 |
+
If you encounter issues:
|
| 116 |
+
|
| 117 |
+
1. Check that your input data follows the expected format
|
| 118 |
+
2. Ensure you have sufficient memory for your dataset size
|
| 119 |
+
3. For GPU errors, try reducing batch size
|
| 120 |
+
4. Check the logs for detailed error messages
|
generate_model_backup.py
ADDED
|
@@ -0,0 +1,1110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
| 3 |
+
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
|
| 4 |
+
import multiprocessing
|
| 5 |
+
try:
|
| 6 |
+
multiprocessing.set_start_method('spawn')
|
| 7 |
+
except RuntimeError:
|
| 8 |
+
pass # Il metodo è già stato impostato
|
| 9 |
+
import json
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from typing import List, Dict
|
| 15 |
+
import logging
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
import random
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
import os
|
| 23 |
+
import multiprocessing
|
| 24 |
+
from multiprocessing import Pool
|
| 25 |
+
import psutil
|
| 26 |
+
import argparse
|
| 27 |
+
|
| 28 |
+
# =================================
|
| 29 |
+
# CONFIGURABLE PARAMETERS
|
| 30 |
+
# =================================
|
| 31 |
+
# Define default parameters that can be overridden via environment variables
|
| 32 |
+
DEFAULT_NUM_TRIPLETS = 150000 # Number of triplet examples to generate
|
| 33 |
+
DEFAULT_NUM_EPOCHS = 20 # Number of training epochs
|
| 34 |
+
DEFAULT_BATCH_SIZE = 64 # Batch size for training
|
| 35 |
+
DEFAULT_LEARNING_RATE = 0.001 # Learning rate for optimizer
|
| 36 |
+
DEFAULT_OUTPUT_DIM = 256 # Output dimension of embeddings
|
| 37 |
+
DEFAULT_MAX_SEQ_LENGTH = 15 # Maximum sequence length
|
| 38 |
+
DEFAULT_SAVE_INTERVAL = 2 # Save checkpoint every N epochs
|
| 39 |
+
DEFAULT_DATA_PATH = "./users.json" # Path to user data
|
| 40 |
+
DEFAULT_OUTPUT_DIR = "./model" # Output directory
|
| 41 |
+
|
| 42 |
+
# Read parameters from environment variables
|
| 43 |
+
NUM_TRIPLETS = int(os.environ.get("NUM_TRIPLETS", DEFAULT_NUM_TRIPLETS))
|
| 44 |
+
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", DEFAULT_NUM_EPOCHS))
|
| 45 |
+
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", DEFAULT_BATCH_SIZE))
|
| 46 |
+
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", DEFAULT_LEARNING_RATE))
|
| 47 |
+
OUTPUT_DIM = int(os.environ.get("OUTPUT_DIM", DEFAULT_OUTPUT_DIM))
|
| 48 |
+
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", DEFAULT_MAX_SEQ_LENGTH))
|
| 49 |
+
SAVE_INTERVAL = int(os.environ.get("SAVE_INTERVAL", DEFAULT_SAVE_INTERVAL))
|
| 50 |
+
DATA_PATH = os.environ.get("DATA_PATH", DEFAULT_DATA_PATH)
|
| 51 |
+
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", DEFAULT_OUTPUT_DIR)
|
| 52 |
+
|
| 53 |
+
# Configure logging
|
| 54 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 55 |
+
|
| 56 |
+
# Get CUDA device
|
| 57 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
logging.info(f"Using device: {device}")
|
| 59 |
+
|
| 60 |
+
# =================================
|
| 61 |
+
# MODEL ARCHITECTURE
|
| 62 |
+
# =================================
|
| 63 |
+
class UserEmbeddingModel(nn.Module):
|
| 64 |
+
def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int],
|
| 65 |
+
output_dim: int = 256, max_sequence_length: int = 15,
|
| 66 |
+
padded_fields_length: int = 10):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.max_sequence_length = max_sequence_length
|
| 70 |
+
self.padded_fields_length = padded_fields_length
|
| 71 |
+
self.padded_fields = {'dmp_channels', 'dmp_tags', 'dmp_clusters'}
|
| 72 |
+
self.embedding_layers = nn.ModuleDict()
|
| 73 |
+
|
| 74 |
+
# Create embedding layers for each field
|
| 75 |
+
for field, vocab_size in vocab_sizes.items():
|
| 76 |
+
self.embedding_layers[field] = nn.Embedding(
|
| 77 |
+
vocab_size,
|
| 78 |
+
embedding_dims.get(field, 16),
|
| 79 |
+
padding_idx=0
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Calculate total input dimension
|
| 83 |
+
self.total_input_dim = 0
|
| 84 |
+
for field, dim in embedding_dims.items():
|
| 85 |
+
if field in self.padded_fields:
|
| 86 |
+
self.total_input_dim += dim # Single dimension for padded field
|
| 87 |
+
else:
|
| 88 |
+
self.total_input_dim += dim
|
| 89 |
+
|
| 90 |
+
print(f"Total input dimension: {self.total_input_dim}")
|
| 91 |
+
|
| 92 |
+
self.fc = nn.Sequential(
|
| 93 |
+
nn.Linear(self.total_input_dim, self.total_input_dim // 2),
|
| 94 |
+
nn.ReLU(),
|
| 95 |
+
nn.Dropout(0.2),
|
| 96 |
+
nn.Linear(self.total_input_dim // 2, output_dim),
|
| 97 |
+
nn.LayerNorm(output_dim)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _process_sequence(self, embedding_layer: nn.Embedding, indices: torch.Tensor,
|
| 101 |
+
field_name: str) -> torch.Tensor:
|
| 102 |
+
"""Process normal sequences"""
|
| 103 |
+
batch_size = indices.size(0)
|
| 104 |
+
if indices.numel() == 0:
|
| 105 |
+
return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device)
|
| 106 |
+
|
| 107 |
+
if field_name in ['dmp_city', 'dmp_domains']:
|
| 108 |
+
if indices.dim() == 1:
|
| 109 |
+
indices = indices.unsqueeze(0)
|
| 110 |
+
if indices.size(1) > 0:
|
| 111 |
+
return embedding_layer(indices[:, 0])
|
| 112 |
+
return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device)
|
| 113 |
+
|
| 114 |
+
# Handle multiple sequences
|
| 115 |
+
embeddings = embedding_layer(indices)
|
| 116 |
+
return embeddings.mean(dim=1) # [batch_size, emb_dim]
|
| 117 |
+
|
| 118 |
+
def _process_padded_sequence(self, embedding_layer: nn.Embedding,
|
| 119 |
+
indices: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
"""Process sequences with padding"""
|
| 121 |
+
batch_size = indices.size(0)
|
| 122 |
+
emb_dim = embedding_layer.embedding_dim
|
| 123 |
+
|
| 124 |
+
# Generate embeddings
|
| 125 |
+
embeddings = embedding_layer(indices) # [batch_size, seq_len, emb_dim]
|
| 126 |
+
|
| 127 |
+
# Average along sequence dimension
|
| 128 |
+
mask = (indices != 0).float().unsqueeze(-1)
|
| 129 |
+
masked_embeddings = embeddings * mask
|
| 130 |
+
sum_mask = mask.sum(dim=1).clamp(min=1.0)
|
| 131 |
+
|
| 132 |
+
return (masked_embeddings.sum(dim=1) / sum_mask) # [batch_size, emb_dim]
|
| 133 |
+
|
| 134 |
+
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 135 |
+
batch_embeddings = []
|
| 136 |
+
|
| 137 |
+
for field in ['dmp_city', 'dmp_domains', 'dmp_brands',
|
| 138 |
+
'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels',
|
| 139 |
+
'link_host', 'link_path']:
|
| 140 |
+
if field in inputs and field in self.embedding_layers:
|
| 141 |
+
if field in self.padded_fields:
|
| 142 |
+
emb = self._process_padded_sequence(
|
| 143 |
+
self.embedding_layers[field],
|
| 144 |
+
inputs[field]
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
emb = self._process_sequence(
|
| 148 |
+
self.embedding_layers[field],
|
| 149 |
+
inputs[field],
|
| 150 |
+
field
|
| 151 |
+
)
|
| 152 |
+
batch_embeddings.append(emb)
|
| 153 |
+
|
| 154 |
+
combined = torch.cat(batch_embeddings, dim=1)
|
| 155 |
+
return self.fc(combined)
|
| 156 |
+
|
| 157 |
+
# =================================
|
| 158 |
+
# EMBEDDING PIPELINE
|
| 159 |
+
# =================================
|
| 160 |
+
class UserEmbeddingPipeline:
|
| 161 |
+
def __init__(self, output_dim: int = 256, max_sequence_length: int = 15):
|
| 162 |
+
self.output_dim = output_dim
|
| 163 |
+
self.max_sequence_length = max_sequence_length
|
| 164 |
+
self.model = None
|
| 165 |
+
self.vocab_maps = {}
|
| 166 |
+
|
| 167 |
+
self.fields = [
|
| 168 |
+
'dmp_city', 'dmp_domains', 'dmp_brands',
|
| 169 |
+
'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels',
|
| 170 |
+
'link_host', 'link_path'
|
| 171 |
+
]
|
| 172 |
+
|
| 173 |
+
# Map of new JSON fields to old field names used in the model
|
| 174 |
+
self.field_mapping = {
|
| 175 |
+
'dmp_city': ('dmp', 'city'),
|
| 176 |
+
'dmp_domains': ('dmp', 'domains'),
|
| 177 |
+
'dmp_brands': ('dmp', 'brands'),
|
| 178 |
+
'dmp_clusters': ('dmp', 'clusters'),
|
| 179 |
+
'dmp_industries': ('dmp', 'industries'),
|
| 180 |
+
'dmp_tags': ('dmp', 'tags'),
|
| 181 |
+
'dmp_channels': ('dmp', 'channels'),
|
| 182 |
+
'link_host': ('dmp', '~click__host'),
|
| 183 |
+
'link_path': ('dmp', '~click__domain')
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
self.embedding_dims = {
|
| 187 |
+
'dmp_city': 16,
|
| 188 |
+
'dmp_domains': 16,
|
| 189 |
+
'dmp_brands': 32,
|
| 190 |
+
'dmp_clusters': 32,
|
| 191 |
+
'dmp_industries': 32,
|
| 192 |
+
'dmp_tags': 64,
|
| 193 |
+
'dmp_channels': 32,
|
| 194 |
+
'link_host': 32,
|
| 195 |
+
'link_path': 32
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
def _clean_value(self, value):
|
| 199 |
+
if isinstance(value, float) and np.isnan(value):
|
| 200 |
+
return []
|
| 201 |
+
if isinstance(value, str):
|
| 202 |
+
return [value.lower().strip()]
|
| 203 |
+
if isinstance(value, list):
|
| 204 |
+
return [str(v).lower().strip() for v in value if v is not None and str(v).strip()]
|
| 205 |
+
return []
|
| 206 |
+
|
| 207 |
+
def _get_field_from_user(self, user, field):
|
| 208 |
+
"""Extract field value from new JSON user format"""
|
| 209 |
+
mapping = self.field_mapping.get(field, (field,))
|
| 210 |
+
value = user
|
| 211 |
+
|
| 212 |
+
# Navigate through nested structure
|
| 213 |
+
for key in mapping:
|
| 214 |
+
if isinstance(value, dict):
|
| 215 |
+
value = value.get(key, {})
|
| 216 |
+
else:
|
| 217 |
+
# If not a dictionary and we're not at the last element
|
| 218 |
+
# of the mapping, return an empty list
|
| 219 |
+
value = []
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
# If we've reached the end and have a value that's not a list but should be,
|
| 223 |
+
# convert it to a list
|
| 224 |
+
if field in {'dmp_brands', 'dmp_channels', 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'link_host', 'link_path'} and not isinstance(value, list):
|
| 225 |
+
# If it's a string or other single value, put it in a list
|
| 226 |
+
if value and not isinstance(value, dict):
|
| 227 |
+
value = [value]
|
| 228 |
+
else:
|
| 229 |
+
value = []
|
| 230 |
+
|
| 231 |
+
return value
|
| 232 |
+
|
| 233 |
+
def build_vocabularies(self, users_data: List[Dict]) -> Dict[str, Dict[str, int]]:
|
| 234 |
+
field_values = {field: {'<PAD>'} for field in self.fields}
|
| 235 |
+
|
| 236 |
+
# Extract the 'user' field from the JSON structure for each record
|
| 237 |
+
users = []
|
| 238 |
+
for data in users_data:
|
| 239 |
+
# Check if there's raw_json.user
|
| 240 |
+
if 'raw_json' in data and 'user' in data['raw_json']:
|
| 241 |
+
users.append(data['raw_json']['user'])
|
| 242 |
+
# Check if there's user
|
| 243 |
+
elif 'user' in data:
|
| 244 |
+
users.append(data['user'])
|
| 245 |
+
else:
|
| 246 |
+
users.append(data) # Assume it's already a user
|
| 247 |
+
|
| 248 |
+
for user in users:
|
| 249 |
+
for field in self.fields:
|
| 250 |
+
values = self._clean_value(self._get_field_from_user(user, field))
|
| 251 |
+
field_values[field].update(values)
|
| 252 |
+
|
| 253 |
+
self.vocab_maps = {
|
| 254 |
+
field: {val: idx for idx, val in enumerate(sorted(values))}
|
| 255 |
+
for field, values in field_values.items()
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
return self.vocab_maps
|
| 259 |
+
|
| 260 |
+
def _prepare_input(self, user: Dict) -> Dict[str, torch.Tensor]:
|
| 261 |
+
inputs = {}
|
| 262 |
+
|
| 263 |
+
for field in self.fields:
|
| 264 |
+
values = self._clean_value(self._get_field_from_user(user, field))
|
| 265 |
+
vocab = self.vocab_maps[field]
|
| 266 |
+
indices = [vocab.get(val, 0) for val in values]
|
| 267 |
+
inputs[field] = torch.tensor(indices, dtype=torch.long)
|
| 268 |
+
|
| 269 |
+
return inputs
|
| 270 |
+
|
| 271 |
+
def initialize_model(self) -> None:
|
| 272 |
+
vocab_sizes = {field: len(vocab) for field, vocab in self.vocab_maps.items()}
|
| 273 |
+
|
| 274 |
+
self.model = UserEmbeddingModel(
|
| 275 |
+
vocab_sizes=vocab_sizes,
|
| 276 |
+
embedding_dims=self.embedding_dims,
|
| 277 |
+
output_dim=self.output_dim,
|
| 278 |
+
max_sequence_length=self.max_sequence_length
|
| 279 |
+
)
|
| 280 |
+
self.model.to(device)
|
| 281 |
+
self.model.eval()
|
| 282 |
+
|
| 283 |
+
def generate_embeddings(self, users_data: List[Dict], batch_size: int = 32) -> Dict[str, np.ndarray]:
|
| 284 |
+
"""Generate embeddings for all users"""
|
| 285 |
+
embeddings = {}
|
| 286 |
+
self.model.eval() # Make sure model is in eval mode
|
| 287 |
+
|
| 288 |
+
# Extract the 'user' field from the JSON structure for each record
|
| 289 |
+
users = []
|
| 290 |
+
user_ids = []
|
| 291 |
+
|
| 292 |
+
for data in users_data:
|
| 293 |
+
# Check if there's raw_json.user
|
| 294 |
+
if 'raw_json' in data and 'user' in data['raw_json']:
|
| 295 |
+
user = data['raw_json']['user']
|
| 296 |
+
users.append(user)
|
| 297 |
+
# Use user.dmp[''].id as identifier
|
| 298 |
+
if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']:
|
| 299 |
+
user_ids.append(str(user['dmp']['']['id']))
|
| 300 |
+
else:
|
| 301 |
+
# Fallback to uid or id if dmp.id is not available
|
| 302 |
+
user_ids.append(str(user.get('uid', user.get('id', None))))
|
| 303 |
+
# Check if there's user
|
| 304 |
+
elif 'user' in data:
|
| 305 |
+
user = data['user']
|
| 306 |
+
users.append(user)
|
| 307 |
+
# Use user.dmp[''].id as identifier
|
| 308 |
+
if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']:
|
| 309 |
+
user_ids.append(str(user['dmp']['']['id']))
|
| 310 |
+
else:
|
| 311 |
+
# Fallback to uid or id if dmp.id is not available
|
| 312 |
+
user_ids.append(str(user.get('uid', user.get('id', None))))
|
| 313 |
+
else:
|
| 314 |
+
users.append(data) # Assume it's already a user
|
| 315 |
+
# Use user.dmp[''].id as identifier
|
| 316 |
+
if 'dmp' in data and '' in data['dmp'] and 'id' in data['dmp']['']:
|
| 317 |
+
user_ids.append(str(data['dmp']['']['id']))
|
| 318 |
+
else:
|
| 319 |
+
# Fallback to uid or id if dmp.id is not available
|
| 320 |
+
user_ids.append(str(data.get('uid', data.get('id', None))))
|
| 321 |
+
|
| 322 |
+
with torch.no_grad():
|
| 323 |
+
for i in tqdm(range(0, len(users), batch_size), desc="Generating embeddings"):
|
| 324 |
+
batch_users = users[i:i+batch_size]
|
| 325 |
+
batch_ids = user_ids[i:i+batch_size]
|
| 326 |
+
batch_inputs = []
|
| 327 |
+
valid_indices = []
|
| 328 |
+
|
| 329 |
+
for j, user in enumerate(batch_users):
|
| 330 |
+
if batch_ids[j] is not None:
|
| 331 |
+
batch_inputs.append(self._prepare_input(user))
|
| 332 |
+
valid_indices.append(j)
|
| 333 |
+
|
| 334 |
+
if batch_inputs:
|
| 335 |
+
# Use the same collate function as training for a single batch
|
| 336 |
+
anchor_batch, _, _ = collate_batch([(inputs, inputs, inputs) for inputs in batch_inputs])
|
| 337 |
+
|
| 338 |
+
# Move data to device
|
| 339 |
+
anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()}
|
| 340 |
+
|
| 341 |
+
# Generate embeddings
|
| 342 |
+
batch_embeddings = self.model(anchor_batch).cpu()
|
| 343 |
+
|
| 344 |
+
# Save embeddings
|
| 345 |
+
for j, idx in enumerate(valid_indices):
|
| 346 |
+
if batch_ids[idx]: # Verify that id is not None or empty
|
| 347 |
+
embeddings[batch_ids[idx]] = batch_embeddings[j].numpy()
|
| 348 |
+
|
| 349 |
+
return embeddings
|
| 350 |
+
|
| 351 |
+
def save_embeddings(self, embeddings: Dict[str, np.ndarray], output_dir: str) -> None:
|
| 352 |
+
"""Save embeddings to file"""
|
| 353 |
+
output_dir = Path(output_dir)
|
| 354 |
+
output_dir.mkdir(exist_ok=True)
|
| 355 |
+
|
| 356 |
+
# Save embeddings as JSON
|
| 357 |
+
json_path = output_dir / 'embeddings.json'
|
| 358 |
+
with open(json_path, 'w') as f:
|
| 359 |
+
json_embeddings = {user_id: emb.tolist() for user_id, emb in embeddings.items()}
|
| 360 |
+
json.dump(json_embeddings, f)
|
| 361 |
+
|
| 362 |
+
# Save embeddings as NPY
|
| 363 |
+
npy_path = output_dir / 'embeddings.npz'
|
| 364 |
+
np.savez_compressed(npy_path,
|
| 365 |
+
embeddings=np.stack(list(embeddings.values())),
|
| 366 |
+
user_ids=np.array(list(embeddings.keys())))
|
| 367 |
+
|
| 368 |
+
# Save vocabularies
|
| 369 |
+
vocab_path = output_dir / 'vocabularies.json'
|
| 370 |
+
with open(vocab_path, 'w') as f:
|
| 371 |
+
json.dump(self.vocab_maps, f)
|
| 372 |
+
|
| 373 |
+
logging.info(f"\nEmbeddings saved in {output_dir}:")
|
| 374 |
+
logging.info(f"- Embeddings JSON: {json_path}")
|
| 375 |
+
logging.info(f"- Embeddings NPY: {npy_path}")
|
| 376 |
+
logging.info(f"- Vocabularies: {vocab_path}")
|
| 377 |
+
|
| 378 |
+
def save_model(self, output_dir: str) -> None:
|
| 379 |
+
"""Save model in PyTorch format (.pth)"""
|
| 380 |
+
output_dir = Path(output_dir)
|
| 381 |
+
output_dir.mkdir(exist_ok=True)
|
| 382 |
+
|
| 383 |
+
# Save path
|
| 384 |
+
model_path = output_dir / 'model.pth'
|
| 385 |
+
|
| 386 |
+
# Prepare dictionary with model state and metadata
|
| 387 |
+
checkpoint = {
|
| 388 |
+
'model_state_dict': self.model.state_dict(),
|
| 389 |
+
'vocab_maps': self.vocab_maps,
|
| 390 |
+
'embedding_dims': self.embedding_dims,
|
| 391 |
+
'output_dim': self.output_dim,
|
| 392 |
+
'max_sequence_length': self.max_sequence_length
|
| 393 |
+
}
|
| 394 |
+
|
| 395 |
+
# Save model
|
| 396 |
+
torch.save(checkpoint, model_path)
|
| 397 |
+
|
| 398 |
+
logging.info(f"Model saved to: {model_path}")
|
| 399 |
+
|
| 400 |
+
# Also save a configuration file for reference
|
| 401 |
+
config_info = {
|
| 402 |
+
'model_type': 'UserEmbeddingModel',
|
| 403 |
+
'vocab_sizes': {field: len(vocab) for field, vocab in self.vocab_maps.items()},
|
| 404 |
+
'embedding_dims': self.embedding_dims,
|
| 405 |
+
'output_dim': self.output_dim,
|
| 406 |
+
'max_sequence_length': self.max_sequence_length,
|
| 407 |
+
'padded_fields': list(self.model.padded_fields),
|
| 408 |
+
'fields': self.fields
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
config_path = output_dir / 'model_config.json'
|
| 412 |
+
with open(config_path, 'w') as f:
|
| 413 |
+
json.dump(config_info, f, indent=2)
|
| 414 |
+
|
| 415 |
+
logging.info(f"Model configuration saved to: {config_path}")
|
| 416 |
+
|
| 417 |
+
# Save model in HuggingFace format
|
| 418 |
+
hf_dir = output_dir / 'huggingface'
|
| 419 |
+
hf_dir.mkdir(exist_ok=True)
|
| 420 |
+
|
| 421 |
+
# Save model in HF format
|
| 422 |
+
torch.save(self.model.state_dict(), hf_dir / 'pytorch_model.bin')
|
| 423 |
+
|
| 424 |
+
# Save config
|
| 425 |
+
with open(hf_dir / 'config.json', 'w') as f:
|
| 426 |
+
json.dump(config_info, f, indent=2)
|
| 427 |
+
|
| 428 |
+
logging.info(f"Model saved in HuggingFace format to: {hf_dir}")
|
| 429 |
+
|
| 430 |
+
def load_model(self, model_path: str) -> None:
|
| 431 |
+
"""Load a previously saved model"""
|
| 432 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 433 |
+
|
| 434 |
+
# Reload vocabularies and dimensions if needed
|
| 435 |
+
self.vocab_maps = checkpoint.get('vocab_maps', self.vocab_maps)
|
| 436 |
+
self.embedding_dims = checkpoint.get('embedding_dims', self.embedding_dims)
|
| 437 |
+
self.output_dim = checkpoint.get('output_dim', self.output_dim)
|
| 438 |
+
self.max_sequence_length = checkpoint.get('max_sequence_length', self.max_sequence_length)
|
| 439 |
+
|
| 440 |
+
# Initialize model if not already done
|
| 441 |
+
if self.model is None:
|
| 442 |
+
self.initialize_model()
|
| 443 |
+
|
| 444 |
+
# Load model weights
|
| 445 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 446 |
+
self.model.to(device)
|
| 447 |
+
self.model.eval()
|
| 448 |
+
|
| 449 |
+
logging.info(f"Model loaded from: {model_path}")
|
| 450 |
+
|
| 451 |
+
# =================================
|
| 452 |
+
# SIMILARITY AND TRIPLET GENERATION
|
| 453 |
+
# =================================
|
| 454 |
+
def calculate_similarity(user1, user2, pipeline):
|
| 455 |
+
try:
|
| 456 |
+
channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None)
|
| 457 |
+
channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None)
|
| 458 |
+
clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None)
|
| 459 |
+
clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None)
|
| 460 |
+
|
| 461 |
+
channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2))
|
| 462 |
+
cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | clusters2))
|
| 463 |
+
|
| 464 |
+
return 0.5 * channel_sim + 0.5 * cluster_sim
|
| 465 |
+
except Exception as e:
|
| 466 |
+
logging.error(f"Error calculating similarity: {str(e)}")
|
| 467 |
+
return 0.0
|
| 468 |
+
|
| 469 |
+
def process_batch_triplets(args):
|
| 470 |
+
try:
|
| 471 |
+
batch_idx, users, channel_index, cluster_index, num_triplets, pipeline = args
|
| 472 |
+
batch_triplets = []
|
| 473 |
+
|
| 474 |
+
# Forza l'uso della CPU per tutti i calcoli
|
| 475 |
+
with torch.no_grad():
|
| 476 |
+
# Imposta temporaneamente il dispositivo su CPU per il calcolo delle similarità
|
| 477 |
+
temp_device = torch.device("cpu")
|
| 478 |
+
|
| 479 |
+
for _ in range(num_triplets):
|
| 480 |
+
anchor_idx = random.randint(0, len(users)-1)
|
| 481 |
+
anchor_user = users[anchor_idx]
|
| 482 |
+
|
| 483 |
+
# Find candidates that share channels or clusters
|
| 484 |
+
candidates = set()
|
| 485 |
+
for channel in pipeline._get_field_from_user(anchor_user, 'dmp_channels'):
|
| 486 |
+
candidates.update(channel_index.get(str(channel), []))
|
| 487 |
+
for cluster in pipeline._get_field_from_user(anchor_user, 'dmp_clusters'):
|
| 488 |
+
candidates.update(cluster_index.get(str(cluster), []))
|
| 489 |
+
|
| 490 |
+
# Remove anchor
|
| 491 |
+
candidates.discard(anchor_idx)
|
| 492 |
+
|
| 493 |
+
# Find positive (similar) user
|
| 494 |
+
if not candidates:
|
| 495 |
+
positive_idx = random.randint(0, len(users)-1)
|
| 496 |
+
else:
|
| 497 |
+
# Calculate similarities for candidates
|
| 498 |
+
similarities = []
|
| 499 |
+
for idx in candidates:
|
| 500 |
+
# Calcolo della similarità senza CUDA
|
| 501 |
+
sim = cpu_calculate_similarity(anchor_user, users[idx], pipeline)
|
| 502 |
+
if sim > 0:
|
| 503 |
+
similarities.append((idx, sim))
|
| 504 |
+
|
| 505 |
+
if not similarities:
|
| 506 |
+
positive_idx = random.randint(0, len(users)-1)
|
| 507 |
+
else:
|
| 508 |
+
# Sort by similarity
|
| 509 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 510 |
+
# Return one of the top K most similar
|
| 511 |
+
top_k = min(10, len(similarities))
|
| 512 |
+
positive_idx = similarities[random.randint(0, top_k-1)][0]
|
| 513 |
+
|
| 514 |
+
# Find negative (dissimilar) user
|
| 515 |
+
max_attempts = 50
|
| 516 |
+
negative_idx = None
|
| 517 |
+
|
| 518 |
+
for _ in range(max_attempts):
|
| 519 |
+
idx = random.randint(0, len(users)-1)
|
| 520 |
+
if idx != anchor_idx and idx != positive_idx:
|
| 521 |
+
# Calcolo della similarità senza CUDA
|
| 522 |
+
if cpu_calculate_similarity(anchor_user, users[idx], pipeline) < 0.1:
|
| 523 |
+
negative_idx = idx
|
| 524 |
+
break
|
| 525 |
+
|
| 526 |
+
if negative_idx is None:
|
| 527 |
+
negative_idx = random.randint(0, len(users)-1)
|
| 528 |
+
|
| 529 |
+
batch_triplets.append((anchor_idx, positive_idx, negative_idx))
|
| 530 |
+
|
| 531 |
+
return batch_triplets
|
| 532 |
+
except Exception as e:
|
| 533 |
+
logging.error(f"Error in batch triplet generation: {str(e)}")
|
| 534 |
+
return []
|
| 535 |
+
|
| 536 |
+
# Versione CPU della funzione di calcolo similarità
|
| 537 |
+
def cpu_calculate_similarity(user1, user2, pipeline):
|
| 538 |
+
try:
|
| 539 |
+
channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None)
|
| 540 |
+
channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None)
|
| 541 |
+
clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None)
|
| 542 |
+
clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None)
|
| 543 |
+
|
| 544 |
+
channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2))
|
| 545 |
+
cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | clusters2))
|
| 546 |
+
|
| 547 |
+
return 0.5 * channel_sim + 0.5 * cluster_sim
|
| 548 |
+
except Exception as e:
|
| 549 |
+
logging.error(f"Error calculating similarity: {str(e)}")
|
| 550 |
+
return 0.0
|
| 551 |
+
|
| 552 |
+
# =================================
|
| 553 |
+
# DATASET AND DATALOADER
|
| 554 |
+
# =================================
|
| 555 |
+
class UserSimilarityDataset(Dataset):
|
| 556 |
+
def __init__(self, pipeline, users_data, num_triplets=10, num_workers=None):
|
| 557 |
+
self.triplets = []
|
| 558 |
+
logging.info("Initializing UserSimilarityDataset...")
|
| 559 |
+
|
| 560 |
+
# Extract the 'user' field from the JSON structure for each record
|
| 561 |
+
self.users = []
|
| 562 |
+
for data in users_data:
|
| 563 |
+
# Check if there's raw_json.user
|
| 564 |
+
if 'raw_json' in data and 'user' in data['raw_json']:
|
| 565 |
+
self.users.append(data['raw_json']['user'])
|
| 566 |
+
# Check if there's user
|
| 567 |
+
elif 'user' in data:
|
| 568 |
+
self.users.append(data['user'])
|
| 569 |
+
else:
|
| 570 |
+
self.users.append(data) # Assume it's already a user
|
| 571 |
+
|
| 572 |
+
self.pipeline = pipeline
|
| 573 |
+
self.num_triplets = num_triplets
|
| 574 |
+
|
| 575 |
+
# Determine number of workers for parallel processing
|
| 576 |
+
if num_workers is None:
|
| 577 |
+
num_workers = max(1, min(8, os.cpu_count()))
|
| 578 |
+
self.num_workers = num_workers
|
| 579 |
+
|
| 580 |
+
# Pre-process inputs for each user
|
| 581 |
+
self.preprocessed_inputs = {}
|
| 582 |
+
for idx, user in enumerate(self.users):
|
| 583 |
+
self.preprocessed_inputs[idx] = pipeline._prepare_input(user)
|
| 584 |
+
|
| 585 |
+
logging.info("Creating indexes for channels and clusters...")
|
| 586 |
+
self.channel_index = defaultdict(list)
|
| 587 |
+
self.cluster_index = defaultdict(list)
|
| 588 |
+
|
| 589 |
+
for idx, user in enumerate(self.users):
|
| 590 |
+
channels = pipeline._get_field_from_user(user, 'dmp_channels')
|
| 591 |
+
clusters = pipeline._get_field_from_user(user, 'dmp_clusters')
|
| 592 |
+
|
| 593 |
+
if channels:
|
| 594 |
+
channels = [str(c) for c in channels if c is not None]
|
| 595 |
+
if clusters:
|
| 596 |
+
clusters = [str(c) for c in clusters if c is not None]
|
| 597 |
+
|
| 598 |
+
for channel in channels:
|
| 599 |
+
self.channel_index[channel].append(idx)
|
| 600 |
+
for cluster in clusters:
|
| 601 |
+
self.cluster_index[cluster].append(idx)
|
| 602 |
+
|
| 603 |
+
logging.info(f"Found {len(self.channel_index)} unique channels and {len(self.cluster_index)} unique clusters")
|
| 604 |
+
logging.info(f"Generating triplets using {self.num_workers} worker processes...")
|
| 605 |
+
|
| 606 |
+
self.triplets = self._generate_triplets_gpu(num_triplets)
|
| 607 |
+
logging.info(f"Generated {len(self.triplets)} triplets")
|
| 608 |
+
|
| 609 |
+
# Verifica che questo metodo sia definito correttamente
|
| 610 |
+
def __len__(self):
|
| 611 |
+
return len(self.triplets)
|
| 612 |
+
|
| 613 |
+
def __getitem__(self, idx):
|
| 614 |
+
if idx >= len(self.triplets):
|
| 615 |
+
raise IndexError(f"Index {idx} out of range for dataset with {len(self.triplets)} triplets")
|
| 616 |
+
|
| 617 |
+
anchor_idx, positive_idx, negative_idx = self.triplets[idx]
|
| 618 |
+
return (
|
| 619 |
+
self.preprocessed_inputs[anchor_idx],
|
| 620 |
+
self.preprocessed_inputs[positive_idx],
|
| 621 |
+
self.preprocessed_inputs[negative_idx]
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def _generate_triplets_gpu(self, num_triplets):
|
| 628 |
+
"""Generate triplets using a more reliable approach with batch processing"""
|
| 629 |
+
logging.info("Generating triplets with batch approach...")
|
| 630 |
+
|
| 631 |
+
triplets = []
|
| 632 |
+
batch_size = 10 # Numero di triplette da generare per batch
|
| 633 |
+
num_batches = (num_triplets + batch_size - 1) // batch_size
|
| 634 |
+
|
| 635 |
+
progress_bar = tqdm(
|
| 636 |
+
range(num_batches),
|
| 637 |
+
desc="Generating triplet batches",
|
| 638 |
+
bar_format='{l_bar}{bar:10}{r_bar}'
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
for _ in progress_bar:
|
| 642 |
+
batch_triplets = []
|
| 643 |
+
|
| 644 |
+
# Genera un batch di triplette
|
| 645 |
+
for i in range(batch_size):
|
| 646 |
+
if len(triplets) >= num_triplets:
|
| 647 |
+
break
|
| 648 |
+
|
| 649 |
+
# Seleziona anchor casuale
|
| 650 |
+
anchor_idx = random.randint(0, len(self.users)-1)
|
| 651 |
+
anchor_user = self.users[anchor_idx]
|
| 652 |
+
|
| 653 |
+
# Trova candidati che condividono channels o clusters
|
| 654 |
+
candidates = set()
|
| 655 |
+
for channel in self.pipeline._get_field_from_user(anchor_user, 'dmp_channels'):
|
| 656 |
+
candidates.update(self.channel_index.get(str(channel), []))
|
| 657 |
+
for cluster in self.pipeline._get_field_from_user(anchor_user, 'dmp_clusters'):
|
| 658 |
+
candidates.update(self.cluster_index.get(str(cluster), []))
|
| 659 |
+
|
| 660 |
+
# Rimuovi l'anchor dai candidati
|
| 661 |
+
candidates.discard(anchor_idx)
|
| 662 |
+
|
| 663 |
+
# Trova esempio positivo
|
| 664 |
+
if candidates:
|
| 665 |
+
similarities = []
|
| 666 |
+
for idx in list(candidates)[:50]: # Limita la ricerca ai primi 50 candidati
|
| 667 |
+
sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline)
|
| 668 |
+
if sim > 0:
|
| 669 |
+
similarities.append((idx, sim))
|
| 670 |
+
|
| 671 |
+
if similarities:
|
| 672 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 673 |
+
top_k = min(10, len(similarities))
|
| 674 |
+
positive_idx = similarities[random.randint(0, top_k-1)][0]
|
| 675 |
+
else:
|
| 676 |
+
positive_idx = random.randint(0, len(self.users)-1)
|
| 677 |
+
else:
|
| 678 |
+
positive_idx = random.randint(0, len(self.users)-1)
|
| 679 |
+
|
| 680 |
+
# Trova esempio negativo
|
| 681 |
+
attempts = 0
|
| 682 |
+
negative_idx = None
|
| 683 |
+
|
| 684 |
+
while attempts < 20 and negative_idx is None: # Ridotto a 20 tentativi
|
| 685 |
+
idx = random.randint(0, len(self.users)-1)
|
| 686 |
+
if idx != anchor_idx and idx != positive_idx:
|
| 687 |
+
sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline)
|
| 688 |
+
if sim < 0.1:
|
| 689 |
+
negative_idx = idx
|
| 690 |
+
break
|
| 691 |
+
attempts += 1
|
| 692 |
+
|
| 693 |
+
if negative_idx is None:
|
| 694 |
+
negative_idx = random.randint(0, len(self.users)-1)
|
| 695 |
+
|
| 696 |
+
batch_triplets.append((anchor_idx, positive_idx, negative_idx))
|
| 697 |
+
|
| 698 |
+
triplets.extend(batch_triplets)
|
| 699 |
+
|
| 700 |
+
return triplets[:num_triplets] # Assicurati di restituire esattamente num_triplets
|
| 701 |
+
|
| 702 |
+
def collate_batch(batch):
|
| 703 |
+
"""Custom collate function to properly handle tensor dimensions"""
|
| 704 |
+
anchor_inputs, positive_inputs, negative_inputs = zip(*batch)
|
| 705 |
+
|
| 706 |
+
def process_group_inputs(group_inputs):
|
| 707 |
+
processed = {}
|
| 708 |
+
for field in group_inputs[0].keys():
|
| 709 |
+
# Find maximum length for this field in the batch
|
| 710 |
+
max_len = max(inputs[field].size(0) for inputs in group_inputs)
|
| 711 |
+
|
| 712 |
+
# Create padded tensors
|
| 713 |
+
padded = torch.stack([
|
| 714 |
+
torch.cat([
|
| 715 |
+
inputs[field],
|
| 716 |
+
torch.zeros(max_len - inputs[field].size(0), dtype=torch.long)
|
| 717 |
+
]) if inputs[field].size(0) < max_len else inputs[field][:max_len]
|
| 718 |
+
for inputs in group_inputs
|
| 719 |
+
])
|
| 720 |
+
|
| 721 |
+
processed[field] = padded
|
| 722 |
+
|
| 723 |
+
return processed
|
| 724 |
+
|
| 725 |
+
# Process each group (anchor, positive, negative)
|
| 726 |
+
anchor_batch = process_group_inputs(anchor_inputs)
|
| 727 |
+
positive_batch = process_group_inputs(positive_inputs)
|
| 728 |
+
negative_batch = process_group_inputs(negative_inputs)
|
| 729 |
+
|
| 730 |
+
return anchor_batch, positive_batch, negative_batch
|
| 731 |
+
|
| 732 |
+
# =================================
|
| 733 |
+
# TRAINING FUNCTION
|
| 734 |
+
# =================================
|
| 735 |
+
|
| 736 |
+
def train_user_embeddings(model, users_data, pipeline, num_epochs=10, batch_size=32, lr=0.001, save_dir=None, save_interval=2, num_triplets=150):
|
| 737 |
+
"""Main training of the model with proper batch handling and incremental saving"""
|
| 738 |
+
model.train()
|
| 739 |
+
model.to(device)
|
| 740 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
| 741 |
+
|
| 742 |
+
# Add learning rate scheduler
|
| 743 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 744 |
+
optimizer,
|
| 745 |
+
step_size=2, # Decay every 2 epochs
|
| 746 |
+
gamma=0.9 # Multiply by 0.9 (10% reduction)
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
# Determine number of CPU cores to use
|
| 750 |
+
num_cpu_cores = max(1, min(32, os.cpu_count()))
|
| 751 |
+
logging.info(f"Using {num_cpu_cores} CPU cores for data processing")
|
| 752 |
+
|
| 753 |
+
# Prepare dataset and dataloader with custom collate function
|
| 754 |
+
dataset = UserSimilarityDataset(
|
| 755 |
+
pipeline,
|
| 756 |
+
users_data,
|
| 757 |
+
num_triplets=num_triplets,
|
| 758 |
+
num_workers=num_cpu_cores
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
dataloader = DataLoader(
|
| 762 |
+
dataset,
|
| 763 |
+
batch_size=batch_size,
|
| 764 |
+
shuffle=True,
|
| 765 |
+
collate_fn=collate_batch,
|
| 766 |
+
num_workers=0, # For loading batches
|
| 767 |
+
pin_memory=True # Speed up data transfer to GPU
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
# Loss function
|
| 771 |
+
criterion = torch.nn.TripletMarginLoss(margin=1.0)
|
| 772 |
+
|
| 773 |
+
# Progress bar for epochs
|
| 774 |
+
epoch_pbar = tqdm(
|
| 775 |
+
range(num_epochs),
|
| 776 |
+
desc="Training Progress",
|
| 777 |
+
bar_format='{l_bar}{bar:10}{r_bar}'
|
| 778 |
+
)
|
| 779 |
+
# Set up tensorboard for logging
|
| 780 |
+
try:
|
| 781 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 782 |
+
log_dir = Path(save_dir) / "logs" if save_dir else Path("./logs")
|
| 783 |
+
log_dir.mkdir(exist_ok=True, parents=True)
|
| 784 |
+
writer = SummaryWriter(log_dir=log_dir)
|
| 785 |
+
tensorboard_available = True
|
| 786 |
+
except ImportError:
|
| 787 |
+
logging.warning("TensorBoard not available, skipping logging")
|
| 788 |
+
tensorboard_available = False
|
| 789 |
+
|
| 790 |
+
for epoch in epoch_pbar:
|
| 791 |
+
total_loss = 0
|
| 792 |
+
num_batches = 0
|
| 793 |
+
|
| 794 |
+
# Progress bar for batches
|
| 795 |
+
|
| 796 |
+
total_batches = len(dataloader)
|
| 797 |
+
update_freq = max(1, total_batches // 10) # Update approximately every 10%
|
| 798 |
+
batch_pbar = tqdm(
|
| 799 |
+
dataloader,
|
| 800 |
+
desc=f"Epoch {epoch+1}/{num_epochs}",
|
| 801 |
+
leave=False,
|
| 802 |
+
miniters=update_freq, # Only update progress bar every update_freq iterations
|
| 803 |
+
bar_format='{l_bar}{bar:10}{r_bar}', # Simplified bar format
|
| 804 |
+
disable=True # Completely disable the inner progress bar
|
| 805 |
+
)
|
| 806 |
+
|
| 807 |
+
# First, create a single progress indicator for the entire epoch instead of the batch_pbar
|
| 808 |
+
epoch_progress = tqdm(
|
| 809 |
+
total=len(dataloader),
|
| 810 |
+
desc=f"Epoch {epoch+1}/{num_epochs}",
|
| 811 |
+
leave=True,
|
| 812 |
+
bar_format='{l_bar}{bar:10}{r_bar}'
|
| 813 |
+
)
|
| 814 |
+
|
| 815 |
+
# Then use this updated batch processing loop
|
| 816 |
+
for batch_idx, batch_inputs in enumerate(dataloader):
|
| 817 |
+
try:
|
| 818 |
+
# Each element in the batch is already a dict of padded tensors
|
| 819 |
+
anchor_batch, positive_batch, negative_batch = batch_inputs
|
| 820 |
+
|
| 821 |
+
# Move data to device (GPU if available)
|
| 822 |
+
anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()}
|
| 823 |
+
positive_batch = {k: v.to(device) for k, v in positive_batch.items()}
|
| 824 |
+
negative_batch = {k: v.to(device) for k, v in negative_batch.items()}
|
| 825 |
+
|
| 826 |
+
# Generate embeddings
|
| 827 |
+
anchor_emb = model(anchor_batch)
|
| 828 |
+
positive_emb = model(positive_batch)
|
| 829 |
+
negative_emb = model(negative_batch)
|
| 830 |
+
|
| 831 |
+
# Calculate loss
|
| 832 |
+
loss = criterion(anchor_emb, positive_emb, negative_emb)
|
| 833 |
+
|
| 834 |
+
# Backward and optimize
|
| 835 |
+
optimizer.zero_grad()
|
| 836 |
+
loss.backward()
|
| 837 |
+
# Add gradient clipping
|
| 838 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 839 |
+
optimizer.step()
|
| 840 |
+
|
| 841 |
+
total_loss += loss.item()
|
| 842 |
+
num_batches += 1
|
| 843 |
+
|
| 844 |
+
# Update the epoch progress bar only at 10% intervals or at the end
|
| 845 |
+
update_interval = max(1, len(dataloader) // 10)
|
| 846 |
+
if (batch_idx + 1) % update_interval == 0 or batch_idx == len(dataloader) - 1:
|
| 847 |
+
# Update progress
|
| 848 |
+
remaining = min(update_interval, len(dataloader) - epoch_progress.n)
|
| 849 |
+
epoch_progress.update(remaining)
|
| 850 |
+
# Update stats with current average loss
|
| 851 |
+
current_avg_loss = total_loss / num_batches
|
| 852 |
+
epoch_progress.set_postfix(avg_loss=f"{current_avg_loss:.4f}",
|
| 853 |
+
last_batch_loss=f"{loss.item():.4f}")
|
| 854 |
+
|
| 855 |
+
except Exception as e:
|
| 856 |
+
logging.error(f"Error during batch processing: {str(e)}")
|
| 857 |
+
logging.error(f"Batch details: {str(e.__class__.__name__)}")
|
| 858 |
+
continue
|
| 859 |
+
|
| 860 |
+
# Close the progress bar at the end of the epoch
|
| 861 |
+
epoch_progress.close()
|
| 862 |
+
|
| 863 |
+
avg_loss = total_loss / max(1, num_batches)
|
| 864 |
+
# Update epoch progress bar with average loss
|
| 865 |
+
epoch_pbar.set_postfix(avg_loss=f"{avg_loss:.4f}")
|
| 866 |
+
|
| 867 |
+
# Log to tensorboard
|
| 868 |
+
if tensorboard_available:
|
| 869 |
+
writer.add_scalar('Loss/train', avg_loss, epoch)
|
| 870 |
+
|
| 871 |
+
# Step the learning rate scheduler at the end of each epoch
|
| 872 |
+
scheduler.step()
|
| 873 |
+
|
| 874 |
+
# Incremental model saving if requested
|
| 875 |
+
if save_dir and (epoch + 1) % save_interval == 0:
|
| 876 |
+
checkpoint = {
|
| 877 |
+
'epoch': epoch,
|
| 878 |
+
'model_state_dict': model.state_dict(),
|
| 879 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 880 |
+
'loss': avg_loss,
|
| 881 |
+
'scheduler_state_dict': scheduler.state_dict() # Save scheduler state
|
| 882 |
+
}
|
| 883 |
+
|
| 884 |
+
save_path = Path(save_dir) / f'model_checkpoint_epoch_{epoch+1}.pth'
|
| 885 |
+
torch.save(checkpoint, save_path)
|
| 886 |
+
logging.info(f"Checkpoint saved at epoch {epoch+1}: {save_path}")
|
| 887 |
+
|
| 888 |
+
if tensorboard_available:
|
| 889 |
+
writer.close()
|
| 890 |
+
|
| 891 |
+
return model
|
| 892 |
+
|
| 893 |
+
# =================================
|
| 894 |
+
# MAIN FUNCTION
|
| 895 |
+
# =================================
|
| 896 |
+
def main():
|
| 897 |
+
# Configuration
|
| 898 |
+
output_dir = Path(OUTPUT_DIR)
|
| 899 |
+
|
| 900 |
+
# Check if CUDA is available
|
| 901 |
+
cuda_available = torch.cuda.is_available()
|
| 902 |
+
logging.info(f"CUDA available: {cuda_available}")
|
| 903 |
+
if cuda_available:
|
| 904 |
+
logging.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
| 905 |
+
logging.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 906 |
+
|
| 907 |
+
# CPU info
|
| 908 |
+
cpu_count = os.cpu_count()
|
| 909 |
+
memory_info = psutil.virtual_memory()
|
| 910 |
+
logging.info(f"CPU cores: {cpu_count}")
|
| 911 |
+
logging.info(f"System memory: {memory_info.total / 1e9:.2f} GB")
|
| 912 |
+
|
| 913 |
+
# Print configuration
|
| 914 |
+
logging.info("Running with the following configuration:")
|
| 915 |
+
logging.info(f"- Number of triplets: {NUM_TRIPLETS}")
|
| 916 |
+
logging.info(f"- Number of epochs: {NUM_EPOCHS}")
|
| 917 |
+
logging.info(f"- Batch size: {BATCH_SIZE}")
|
| 918 |
+
logging.info(f"- Learning rate: {LEARNING_RATE}")
|
| 919 |
+
logging.info(f"- Output dimension: {OUTPUT_DIM}")
|
| 920 |
+
logging.info(f"- Data path: {DATA_PATH}")
|
| 921 |
+
logging.info(f"- Output directory: {OUTPUT_DIR}")
|
| 922 |
+
|
| 923 |
+
# Load data
|
| 924 |
+
logging.info("Loading user data...")
|
| 925 |
+
try:
|
| 926 |
+
try:
|
| 927 |
+
# First try loading as normal JSON
|
| 928 |
+
with open(DATA_PATH, 'r') as f:
|
| 929 |
+
json_data = json.load(f)
|
| 930 |
+
|
| 931 |
+
# Handle both cases: array of users or single object
|
| 932 |
+
if isinstance(json_data, list):
|
| 933 |
+
users_data = json_data
|
| 934 |
+
elif isinstance(json_data, dict):
|
| 935 |
+
# If it's a single record, put it in a list
|
| 936 |
+
users_data = [json_data]
|
| 937 |
+
else:
|
| 938 |
+
raise ValueError("Unsupported JSON format")
|
| 939 |
+
|
| 940 |
+
except json.JSONDecodeError:
|
| 941 |
+
# If it fails, the file might be a non-standard JSON (objects separated by commas)
|
| 942 |
+
logging.info("Detected possible non-standard JSON format, attempting correction...")
|
| 943 |
+
with open(DATA_PATH, 'r') as f:
|
| 944 |
+
text = f.read().strip()
|
| 945 |
+
|
| 946 |
+
# Add square brackets to create a valid JSON array
|
| 947 |
+
if not text.startswith('['):
|
| 948 |
+
text = '[' + text
|
| 949 |
+
if not text.endswith(']'):
|
| 950 |
+
text = text + ']'
|
| 951 |
+
|
| 952 |
+
# Try loading the corrected text
|
| 953 |
+
users_data = json.loads(text)
|
| 954 |
+
logging.info("JSON format successfully corrected")
|
| 955 |
+
|
| 956 |
+
logging.info(f"Loaded {len(users_data)} records")
|
| 957 |
+
except FileNotFoundError:
|
| 958 |
+
logging.error(f"File {DATA_PATH} not found!")
|
| 959 |
+
return
|
| 960 |
+
except Exception as e:
|
| 961 |
+
logging.error(f"Unable to load file: {str(e)}")
|
| 962 |
+
return
|
| 963 |
+
|
| 964 |
+
# Initialize pipeline
|
| 965 |
+
logging.info("Initializing pipeline...")
|
| 966 |
+
pipeline = UserEmbeddingPipeline(
|
| 967 |
+
output_dim=OUTPUT_DIM,
|
| 968 |
+
max_sequence_length=MAX_SEQ_LENGTH
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
# Build vocabularies
|
| 972 |
+
logging.info("Building vocabularies...")
|
| 973 |
+
try:
|
| 974 |
+
pipeline.build_vocabularies(users_data)
|
| 975 |
+
vocab_sizes = {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()}
|
| 976 |
+
logging.info(f"Vocabulary sizes: {vocab_sizes}")
|
| 977 |
+
except Exception as e:
|
| 978 |
+
logging.error(f"Error building vocabularies: {str(e)}")
|
| 979 |
+
return
|
| 980 |
+
|
| 981 |
+
# Initialize model
|
| 982 |
+
logging.info("Initializing model...")
|
| 983 |
+
try:
|
| 984 |
+
pipeline.initialize_model()
|
| 985 |
+
logging.info("Model initialized successfully")
|
| 986 |
+
except Exception as e:
|
| 987 |
+
logging.error(f"Error initializing model: {str(e)}")
|
| 988 |
+
return
|
| 989 |
+
|
| 990 |
+
# Training
|
| 991 |
+
logging.info("Starting training...")
|
| 992 |
+
try:
|
| 993 |
+
# Create directory for checkpoints if it doesn't exist
|
| 994 |
+
model_dir = output_dir / "model_checkpoints"
|
| 995 |
+
model_dir.mkdir(exist_ok=True, parents=True)
|
| 996 |
+
|
| 997 |
+
model = train_user_embeddings(
|
| 998 |
+
pipeline.model,
|
| 999 |
+
users_data,
|
| 1000 |
+
pipeline, # Pass the pipeline to training
|
| 1001 |
+
num_epochs=NUM_EPOCHS,
|
| 1002 |
+
batch_size=BATCH_SIZE,
|
| 1003 |
+
lr=LEARNING_RATE,
|
| 1004 |
+
save_dir=model_dir, # Add incremental saving
|
| 1005 |
+
save_interval=SAVE_INTERVAL, # Save every N epochs
|
| 1006 |
+
num_triplets=NUM_TRIPLETS
|
| 1007 |
+
)
|
| 1008 |
+
logging.info("Training completed")
|
| 1009 |
+
pipeline.model = model
|
| 1010 |
+
|
| 1011 |
+
# Save only the model file
|
| 1012 |
+
logging.info("Saving model...")
|
| 1013 |
+
|
| 1014 |
+
# Create output directory
|
| 1015 |
+
output_dir.mkdir(exist_ok=True)
|
| 1016 |
+
|
| 1017 |
+
# Save path
|
| 1018 |
+
model_path = output_dir / 'model.pth'
|
| 1019 |
+
|
| 1020 |
+
# Prepare dictionary with model state and metadata
|
| 1021 |
+
checkpoint = {
|
| 1022 |
+
'model_state_dict': pipeline.model.state_dict(),
|
| 1023 |
+
'vocab_maps': pipeline.vocab_maps,
|
| 1024 |
+
'embedding_dims': pipeline.embedding_dims,
|
| 1025 |
+
'output_dim': pipeline.output_dim,
|
| 1026 |
+
'max_sequence_length': pipeline.max_sequence_length
|
| 1027 |
+
}
|
| 1028 |
+
|
| 1029 |
+
# Save model
|
| 1030 |
+
torch.save(checkpoint, model_path)
|
| 1031 |
+
|
| 1032 |
+
logging.info(f"Model saved to: {model_path}")
|
| 1033 |
+
|
| 1034 |
+
# Also save a configuration file for reference
|
| 1035 |
+
config_info = {
|
| 1036 |
+
'model_type': 'UserEmbeddingModel',
|
| 1037 |
+
'vocab_sizes': {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()},
|
| 1038 |
+
'embedding_dims': pipeline.embedding_dims,
|
| 1039 |
+
'output_dim': pipeline.output_dim,
|
| 1040 |
+
'max_sequence_length': pipeline.max_sequence_length,
|
| 1041 |
+
'padded_fields': list(pipeline.model.padded_fields),
|
| 1042 |
+
'fields': pipeline.fields
|
| 1043 |
+
}
|
| 1044 |
+
|
| 1045 |
+
config_path = output_dir / 'model_config.json'
|
| 1046 |
+
with open(config_path, 'w') as f:
|
| 1047 |
+
json.dump(config_info, f, indent=2)
|
| 1048 |
+
|
| 1049 |
+
logging.info(f"Model configuration saved to: {config_path}")
|
| 1050 |
+
|
| 1051 |
+
# Only save in HuggingFace format if requested
|
| 1052 |
+
save_hf = os.environ.get("SAVE_HF_FORMAT", "false").lower() == "true"
|
| 1053 |
+
if save_hf:
|
| 1054 |
+
logging.info("Saving in HuggingFace format...")
|
| 1055 |
+
# Save model in HuggingFace format
|
| 1056 |
+
hf_dir = output_dir / 'huggingface'
|
| 1057 |
+
hf_dir.mkdir(exist_ok=True)
|
| 1058 |
+
|
| 1059 |
+
# Save model in HF format
|
| 1060 |
+
torch.save(pipeline.model.state_dict(), hf_dir / 'pytorch_model.bin')
|
| 1061 |
+
|
| 1062 |
+
# Save config
|
| 1063 |
+
with open(hf_dir / 'config.json', 'w') as f:
|
| 1064 |
+
json.dump(config_info, f, indent=2)
|
| 1065 |
+
|
| 1066 |
+
logging.info(f"Model saved in HuggingFace format to: {hf_dir}")
|
| 1067 |
+
|
| 1068 |
+
# Push to HuggingFace if environment variable is set
|
| 1069 |
+
hf_repo_id = os.environ.get("HF_REPO_ID")
|
| 1070 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 1071 |
+
|
| 1072 |
+
if save_hf and hf_repo_id and hf_token:
|
| 1073 |
+
try:
|
| 1074 |
+
from huggingface_hub import HfApi
|
| 1075 |
+
|
| 1076 |
+
logging.info(f"Pushing model to HuggingFace: {hf_repo_id}")
|
| 1077 |
+
api = HfApi()
|
| 1078 |
+
|
| 1079 |
+
# Push the model directory
|
| 1080 |
+
api.create_repo(
|
| 1081 |
+
repo_id=hf_repo_id,
|
| 1082 |
+
token=hf_token,
|
| 1083 |
+
exist_ok=True,
|
| 1084 |
+
private=True
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
# Upload the files
|
| 1088 |
+
for file_path in (output_dir / "huggingface").glob("**/*"):
|
| 1089 |
+
if file_path.is_file():
|
| 1090 |
+
api.upload_file(
|
| 1091 |
+
path_or_fileobj=str(file_path),
|
| 1092 |
+
path_in_repo=file_path.relative_to(output_dir / "huggingface"),
|
| 1093 |
+
repo_id=hf_repo_id,
|
| 1094 |
+
token=hf_token
|
| 1095 |
+
)
|
| 1096 |
+
|
| 1097 |
+
logging.info(f"Model successfully pushed to HuggingFace: {hf_repo_id}")
|
| 1098 |
+
except Exception as e:
|
| 1099 |
+
logging.error(f"Error pushing to HuggingFace: {str(e)}")
|
| 1100 |
+
|
| 1101 |
+
except Exception as e:
|
| 1102 |
+
logging.error(f"Error during training or saving: {str(e)}")
|
| 1103 |
+
import traceback
|
| 1104 |
+
traceback.print_exc()
|
| 1105 |
+
return
|
| 1106 |
+
|
| 1107 |
+
logging.info("Process completed successfully!")
|
| 1108 |
+
|
| 1109 |
+
if __name__ == "__main__":
|
| 1110 |
+
main()
|
generate_model_gpu.py
ADDED
|
@@ -0,0 +1,1265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
|
| 3 |
+
#os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
|
| 4 |
+
import multiprocessing
|
| 5 |
+
try:
|
| 6 |
+
multiprocessing.set_start_method('spawn')
|
| 7 |
+
except RuntimeError:
|
| 8 |
+
pass # Il metodo è già stato impostato
|
| 9 |
+
import json
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import numpy as np
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from typing import List, Dict
|
| 15 |
+
import logging
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
from torch.utils.data import Dataset, DataLoader
|
| 18 |
+
import random
|
| 19 |
+
from collections import defaultdict
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
import os
|
| 23 |
+
import multiprocessing
|
| 24 |
+
from multiprocessing import Pool
|
| 25 |
+
import psutil
|
| 26 |
+
import argparse
|
| 27 |
+
|
| 28 |
+
# =================================
|
| 29 |
+
# CONFIGURABLE PARAMETERS
|
| 30 |
+
# =================================
|
| 31 |
+
# Define default parameters that can be overridden via environment variables
|
| 32 |
+
DEFAULT_NUM_TRIPLETS = 150 # Number of triplet examples to generate
|
| 33 |
+
DEFAULT_NUM_EPOCHS = 1 # Number of training epochs
|
| 34 |
+
DEFAULT_BATCH_SIZE = 64 # Batch size for training
|
| 35 |
+
DEFAULT_LEARNING_RATE = 0.001 # Learning rate for optimizer
|
| 36 |
+
DEFAULT_OUTPUT_DIM = 256 # Output dimension of embeddings
|
| 37 |
+
DEFAULT_MAX_SEQ_LENGTH = 15 # Maximum sequence length
|
| 38 |
+
DEFAULT_SAVE_INTERVAL = 2 # Save checkpoint every N epochs
|
| 39 |
+
DEFAULT_DATA_PATH = "./users.json" # Path to user data
|
| 40 |
+
DEFAULT_OUTPUT_DIR = "./model" # Output directory
|
| 41 |
+
|
| 42 |
+
# Read parameters from environment variables
|
| 43 |
+
NUM_TRIPLETS = int(os.environ.get("NUM_TRIPLETS", DEFAULT_NUM_TRIPLETS))
|
| 44 |
+
NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", DEFAULT_NUM_EPOCHS))
|
| 45 |
+
BATCH_SIZE = int(os.environ.get("BATCH_SIZE", DEFAULT_BATCH_SIZE))
|
| 46 |
+
LEARNING_RATE = float(os.environ.get("LEARNING_RATE", DEFAULT_LEARNING_RATE))
|
| 47 |
+
OUTPUT_DIM = int(os.environ.get("OUTPUT_DIM", DEFAULT_OUTPUT_DIM))
|
| 48 |
+
MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", DEFAULT_MAX_SEQ_LENGTH))
|
| 49 |
+
SAVE_INTERVAL = int(os.environ.get("SAVE_INTERVAL", DEFAULT_SAVE_INTERVAL))
|
| 50 |
+
DATA_PATH = os.environ.get("DATA_PATH", DEFAULT_DATA_PATH)
|
| 51 |
+
OUTPUT_DIR = os.environ.get("OUTPUT_DIR", DEFAULT_OUTPUT_DIR)
|
| 52 |
+
|
| 53 |
+
# Configure logging
|
| 54 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 55 |
+
|
| 56 |
+
# Get CUDA device
|
| 57 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 58 |
+
logging.info(f"Using device: {device}")
|
| 59 |
+
|
| 60 |
+
# =================================
|
| 61 |
+
# MODEL ARCHITECTURE
|
| 62 |
+
# =================================
|
| 63 |
+
class UserEmbeddingModel(nn.Module):
|
| 64 |
+
def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int],
|
| 65 |
+
output_dim: int = 256, max_sequence_length: int = 15,
|
| 66 |
+
padded_fields_length: int = 10):
|
| 67 |
+
super().__init__()
|
| 68 |
+
|
| 69 |
+
self.max_sequence_length = max_sequence_length
|
| 70 |
+
self.padded_fields_length = padded_fields_length
|
| 71 |
+
self.padded_fields = {'dmp_channels', 'dmp_tags', 'dmp_clusters'}
|
| 72 |
+
self.embedding_layers = nn.ModuleDict()
|
| 73 |
+
|
| 74 |
+
# Create embedding layers for each field
|
| 75 |
+
for field, vocab_size in vocab_sizes.items():
|
| 76 |
+
self.embedding_layers[field] = nn.Embedding(
|
| 77 |
+
vocab_size,
|
| 78 |
+
embedding_dims.get(field, 16),
|
| 79 |
+
padding_idx=0
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Calculate total input dimension
|
| 83 |
+
self.total_input_dim = 0
|
| 84 |
+
for field, dim in embedding_dims.items():
|
| 85 |
+
if field in self.padded_fields:
|
| 86 |
+
self.total_input_dim += dim # Single dimension for padded field
|
| 87 |
+
else:
|
| 88 |
+
self.total_input_dim += dim
|
| 89 |
+
|
| 90 |
+
print(f"Total input dimension: {self.total_input_dim}")
|
| 91 |
+
|
| 92 |
+
self.fc = nn.Sequential(
|
| 93 |
+
nn.Linear(self.total_input_dim, self.total_input_dim // 2),
|
| 94 |
+
nn.ReLU(),
|
| 95 |
+
nn.Dropout(0.2),
|
| 96 |
+
nn.Linear(self.total_input_dim // 2, output_dim),
|
| 97 |
+
nn.LayerNorm(output_dim)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def _process_sequence(self, embedding_layer: nn.Embedding, indices: torch.Tensor,
|
| 101 |
+
field_name: str) -> torch.Tensor:
|
| 102 |
+
"""Process normal sequences"""
|
| 103 |
+
batch_size = indices.size(0)
|
| 104 |
+
if indices.numel() == 0:
|
| 105 |
+
return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device)
|
| 106 |
+
|
| 107 |
+
if field_name in ['dmp_city', 'dmp_domains']:
|
| 108 |
+
if indices.dim() == 1:
|
| 109 |
+
indices = indices.unsqueeze(0)
|
| 110 |
+
if indices.size(1) > 0:
|
| 111 |
+
return embedding_layer(indices[:, 0])
|
| 112 |
+
return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device)
|
| 113 |
+
|
| 114 |
+
# Handle multiple sequences
|
| 115 |
+
embeddings = embedding_layer(indices)
|
| 116 |
+
return embeddings.mean(dim=1) # [batch_size, emb_dim]
|
| 117 |
+
|
| 118 |
+
def _process_padded_sequence(self, embedding_layer: nn.Embedding,
|
| 119 |
+
indices: torch.Tensor) -> torch.Tensor:
|
| 120 |
+
"""Process sequences with padding"""
|
| 121 |
+
batch_size = indices.size(0)
|
| 122 |
+
emb_dim = embedding_layer.embedding_dim
|
| 123 |
+
|
| 124 |
+
# Generate embeddings
|
| 125 |
+
embeddings = embedding_layer(indices) # [batch_size, seq_len, emb_dim]
|
| 126 |
+
|
| 127 |
+
# Average along sequence dimension
|
| 128 |
+
mask = (indices != 0).float().unsqueeze(-1)
|
| 129 |
+
masked_embeddings = embeddings * mask
|
| 130 |
+
sum_mask = mask.sum(dim=1).clamp(min=1.0)
|
| 131 |
+
|
| 132 |
+
return (masked_embeddings.sum(dim=1) / sum_mask) # [batch_size, emb_dim]
|
| 133 |
+
|
| 134 |
+
def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
|
| 135 |
+
batch_embeddings = []
|
| 136 |
+
|
| 137 |
+
for field in ['dmp_city', 'source', 'dmp_brands', # modificato: rimosso 'dmp_domains', aggiunto 'source'
|
| 138 |
+
'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels',
|
| 139 |
+
'device']: # aggiunto 'device'
|
| 140 |
+
if field in inputs and field in self.embedding_layers:
|
| 141 |
+
if field in self.padded_fields:
|
| 142 |
+
emb = self._process_padded_sequence(
|
| 143 |
+
self.embedding_layers[field],
|
| 144 |
+
inputs[field]
|
| 145 |
+
)
|
| 146 |
+
else:
|
| 147 |
+
emb = self._process_sequence(
|
| 148 |
+
self.embedding_layers[field],
|
| 149 |
+
inputs[field],
|
| 150 |
+
field
|
| 151 |
+
)
|
| 152 |
+
batch_embeddings.append(emb)
|
| 153 |
+
|
| 154 |
+
combined = torch.cat(batch_embeddings, dim=1)
|
| 155 |
+
return self.fc(combined)
|
| 156 |
+
|
| 157 |
+
# =================================
|
| 158 |
+
# EMBEDDING PIPELINE
|
| 159 |
+
# =================================
|
| 160 |
+
class UserEmbeddingPipeline:
|
| 161 |
+
def __init__(self, output_dim: int = 256, max_sequence_length: int = 15):
|
| 162 |
+
self.output_dim = output_dim
|
| 163 |
+
self.max_sequence_length = max_sequence_length
|
| 164 |
+
self.model = None
|
| 165 |
+
self.vocab_maps = {}
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
self.fields = [
|
| 169 |
+
'dmp_city', 'source', 'dmp_brands', # 'dmp_domains' rimosso, 'source' aggiunto
|
| 170 |
+
'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels',
|
| 171 |
+
'device' # 'device' aggiunto
|
| 172 |
+
]
|
| 173 |
+
|
| 174 |
+
# Map of new JSON fields to old field names used in the model
|
| 175 |
+
self.field_mapping = {
|
| 176 |
+
'dmp_city': ('dmp', 'city'),
|
| 177 |
+
'source': ('dmp', '', 'source'),
|
| 178 |
+
'dmp_brands': ('dmp', 'brands'),
|
| 179 |
+
'dmp_clusters': ('dmp', 'clusters'),
|
| 180 |
+
'dmp_industries': ('dmp', 'industries'),
|
| 181 |
+
'dmp_tags': ('dmp', 'tags'),
|
| 182 |
+
'dmp_channels': ('dmp', 'channels'),
|
| 183 |
+
'device': ('device',) # Nuovo campo device
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
self.embedding_dims = {
|
| 187 |
+
'dmp_city': 8,
|
| 188 |
+
'source': 8, # Dimensione per source
|
| 189 |
+
'dmp_brands': 32,
|
| 190 |
+
'dmp_clusters': 64,
|
| 191 |
+
'dmp_industries': 32,
|
| 192 |
+
'dmp_tags': 128,
|
| 193 |
+
'dmp_channels': 64,
|
| 194 |
+
'device': 8 # Dimensione per device
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
def _clean_value(self, value):
|
| 198 |
+
if isinstance(value, float) and np.isnan(value):
|
| 199 |
+
return []
|
| 200 |
+
if isinstance(value, str):
|
| 201 |
+
return [value.lower().strip()]
|
| 202 |
+
if isinstance(value, list):
|
| 203 |
+
return [str(v).lower().strip() for v in value if v is not None and str(v).strip()]
|
| 204 |
+
return []
|
| 205 |
+
|
| 206 |
+
def _get_field_from_user(self, user, field):
|
| 207 |
+
"""Extract field value from new JSON user format"""
|
| 208 |
+
mapping = self.field_mapping.get(field, (field,))
|
| 209 |
+
value = user
|
| 210 |
+
|
| 211 |
+
# Navigate through nested structure
|
| 212 |
+
for key in mapping:
|
| 213 |
+
if isinstance(value, dict):
|
| 214 |
+
value = value.get(key, {})
|
| 215 |
+
else:
|
| 216 |
+
# If not a dictionary and we're not at the last element
|
| 217 |
+
# of the mapping, return an empty list
|
| 218 |
+
value = []
|
| 219 |
+
break
|
| 220 |
+
|
| 221 |
+
# If we've reached the end and have a value that's not a list but should be,
|
| 222 |
+
# convert it to a list
|
| 223 |
+
if field in {'dmp_brands', 'dmp_channels', 'dmp_clusters', 'dmp_industries', 'dmp_tags'} and not isinstance(value, list):
|
| 224 |
+
# If it's a string or other single value, put it in a list
|
| 225 |
+
if value and not isinstance(value, dict):
|
| 226 |
+
value = [value]
|
| 227 |
+
else:
|
| 228 |
+
value = []
|
| 229 |
+
|
| 230 |
+
return value
|
| 231 |
+
|
| 232 |
+
def build_vocabularies(self, users_data: List[Dict]) -> Dict[str, Dict[str, int]]:
|
| 233 |
+
field_values = {field: {'<PAD>'} for field in self.fields}
|
| 234 |
+
|
| 235 |
+
# Extract the 'user' field from the JSON structure for each record
|
| 236 |
+
users = []
|
| 237 |
+
for data in users_data:
|
| 238 |
+
# Check if there's raw_json.user
|
| 239 |
+
if 'raw_json' in data and 'user' in data['raw_json']:
|
| 240 |
+
users.append(data['raw_json']['user'])
|
| 241 |
+
# Check if there's user
|
| 242 |
+
elif 'user' in data:
|
| 243 |
+
users.append(data['user'])
|
| 244 |
+
else:
|
| 245 |
+
users.append(data) # Assume it's already a user
|
| 246 |
+
|
| 247 |
+
for user in users:
|
| 248 |
+
for field in self.fields:
|
| 249 |
+
values = self._clean_value(self._get_field_from_user(user, field))
|
| 250 |
+
field_values[field].update(values)
|
| 251 |
+
|
| 252 |
+
self.vocab_maps = {
|
| 253 |
+
field: {val: idx for idx, val in enumerate(sorted(values))}
|
| 254 |
+
for field, values in field_values.items()
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
return self.vocab_maps
|
| 258 |
+
|
| 259 |
+
def _prepare_input(self, user: Dict) -> Dict[str, torch.Tensor]:
|
| 260 |
+
inputs = {}
|
| 261 |
+
|
| 262 |
+
for field in self.fields:
|
| 263 |
+
values = self._clean_value(self._get_field_from_user(user, field))
|
| 264 |
+
vocab = self.vocab_maps[field]
|
| 265 |
+
indices = [vocab.get(val, 0) for val in values]
|
| 266 |
+
inputs[field] = torch.tensor(indices, dtype=torch.long)
|
| 267 |
+
|
| 268 |
+
return inputs
|
| 269 |
+
|
| 270 |
+
def initialize_model(self) -> None:
|
| 271 |
+
vocab_sizes = {field: len(vocab) for field, vocab in self.vocab_maps.items()}
|
| 272 |
+
|
| 273 |
+
self.model = UserEmbeddingModel(
|
| 274 |
+
vocab_sizes=vocab_sizes,
|
| 275 |
+
embedding_dims=self.embedding_dims,
|
| 276 |
+
output_dim=self.output_dim,
|
| 277 |
+
max_sequence_length=self.max_sequence_length
|
| 278 |
+
)
|
| 279 |
+
self.model.to(device)
|
| 280 |
+
self.model.eval()
|
| 281 |
+
|
| 282 |
+
def generate_embeddings(self, users_data: List[Dict], batch_size: int = 32) -> Dict[str, np.ndarray]:
|
| 283 |
+
"""Generate embeddings for all users"""
|
| 284 |
+
embeddings = {}
|
| 285 |
+
self.model.eval() # Make sure model is in eval mode
|
| 286 |
+
|
| 287 |
+
# Extract the 'user' field from the JSON structure for each record
|
| 288 |
+
users = []
|
| 289 |
+
user_ids = []
|
| 290 |
+
|
| 291 |
+
for data in users_data:
|
| 292 |
+
# Check if there's raw_json.user
|
| 293 |
+
if 'raw_json' in data and 'user' in data['raw_json']:
|
| 294 |
+
user = data['raw_json']['user']
|
| 295 |
+
users.append(user)
|
| 296 |
+
# Use user.dmp[''].id as identifier
|
| 297 |
+
if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']:
|
| 298 |
+
user_ids.append(str(user['dmp']['']['id']))
|
| 299 |
+
else:
|
| 300 |
+
# Fallback to uid or id if dmp.id is not available
|
| 301 |
+
user_ids.append(str(user.get('uid', user.get('id', None))))
|
| 302 |
+
# Check if there's user
|
| 303 |
+
elif 'user' in data:
|
| 304 |
+
user = data['user']
|
| 305 |
+
users.append(user)
|
| 306 |
+
# Use user.dmp[''].id as identifier
|
| 307 |
+
if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']:
|
| 308 |
+
user_ids.append(str(user['dmp']['']['id']))
|
| 309 |
+
else:
|
| 310 |
+
# Fallback to uid or id if dmp.id is not available
|
| 311 |
+
user_ids.append(str(user.get('uid', user.get('id', None))))
|
| 312 |
+
else:
|
| 313 |
+
users.append(data) # Assume it's already a user
|
| 314 |
+
# Use user.dmp[''].id as identifier
|
| 315 |
+
if 'dmp' in data and '' in data['dmp'] and 'id' in data['dmp']['']:
|
| 316 |
+
user_ids.append(str(data['dmp']['']['id']))
|
| 317 |
+
else:
|
| 318 |
+
# Fallback to uid or id if dmp.id is not available
|
| 319 |
+
user_ids.append(str(data.get('uid', data.get('id', None))))
|
| 320 |
+
|
| 321 |
+
with torch.no_grad():
|
| 322 |
+
for i in tqdm(range(0, len(users), batch_size), desc="Generating embeddings"):
|
| 323 |
+
batch_users = users[i:i+batch_size]
|
| 324 |
+
batch_ids = user_ids[i:i+batch_size]
|
| 325 |
+
batch_inputs = []
|
| 326 |
+
valid_indices = []
|
| 327 |
+
|
| 328 |
+
for j, user in enumerate(batch_users):
|
| 329 |
+
if batch_ids[j] is not None:
|
| 330 |
+
batch_inputs.append(self._prepare_input(user))
|
| 331 |
+
valid_indices.append(j)
|
| 332 |
+
|
| 333 |
+
if batch_inputs:
|
| 334 |
+
# Use the same collate function as training for a single batch
|
| 335 |
+
anchor_batch, _, _ = collate_batch([(inputs, inputs, inputs) for inputs in batch_inputs])
|
| 336 |
+
|
| 337 |
+
# Move data to device
|
| 338 |
+
anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()}
|
| 339 |
+
|
| 340 |
+
# Generate embeddings
|
| 341 |
+
batch_embeddings = self.model(anchor_batch).cpu()
|
| 342 |
+
|
| 343 |
+
# Save embeddings
|
| 344 |
+
for j, idx in enumerate(valid_indices):
|
| 345 |
+
if batch_ids[idx]: # Verify that id is not None or empty
|
| 346 |
+
embeddings[batch_ids[idx]] = batch_embeddings[j].numpy()
|
| 347 |
+
|
| 348 |
+
return embeddings
|
| 349 |
+
|
| 350 |
+
def save_embeddings(self, embeddings: Dict[str, np.ndarray], output_dir: str) -> None:
|
| 351 |
+
"""Save embeddings to file"""
|
| 352 |
+
output_dir = Path(output_dir)
|
| 353 |
+
output_dir.mkdir(exist_ok=True)
|
| 354 |
+
|
| 355 |
+
# Save embeddings as JSON
|
| 356 |
+
json_path = output_dir / 'embeddings.json'
|
| 357 |
+
with open(json_path, 'w') as f:
|
| 358 |
+
json_embeddings = {user_id: emb.tolist() for user_id, emb in embeddings.items()}
|
| 359 |
+
json.dump(json_embeddings, f)
|
| 360 |
+
|
| 361 |
+
# Save embeddings as NPY
|
| 362 |
+
npy_path = output_dir / 'embeddings.npz'
|
| 363 |
+
np.savez_compressed(npy_path,
|
| 364 |
+
embeddings=np.stack(list(embeddings.values())),
|
| 365 |
+
user_ids=np.array(list(embeddings.keys())))
|
| 366 |
+
|
| 367 |
+
# Save vocabularies
|
| 368 |
+
vocab_path = output_dir / 'vocabularies.json'
|
| 369 |
+
with open(vocab_path, 'w') as f:
|
| 370 |
+
json.dump(self.vocab_maps, f)
|
| 371 |
+
|
| 372 |
+
logging.info(f"\nEmbeddings saved in {output_dir}:")
|
| 373 |
+
logging.info(f"- Embeddings JSON: {json_path}")
|
| 374 |
+
logging.info(f"- Embeddings NPY: {npy_path}")
|
| 375 |
+
logging.info(f"- Vocabularies: {vocab_path}")
|
| 376 |
+
|
| 377 |
+
def save_model(self, output_dir: str) -> None:
|
| 378 |
+
"""Save model in PyTorch format (.pth)"""
|
| 379 |
+
output_dir = Path(output_dir)
|
| 380 |
+
output_dir.mkdir(exist_ok=True)
|
| 381 |
+
|
| 382 |
+
# Save path
|
| 383 |
+
model_path = output_dir / 'model.pth'
|
| 384 |
+
|
| 385 |
+
# Prepare dictionary with model state and metadata
|
| 386 |
+
checkpoint = {
|
| 387 |
+
'model_state_dict': self.model.state_dict(),
|
| 388 |
+
'vocab_maps': self.vocab_maps,
|
| 389 |
+
'embedding_dims': self.embedding_dims,
|
| 390 |
+
'output_dim': self.output_dim,
|
| 391 |
+
'max_sequence_length': self.max_sequence_length
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
# Save model
|
| 395 |
+
torch.save(checkpoint, model_path)
|
| 396 |
+
|
| 397 |
+
logging.info(f"Model saved to: {model_path}")
|
| 398 |
+
|
| 399 |
+
# Also save a configuration file for reference
|
| 400 |
+
config_info = {
|
| 401 |
+
'model_type': 'UserEmbeddingModel',
|
| 402 |
+
'vocab_sizes': {field: len(vocab) for field, vocab in self.vocab_maps.items()},
|
| 403 |
+
'embedding_dims': self.embedding_dims,
|
| 404 |
+
'output_dim': self.output_dim,
|
| 405 |
+
'max_sequence_length': self.max_sequence_length,
|
| 406 |
+
'padded_fields': list(self.model.padded_fields),
|
| 407 |
+
'fields': self.fields
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
config_path = output_dir / 'model_config.json'
|
| 411 |
+
with open(config_path, 'w') as f:
|
| 412 |
+
json.dump(config_info, f, indent=2)
|
| 413 |
+
|
| 414 |
+
logging.info(f"Model configuration saved to: {config_path}")
|
| 415 |
+
|
| 416 |
+
# Save model in HuggingFace format
|
| 417 |
+
hf_dir = output_dir / 'huggingface'
|
| 418 |
+
hf_dir.mkdir(exist_ok=True)
|
| 419 |
+
|
| 420 |
+
# Save model in HF format
|
| 421 |
+
torch.save(self.model.state_dict(), hf_dir / 'pytorch_model.bin')
|
| 422 |
+
|
| 423 |
+
# Save config
|
| 424 |
+
with open(hf_dir / 'config.json', 'w') as f:
|
| 425 |
+
json.dump(config_info, f, indent=2)
|
| 426 |
+
|
| 427 |
+
logging.info(f"Model saved in HuggingFace format to: {hf_dir}")
|
| 428 |
+
|
| 429 |
+
def load_model(self, model_path: str) -> None:
|
| 430 |
+
"""Load a previously saved model"""
|
| 431 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 432 |
+
|
| 433 |
+
# Reload vocabularies and dimensions if needed
|
| 434 |
+
self.vocab_maps = checkpoint.get('vocab_maps', self.vocab_maps)
|
| 435 |
+
self.embedding_dims = checkpoint.get('embedding_dims', self.embedding_dims)
|
| 436 |
+
self.output_dim = checkpoint.get('output_dim', self.output_dim)
|
| 437 |
+
self.max_sequence_length = checkpoint.get('max_sequence_length', self.max_sequence_length)
|
| 438 |
+
|
| 439 |
+
# Initialize model if not already done
|
| 440 |
+
if self.model is None:
|
| 441 |
+
self.initialize_model()
|
| 442 |
+
|
| 443 |
+
# Load model weights
|
| 444 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 445 |
+
self.model.to(device)
|
| 446 |
+
self.model.eval()
|
| 447 |
+
|
| 448 |
+
logging.info(f"Model loaded from: {model_path}")
|
| 449 |
+
|
| 450 |
+
# =================================
|
| 451 |
+
# SIMILARITY AND TRIPLET GENERATION
|
| 452 |
+
# =================================
|
| 453 |
+
|
| 454 |
+
def calculate_similarity(user1, user2, pipeline, filtered_tags=None):
|
| 455 |
+
try:
|
| 456 |
+
# Estrai i campi originali: channels e clusters
|
| 457 |
+
channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None)
|
| 458 |
+
channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None)
|
| 459 |
+
clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None)
|
| 460 |
+
clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None)
|
| 461 |
+
|
| 462 |
+
# RIMOSSO domains1/domains2
|
| 463 |
+
|
| 464 |
+
# Estrai i tag e applica il filtro se necessario
|
| 465 |
+
tags1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'dmp_tags') if c is not None)
|
| 466 |
+
tags2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'dmp_tags') if c is not None)
|
| 467 |
+
|
| 468 |
+
# NUOVI CAMPI: source, brands, device
|
| 469 |
+
source1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'source') if c is not None)
|
| 470 |
+
source2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'source') if c is not None)
|
| 471 |
+
|
| 472 |
+
brands1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'dmp_brands') if c is not None)
|
| 473 |
+
brands2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'dmp_brands') if c is not None)
|
| 474 |
+
|
| 475 |
+
device1 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user1, 'device') if c is not None)
|
| 476 |
+
device2 = set(str(c).lower().strip() for c in pipeline._get_field_from_user(user2, 'device') if c is not None)
|
| 477 |
+
|
| 478 |
+
if filtered_tags is not None:
|
| 479 |
+
# Filtra i tag usando solo quelli presenti nel set di tag filtrati
|
| 480 |
+
tags1 = {tag for tag in tags1 if tag in filtered_tags}
|
| 481 |
+
tags2 = {tag for tag in tags2 if tag in filtered_tags}
|
| 482 |
+
|
| 483 |
+
# Calcola le similarità Jaccard per ogni campo
|
| 484 |
+
channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2))
|
| 485 |
+
cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | channels2))
|
| 486 |
+
|
| 487 |
+
tag_sim = len(tags1 & tags2) / max(1, len(tags1 | tags2))
|
| 488 |
+
|
| 489 |
+
# Nuove similarità
|
| 490 |
+
source_sim = len(source1 & source2) / max(1, len(source1 | source2))
|
| 491 |
+
brands_sim = len(brands1 & brands2) / max(1, len(brands1 | brands2))
|
| 492 |
+
device_sim = len(device1 & device2) / max(1, len(device1 | device2))
|
| 493 |
+
|
| 494 |
+
# Calcola la similarità totale con i pesi specificati:
|
| 495 |
+
# 6 per clusters, 5 per channels, 3 per tags,
|
| 496 |
+
# 3 per source, 5 per brands, 2 per device
|
| 497 |
+
total_weight = 6 + 5 + 3 + 2 + 5 + 3 # Somma dei pesi = 24
|
| 498 |
+
weighted_sim = (
|
| 499 |
+
6 * cluster_sim +
|
| 500 |
+
5 * channel_sim +
|
| 501 |
+
3 * tag_sim +
|
| 502 |
+
2 * source_sim + # Nuovo: source con peso 0.3 (3/24 ≈ 0.125)
|
| 503 |
+
5 * brands_sim + # Nuovo: brands con peso 0.5 (5/24 ≈ 0.208)
|
| 504 |
+
3 * device_sim # Nuovo: device con peso 0.2 (2/24 ≈ 0.083)
|
| 505 |
+
) / total_weight
|
| 506 |
+
|
| 507 |
+
return weighted_sim
|
| 508 |
+
except Exception as e:
|
| 509 |
+
logging.error(f"Error calculating similarity: {str(e)}")
|
| 510 |
+
return 0.0
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def process_batch_triplets(args):
|
| 514 |
+
try:
|
| 515 |
+
batch_idx, users, channel_index, cluster_index, num_triplets, pipeline = args
|
| 516 |
+
batch_triplets = []
|
| 517 |
+
|
| 518 |
+
# Forza l'uso della CPU per tutti i calcoli
|
| 519 |
+
with torch.no_grad():
|
| 520 |
+
# Imposta temporaneamente il dispositivo su CPU per il calcolo delle similarità
|
| 521 |
+
temp_device = torch.device("cpu")
|
| 522 |
+
|
| 523 |
+
for _ in range(num_triplets):
|
| 524 |
+
anchor_idx = random.randint(0, len(users)-1)
|
| 525 |
+
anchor_user = users[anchor_idx]
|
| 526 |
+
|
| 527 |
+
# Find candidates that share channels or clusters
|
| 528 |
+
candidates = set()
|
| 529 |
+
for channel in pipeline._get_field_from_user(anchor_user, 'dmp_channels'):
|
| 530 |
+
candidates.update(channel_index.get(str(channel), []))
|
| 531 |
+
for cluster in pipeline._get_field_from_user(anchor_user, 'dmp_clusters'):
|
| 532 |
+
candidates.update(cluster_index.get(str(cluster), []))
|
| 533 |
+
|
| 534 |
+
# Remove anchor
|
| 535 |
+
candidates.discard(anchor_idx)
|
| 536 |
+
|
| 537 |
+
# Find positive (similar) user
|
| 538 |
+
if not candidates:
|
| 539 |
+
positive_idx = random.randint(0, len(users)-1)
|
| 540 |
+
else:
|
| 541 |
+
# Calculate similarities for candidates
|
| 542 |
+
similarities = []
|
| 543 |
+
for idx in candidates:
|
| 544 |
+
# Calcolo della similarità senza CUDA
|
| 545 |
+
sim = calculate_similarity(anchor_user, users[idx], pipeline)
|
| 546 |
+
if sim > 0:
|
| 547 |
+
similarities.append((idx, sim))
|
| 548 |
+
|
| 549 |
+
if not similarities:
|
| 550 |
+
positive_idx = random.randint(0, len(users)-1)
|
| 551 |
+
else:
|
| 552 |
+
# Sort by similarity
|
| 553 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 554 |
+
# Return one of the top K most similar
|
| 555 |
+
top_k = min(10, len(similarities))
|
| 556 |
+
positive_idx = similarities[random.randint(0, top_k-1)][0]
|
| 557 |
+
|
| 558 |
+
# Find negative (dissimilar) user
|
| 559 |
+
max_attempts = 50
|
| 560 |
+
negative_idx = None
|
| 561 |
+
|
| 562 |
+
for _ in range(max_attempts):
|
| 563 |
+
idx = random.randint(0, len(users)-1)
|
| 564 |
+
if idx != anchor_idx and idx != positive_idx:
|
| 565 |
+
# Calcolo della similarità senza CUDA
|
| 566 |
+
if calculate_similarity(anchor_user, users[idx], pipeline) < 0.1:
|
| 567 |
+
negative_idx = idx
|
| 568 |
+
break
|
| 569 |
+
|
| 570 |
+
if negative_idx is None:
|
| 571 |
+
negative_idx = random.randint(0, len(users)-1)
|
| 572 |
+
|
| 573 |
+
batch_triplets.append((anchor_idx, positive_idx, negative_idx))
|
| 574 |
+
|
| 575 |
+
return batch_triplets
|
| 576 |
+
except Exception as e:
|
| 577 |
+
logging.error(f"Error in batch triplet generation: {str(e)}")
|
| 578 |
+
return []
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# =================================
|
| 582 |
+
# DATASET AND DATALOADER
|
| 583 |
+
# =================================
|
| 584 |
+
|
| 585 |
+
class UserSimilarityDataset(Dataset):
|
| 586 |
+
def __init__(self, pipeline, users_data, num_triplets=10, num_workers=None, filtered_tags=None):
|
| 587 |
+
self.triplets = []
|
| 588 |
+
self.filtered_tags = filtered_tags
|
| 589 |
+
logging.info("Initializing UserSimilarityDataset...")
|
| 590 |
+
|
| 591 |
+
# Extract the 'user' field from the JSON structure for each record
|
| 592 |
+
self.users = []
|
| 593 |
+
for data in users_data:
|
| 594 |
+
# Check if there's raw_json.user
|
| 595 |
+
if 'raw_json' in data and 'user' in data['raw_json']:
|
| 596 |
+
self.users.append(data['raw_json']['user'])
|
| 597 |
+
# Check if there's user
|
| 598 |
+
elif 'user' in data:
|
| 599 |
+
self.users.append(data['user'])
|
| 600 |
+
else:
|
| 601 |
+
self.users.append(data) # Assume it's already a user
|
| 602 |
+
|
| 603 |
+
self.pipeline = pipeline
|
| 604 |
+
self.num_triplets = num_triplets
|
| 605 |
+
|
| 606 |
+
# Determine number of workers for parallel processing
|
| 607 |
+
if num_workers is None:
|
| 608 |
+
num_workers = max(1, min(8, os.cpu_count()))
|
| 609 |
+
self.num_workers = num_workers
|
| 610 |
+
|
| 611 |
+
# Pre-processa gli input per ogni utente
|
| 612 |
+
self.preprocessed_inputs = {}
|
| 613 |
+
for idx, user in enumerate(self.users):
|
| 614 |
+
self.preprocessed_inputs[idx] = pipeline._prepare_input(user)
|
| 615 |
+
|
| 616 |
+
logging.info("Creating indexes for channels, clusters, tags, brands, source, and device...")
|
| 617 |
+
self.channel_index = defaultdict(list)
|
| 618 |
+
self.cluster_index = defaultdict(list)
|
| 619 |
+
self.tag_index = defaultdict(list)
|
| 620 |
+
# Rimosso domain_index
|
| 621 |
+
self.brands_index = defaultdict(list) # Nuovo indice per brands
|
| 622 |
+
self.source_index = defaultdict(list) # Nuovo indice per source
|
| 623 |
+
self.device_index = defaultdict(list) # Nuovo indice per device
|
| 624 |
+
|
| 625 |
+
for idx, user in enumerate(self.users):
|
| 626 |
+
channels = pipeline._get_field_from_user(user, 'dmp_channels')
|
| 627 |
+
clusters = pipeline._get_field_from_user(user, 'dmp_clusters')
|
| 628 |
+
tags = pipeline._get_field_from_user(user, 'dmp_tags')
|
| 629 |
+
# Rimosso domains
|
| 630 |
+
brands = pipeline._get_field_from_user(user, 'dmp_brands') # Aggiunto brands
|
| 631 |
+
source = pipeline._get_field_from_user(user, 'source') # Aggiunto source
|
| 632 |
+
device = pipeline._get_field_from_user(user, 'device') # Aggiunto device
|
| 633 |
+
|
| 634 |
+
if channels:
|
| 635 |
+
channels = [str(c) for c in channels if c is not None]
|
| 636 |
+
if clusters:
|
| 637 |
+
clusters = [str(c) for c in clusters if c is not None]
|
| 638 |
+
if tags:
|
| 639 |
+
tags = [str(c) for c in tags if c is not None]
|
| 640 |
+
# Filtra i tag se è stato fornito un set di tag filtrati
|
| 641 |
+
if self.filtered_tags:
|
| 642 |
+
tags = [tag for tag in tags if tag in self.filtered_tags]
|
| 643 |
+
# Rimosso codice per domains
|
| 644 |
+
if brands:
|
| 645 |
+
brands = [str(c) for c in brands if c is not None]
|
| 646 |
+
|
| 647 |
+
|
| 648 |
+
if source:
|
| 649 |
+
if isinstance(source, str):
|
| 650 |
+
source = [source] # Se è una stringa singola, mettila in una lista
|
| 651 |
+
else:
|
| 652 |
+
source = [str(c) for c in source if c is not None] # Altrimenti trattala come prima
|
| 653 |
+
|
| 654 |
+
if device:
|
| 655 |
+
if isinstance(device, str):
|
| 656 |
+
device = [device] # Se è una stringa singola, mettila in una lista
|
| 657 |
+
else:
|
| 658 |
+
device = [str(c) for c in device if c is not None]
|
| 659 |
+
|
| 660 |
+
for channel in channels:
|
| 661 |
+
self.channel_index[channel].append(idx)
|
| 662 |
+
for cluster in clusters:
|
| 663 |
+
self.cluster_index[cluster].append(idx)
|
| 664 |
+
for tag in tags:
|
| 665 |
+
self.tag_index[tag].append(idx)
|
| 666 |
+
# Rimosso ciclo per domains
|
| 667 |
+
for brand in brands:
|
| 668 |
+
self.brands_index[brand].append(idx)
|
| 669 |
+
for s in source:
|
| 670 |
+
self.source_index[s].append(idx)
|
| 671 |
+
for d in device:
|
| 672 |
+
self.device_index[d].append(idx)
|
| 673 |
+
|
| 674 |
+
logging.info(f"Found {len(self.channel_index)} unique channels, {len(self.cluster_index)} unique clusters, {len(self.tag_index)} unique tags")
|
| 675 |
+
logging.info(f"Found {len(self.brands_index)} unique brands, {len(self.source_index)} unique sources, and {len(self.device_index)} unique devices")
|
| 676 |
+
|
| 677 |
+
logging.info(f"Generating triplets using {self.num_workers} worker processes...")
|
| 678 |
+
|
| 679 |
+
self.triplets = self._generate_triplets_gpu(num_triplets)
|
| 680 |
+
logging.info(f"Generated {len(self.triplets)} triplets")
|
| 681 |
+
|
| 682 |
+
def __len__(self):
|
| 683 |
+
return len(self.triplets)
|
| 684 |
+
|
| 685 |
+
def __getitem__(self, idx):
|
| 686 |
+
if idx >= len(self.triplets):
|
| 687 |
+
raise IndexError(f"Index {idx} out of range for dataset with {len(self.triplets)} triplets")
|
| 688 |
+
|
| 689 |
+
anchor_idx, positive_idx, negative_idx = self.triplets[idx]
|
| 690 |
+
return (
|
| 691 |
+
self.preprocessed_inputs[anchor_idx],
|
| 692 |
+
self.preprocessed_inputs[positive_idx],
|
| 693 |
+
self.preprocessed_inputs[negative_idx]
|
| 694 |
+
)
|
| 695 |
+
|
| 696 |
+
def _generate_triplets_gpu(self, num_triplets):
|
| 697 |
+
"""Generate triplets using a more reliable approach with batch processing"""
|
| 698 |
+
logging.info("Generating triplets with batch approach...")
|
| 699 |
+
|
| 700 |
+
triplets = []
|
| 701 |
+
batch_size = 10 # Numero di triplette da generare per batch
|
| 702 |
+
num_batches = (num_triplets + batch_size - 1) // batch_size
|
| 703 |
+
|
| 704 |
+
progress_bar = tqdm(
|
| 705 |
+
range(num_batches),
|
| 706 |
+
desc="Generating triplet batches",
|
| 707 |
+
bar_format='{l_bar}{bar:10}{r_bar}'
|
| 708 |
+
)
|
| 709 |
+
|
| 710 |
+
for _ in progress_bar:
|
| 711 |
+
batch_triplets = []
|
| 712 |
+
|
| 713 |
+
# Genera un batch di triplette
|
| 714 |
+
for i in range(batch_size):
|
| 715 |
+
if len(triplets) >= num_triplets:
|
| 716 |
+
break
|
| 717 |
+
|
| 718 |
+
# Seleziona anchor casuale
|
| 719 |
+
anchor_idx = random.randint(0, len(self.users)-1)
|
| 720 |
+
anchor_user = self.users[anchor_idx]
|
| 721 |
+
|
| 722 |
+
# Trova candidati che condividono channels, clusters, tags o domains
|
| 723 |
+
# Trova candidati che condividono channels, clusters, tags, brands, source, device
|
| 724 |
+
candidates = set()
|
| 725 |
+
for channel in self.pipeline._get_field_from_user(anchor_user, 'dmp_channels'):
|
| 726 |
+
if channel is not None:
|
| 727 |
+
candidates.update(self.channel_index.get(str(channel), []))
|
| 728 |
+
for cluster in self.pipeline._get_field_from_user(anchor_user, 'dmp_clusters'):
|
| 729 |
+
if cluster is not None:
|
| 730 |
+
candidates.update(self.cluster_index.get(str(cluster), []))
|
| 731 |
+
for tag in self.pipeline._get_field_from_user(anchor_user, 'dmp_tags'):
|
| 732 |
+
if tag is not None and (self.filtered_tags is None or str(tag) in self.filtered_tags):
|
| 733 |
+
candidates.update(self.tag_index.get(str(tag), []))
|
| 734 |
+
# Rimosso il loop per domains
|
| 735 |
+
|
| 736 |
+
# Nuovi loop per brands, source, device
|
| 737 |
+
for brand in self.pipeline._get_field_from_user(anchor_user, 'dmp_brands'):
|
| 738 |
+
if brand is not None:
|
| 739 |
+
candidates.update(self.brands_index.get(str(brand), []))
|
| 740 |
+
for source in self.pipeline._get_field_from_user(anchor_user, 'source'):
|
| 741 |
+
if source is not None:
|
| 742 |
+
candidates.update(self.source_index.get(str(source), []))
|
| 743 |
+
for device in self.pipeline._get_field_from_user(anchor_user, 'device'):
|
| 744 |
+
if device is not None:
|
| 745 |
+
candidates.update(self.device_index.get(str(device), []))
|
| 746 |
+
|
| 747 |
+
# Rimuovi l'anchor dai candidati
|
| 748 |
+
candidates.discard(anchor_idx)
|
| 749 |
+
|
| 750 |
+
# Trova esempio positivo
|
| 751 |
+
if candidates:
|
| 752 |
+
similarities = []
|
| 753 |
+
for idx in list(candidates)[:50]: # Limita la ricerca ai primi 50 candidati
|
| 754 |
+
sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline, self.filtered_tags)
|
| 755 |
+
if sim > 0:
|
| 756 |
+
similarities.append((idx, sim))
|
| 757 |
+
|
| 758 |
+
if similarities:
|
| 759 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 760 |
+
top_k = min(10, len(similarities))
|
| 761 |
+
positive_idx = similarities[random.randint(0, top_k-1)][0]
|
| 762 |
+
else:
|
| 763 |
+
positive_idx = random.randint(0, len(self.users)-1)
|
| 764 |
+
else:
|
| 765 |
+
positive_idx = random.randint(0, len(self.users)-1)
|
| 766 |
+
|
| 767 |
+
# Trova esempio negativo
|
| 768 |
+
attempts = 0
|
| 769 |
+
negative_idx = None
|
| 770 |
+
|
| 771 |
+
while attempts < 20 and negative_idx is None: # Ridotto a 20 tentativi
|
| 772 |
+
idx = random.randint(0, len(self.users)-1)
|
| 773 |
+
if idx != anchor_idx and idx != positive_idx:
|
| 774 |
+
sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline, self.filtered_tags)
|
| 775 |
+
if sim < 0.1:
|
| 776 |
+
negative_idx = idx
|
| 777 |
+
break
|
| 778 |
+
attempts += 1
|
| 779 |
+
|
| 780 |
+
if negative_idx is None:
|
| 781 |
+
negative_idx = random.randint(0, len(self.users)-1)
|
| 782 |
+
|
| 783 |
+
batch_triplets.append((anchor_idx, positive_idx, negative_idx))
|
| 784 |
+
|
| 785 |
+
triplets.extend(batch_triplets)
|
| 786 |
+
|
| 787 |
+
return triplets[:num_triplets] # Assicurati di restituire esattamente num_triplets
|
| 788 |
+
|
| 789 |
+
|
| 790 |
+
def collate_batch(batch):
|
| 791 |
+
"""Custom collate function to properly handle tensor dimensions"""
|
| 792 |
+
anchor_inputs, positive_inputs, negative_inputs = zip(*batch)
|
| 793 |
+
|
| 794 |
+
def process_group_inputs(group_inputs):
|
| 795 |
+
processed = {}
|
| 796 |
+
for field in group_inputs[0].keys():
|
| 797 |
+
# Find maximum length for this field in the batch
|
| 798 |
+
max_len = max(inputs[field].size(0) for inputs in group_inputs)
|
| 799 |
+
|
| 800 |
+
# Create padded tensors
|
| 801 |
+
padded = torch.stack([
|
| 802 |
+
torch.cat([
|
| 803 |
+
inputs[field],
|
| 804 |
+
torch.zeros(max_len - inputs[field].size(0), dtype=torch.long)
|
| 805 |
+
]) if inputs[field].size(0) < max_len else inputs[field][:max_len]
|
| 806 |
+
for inputs in group_inputs
|
| 807 |
+
])
|
| 808 |
+
|
| 809 |
+
processed[field] = padded
|
| 810 |
+
|
| 811 |
+
return processed
|
| 812 |
+
|
| 813 |
+
# Process each group (anchor, positive, negative)
|
| 814 |
+
anchor_batch = process_group_inputs(anchor_inputs)
|
| 815 |
+
positive_batch = process_group_inputs(positive_inputs)
|
| 816 |
+
negative_batch = process_group_inputs(negative_inputs)
|
| 817 |
+
|
| 818 |
+
return anchor_batch, positive_batch, negative_batch
|
| 819 |
+
|
| 820 |
+
# =================================
|
| 821 |
+
# TRAINING FUNCTION
|
| 822 |
+
# =================================
|
| 823 |
+
|
| 824 |
+
def train_user_embeddings(model, users_data, pipeline, num_epochs=10, batch_size=32, lr=0.001, save_dir=None, save_interval=2, num_triplets=150):
|
| 825 |
+
"""Main training of the model with proper batch handling and incremental saving"""
|
| 826 |
+
model.train()
|
| 827 |
+
model.to(device)
|
| 828 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
| 829 |
+
|
| 830 |
+
# Add learning rate scheduler
|
| 831 |
+
scheduler = torch.optim.lr_scheduler.StepLR(
|
| 832 |
+
optimizer,
|
| 833 |
+
step_size=2, # Decay every 2 epochs
|
| 834 |
+
gamma=0.9 # Multiply by 0.9 (10% reduction)
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
# Determine number of CPU cores to use
|
| 838 |
+
num_cpu_cores = max(1, min(32, os.cpu_count()))
|
| 839 |
+
logging.info(f"Using {num_cpu_cores} CPU cores for data processing")
|
| 840 |
+
|
| 841 |
+
# Prepare dataset and dataloader with custom collate function
|
| 842 |
+
dataset = UserSimilarityDataset(
|
| 843 |
+
pipeline,
|
| 844 |
+
users_data,
|
| 845 |
+
num_triplets=num_triplets,
|
| 846 |
+
num_workers=num_cpu_cores
|
| 847 |
+
)
|
| 848 |
+
|
| 849 |
+
dataloader = DataLoader(
|
| 850 |
+
dataset,
|
| 851 |
+
batch_size=batch_size,
|
| 852 |
+
shuffle=True,
|
| 853 |
+
collate_fn=collate_batch,
|
| 854 |
+
num_workers=0, # For loading batches
|
| 855 |
+
pin_memory=True # Speed up data transfer to GPU
|
| 856 |
+
)
|
| 857 |
+
|
| 858 |
+
# Loss function
|
| 859 |
+
criterion = torch.nn.TripletMarginLoss(margin=1.0)
|
| 860 |
+
|
| 861 |
+
# Progress bar for epochs
|
| 862 |
+
epoch_pbar = tqdm(
|
| 863 |
+
range(num_epochs),
|
| 864 |
+
desc="Training Progress",
|
| 865 |
+
bar_format='{l_bar}{bar:10}{r_bar}'
|
| 866 |
+
)
|
| 867 |
+
# Set up tensorboard for logging
|
| 868 |
+
try:
|
| 869 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 870 |
+
log_dir = Path(save_dir) / "logs" if save_dir else Path("./logs")
|
| 871 |
+
log_dir.mkdir(exist_ok=True, parents=True)
|
| 872 |
+
writer = SummaryWriter(log_dir=log_dir)
|
| 873 |
+
tensorboard_available = True
|
| 874 |
+
except ImportError:
|
| 875 |
+
logging.warning("TensorBoard not available, skipping logging")
|
| 876 |
+
tensorboard_available = False
|
| 877 |
+
|
| 878 |
+
for epoch in epoch_pbar:
|
| 879 |
+
total_loss = 0
|
| 880 |
+
num_batches = 0
|
| 881 |
+
|
| 882 |
+
# Progress bar for batches
|
| 883 |
+
|
| 884 |
+
total_batches = len(dataloader)
|
| 885 |
+
update_freq = max(1, total_batches // 10) # Update approximately every 10%
|
| 886 |
+
batch_pbar = tqdm(
|
| 887 |
+
dataloader,
|
| 888 |
+
desc=f"Epoch {epoch+1}/{num_epochs}",
|
| 889 |
+
leave=False,
|
| 890 |
+
miniters=update_freq, # Only update progress bar every update_freq iterations
|
| 891 |
+
bar_format='{l_bar}{bar:10}{r_bar}', # Simplified bar format
|
| 892 |
+
disable=True # Completely disable the inner progress bar
|
| 893 |
+
)
|
| 894 |
+
|
| 895 |
+
# First, create a single progress indicator for the entire epoch instead of the batch_pbar
|
| 896 |
+
epoch_progress = tqdm(
|
| 897 |
+
total=len(dataloader),
|
| 898 |
+
desc=f"Epoch {epoch+1}/{num_epochs}",
|
| 899 |
+
leave=True,
|
| 900 |
+
bar_format='{l_bar}{bar:10}{r_bar}'
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
# Then use this updated batch processing loop
|
| 904 |
+
for batch_idx, batch_inputs in enumerate(dataloader):
|
| 905 |
+
try:
|
| 906 |
+
# Each element in the batch is already a dict of padded tensors
|
| 907 |
+
anchor_batch, positive_batch, negative_batch = batch_inputs
|
| 908 |
+
|
| 909 |
+
# Move data to device (GPU if available)
|
| 910 |
+
anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()}
|
| 911 |
+
positive_batch = {k: v.to(device) for k, v in positive_batch.items()}
|
| 912 |
+
negative_batch = {k: v.to(device) for k, v in negative_batch.items()}
|
| 913 |
+
|
| 914 |
+
# Generate embeddings
|
| 915 |
+
anchor_emb = model(anchor_batch)
|
| 916 |
+
positive_emb = model(positive_batch)
|
| 917 |
+
negative_emb = model(negative_batch)
|
| 918 |
+
|
| 919 |
+
# Calculate loss
|
| 920 |
+
loss = criterion(anchor_emb, positive_emb, negative_emb)
|
| 921 |
+
|
| 922 |
+
# Backward and optimize
|
| 923 |
+
optimizer.zero_grad()
|
| 924 |
+
loss.backward()
|
| 925 |
+
# Add gradient clipping
|
| 926 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 927 |
+
optimizer.step()
|
| 928 |
+
|
| 929 |
+
total_loss += loss.item()
|
| 930 |
+
num_batches += 1
|
| 931 |
+
|
| 932 |
+
# Update the epoch progress bar only at 10% intervals or at the end
|
| 933 |
+
update_interval = max(1, len(dataloader) // 10)
|
| 934 |
+
if (batch_idx + 1) % update_interval == 0 or batch_idx == len(dataloader) - 1:
|
| 935 |
+
# Update progress
|
| 936 |
+
remaining = min(update_interval, len(dataloader) - epoch_progress.n)
|
| 937 |
+
epoch_progress.update(remaining)
|
| 938 |
+
# Update stats with current average loss
|
| 939 |
+
current_avg_loss = total_loss / num_batches
|
| 940 |
+
epoch_progress.set_postfix(avg_loss=f"{current_avg_loss:.4f}",
|
| 941 |
+
last_batch_loss=f"{loss.item():.4f}")
|
| 942 |
+
|
| 943 |
+
except Exception as e:
|
| 944 |
+
logging.error(f"Error during batch processing: {str(e)}")
|
| 945 |
+
logging.error(f"Batch details: {str(e.__class__.__name__)}")
|
| 946 |
+
continue
|
| 947 |
+
|
| 948 |
+
# Close the progress bar at the end of the epoch
|
| 949 |
+
epoch_progress.close()
|
| 950 |
+
|
| 951 |
+
avg_loss = total_loss / max(1, num_batches)
|
| 952 |
+
# Update epoch progress bar with average loss
|
| 953 |
+
epoch_pbar.set_postfix(avg_loss=f"{avg_loss:.4f}")
|
| 954 |
+
|
| 955 |
+
# Log to tensorboard
|
| 956 |
+
if tensorboard_available:
|
| 957 |
+
writer.add_scalar('Loss/train', avg_loss, epoch)
|
| 958 |
+
|
| 959 |
+
# Step the learning rate scheduler at the end of each epoch
|
| 960 |
+
scheduler.step()
|
| 961 |
+
|
| 962 |
+
# Incremental model saving if requested
|
| 963 |
+
if save_dir and (epoch + 1) % save_interval == 0:
|
| 964 |
+
checkpoint = {
|
| 965 |
+
'epoch': epoch,
|
| 966 |
+
'model_state_dict': model.state_dict(),
|
| 967 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 968 |
+
'loss': avg_loss,
|
| 969 |
+
'scheduler_state_dict': scheduler.state_dict() # Save scheduler state
|
| 970 |
+
}
|
| 971 |
+
|
| 972 |
+
save_path = Path(save_dir) / f'model_checkpoint_epoch_{epoch+1}.pth'
|
| 973 |
+
torch.save(checkpoint, save_path)
|
| 974 |
+
logging.info(f"Checkpoint saved at epoch {epoch+1}: {save_path}")
|
| 975 |
+
|
| 976 |
+
if tensorboard_available:
|
| 977 |
+
writer.close()
|
| 978 |
+
|
| 979 |
+
return model
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
def compute_tag_frequencies(pipeline, users_data):
|
| 983 |
+
"""
|
| 984 |
+
Calcola le frequenze dei tag nel dataset.
|
| 985 |
+
|
| 986 |
+
Args:
|
| 987 |
+
pipeline: Pipeline di embedding
|
| 988 |
+
users_data: Lista di utenti
|
| 989 |
+
|
| 990 |
+
Returns:
|
| 991 |
+
dict: Dizionario con tag come chiavi e frequenze come valori
|
| 992 |
+
int: Numero totale di tag processati
|
| 993 |
+
"""
|
| 994 |
+
logging.info("Calcolando frequenze dei tag...")
|
| 995 |
+
tag_frequencies = {}
|
| 996 |
+
total_tags = 0
|
| 997 |
+
|
| 998 |
+
# Estrai i tag da tutti gli utenti
|
| 999 |
+
for data in users_data:
|
| 1000 |
+
# Estrai l'utente dalla struttura JSON
|
| 1001 |
+
if 'raw_json' in data and 'user' in data['raw_json']:
|
| 1002 |
+
user = data['raw_json']['user']
|
| 1003 |
+
elif 'user' in data:
|
| 1004 |
+
user = data['user']
|
| 1005 |
+
else:
|
| 1006 |
+
user = data # Assume che sia già un utente
|
| 1007 |
+
|
| 1008 |
+
# Estrai e conta i tag
|
| 1009 |
+
tags = pipeline._get_field_from_user(user, 'dmp_tags')
|
| 1010 |
+
for tag in tags:
|
| 1011 |
+
if tag is not None:
|
| 1012 |
+
tag_str = str(tag).lower().strip()
|
| 1013 |
+
tag_frequencies[tag_str] = tag_frequencies.get(tag_str, 0) + 1
|
| 1014 |
+
total_tags += 1
|
| 1015 |
+
|
| 1016 |
+
logging.info(f"Trovati {len(tag_frequencies)} tag unici su {total_tags} occorrenze totali")
|
| 1017 |
+
return tag_frequencies, total_tags
|
| 1018 |
+
|
| 1019 |
+
# funzione per filtrare i tag in base alla frequenza o al percentile
|
| 1020 |
+
|
| 1021 |
+
def filter_tags_by_criteria(tag_frequencies, min_frequency=100, percentile=None):
|
| 1022 |
+
"""
|
| 1023 |
+
Filtra i tag in base a criteri di frequenza o percentile.
|
| 1024 |
+
|
| 1025 |
+
Args:
|
| 1026 |
+
tag_frequencies: Dizionario con tag e frequenze
|
| 1027 |
+
min_frequency: Frequenza minima richiesta (default: 100)
|
| 1028 |
+
percentile: Se specificato, mantiene solo i tag fino al percentile indicato
|
| 1029 |
+
|
| 1030 |
+
Returns:
|
| 1031 |
+
set: Set di tag che soddisfano i criteri
|
| 1032 |
+
"""
|
| 1033 |
+
if percentile is not None:
|
| 1034 |
+
# Ordina i tag per frequenza e mantieni solo fino al percentile specificato
|
| 1035 |
+
sorted_tags = sorted(tag_frequencies.items(), key=lambda x: x[1], reverse=True)
|
| 1036 |
+
cutoff_index = int(len(sorted_tags) * (percentile / 100.0))
|
| 1037 |
+
filtered_tags = {tag for tag, _ in sorted_tags[:cutoff_index]}
|
| 1038 |
+
|
| 1039 |
+
min_freq_in_set = sorted_tags[cutoff_index-1][1] if cutoff_index > 0 else 0
|
| 1040 |
+
logging.info(f"Filtrati tag al {percentile}° percentile. Mantenuti {len(filtered_tags)} tag con frequenza >= {min_freq_in_set}")
|
| 1041 |
+
else:
|
| 1042 |
+
# Filtra solo in base alla frequenza minima
|
| 1043 |
+
filtered_tags = {tag for tag, freq in tag_frequencies.items() if freq >= min_frequency}
|
| 1044 |
+
logging.info(f"Filtrati tag con frequenza < {min_frequency}. Mantenuti {len(filtered_tags)} tag")
|
| 1045 |
+
|
| 1046 |
+
return filtered_tags
|
| 1047 |
+
|
| 1048 |
+
# =================================
|
| 1049 |
+
# MAIN FUNCTION
|
| 1050 |
+
# =================================
|
| 1051 |
+
def main():
|
| 1052 |
+
# Configuration
|
| 1053 |
+
output_dir = Path(OUTPUT_DIR)
|
| 1054 |
+
|
| 1055 |
+
# Check if CUDA is available
|
| 1056 |
+
cuda_available = torch.cuda.is_available()
|
| 1057 |
+
logging.info(f"CUDA available: {cuda_available}")
|
| 1058 |
+
if cuda_available:
|
| 1059 |
+
logging.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
|
| 1060 |
+
logging.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
|
| 1061 |
+
|
| 1062 |
+
# CPU info
|
| 1063 |
+
cpu_count = os.cpu_count()
|
| 1064 |
+
memory_info = psutil.virtual_memory()
|
| 1065 |
+
logging.info(f"CPU cores: {cpu_count}")
|
| 1066 |
+
logging.info(f"System memory: {memory_info.total / 1e9:.2f} GB")
|
| 1067 |
+
|
| 1068 |
+
# Print configuration
|
| 1069 |
+
logging.info("Running with the following configuration:")
|
| 1070 |
+
logging.info(f"- Number of triplets: {NUM_TRIPLETS}")
|
| 1071 |
+
logging.info(f"- Number of epochs: {NUM_EPOCHS}")
|
| 1072 |
+
logging.info(f"- Batch size: {BATCH_SIZE}")
|
| 1073 |
+
logging.info(f"- Learning rate: {LEARNING_RATE}")
|
| 1074 |
+
logging.info(f"- Output dimension: {OUTPUT_DIM}")
|
| 1075 |
+
logging.info(f"- Data path: {DATA_PATH}")
|
| 1076 |
+
logging.info(f"- Output directory: {OUTPUT_DIR}")
|
| 1077 |
+
|
| 1078 |
+
# Load data
|
| 1079 |
+
logging.info("Loading user data...")
|
| 1080 |
+
try:
|
| 1081 |
+
try:
|
| 1082 |
+
# First try loading as normal JSON
|
| 1083 |
+
with open(DATA_PATH, 'r') as f:
|
| 1084 |
+
json_data = json.load(f)
|
| 1085 |
+
|
| 1086 |
+
# Handle both cases: array of users or single object
|
| 1087 |
+
if isinstance(json_data, list):
|
| 1088 |
+
users_data = json_data
|
| 1089 |
+
elif isinstance(json_data, dict):
|
| 1090 |
+
# If it's a single record, put it in a list
|
| 1091 |
+
users_data = [json_data]
|
| 1092 |
+
else:
|
| 1093 |
+
raise ValueError("Unsupported JSON format")
|
| 1094 |
+
|
| 1095 |
+
except json.JSONDecodeError:
|
| 1096 |
+
# If it fails, the file might be a non-standard JSON (objects separated by commas)
|
| 1097 |
+
logging.info("Detected possible non-standard JSON format, attempting correction...")
|
| 1098 |
+
with open(DATA_PATH, 'r') as f:
|
| 1099 |
+
text = f.read().strip()
|
| 1100 |
+
|
| 1101 |
+
# Add square brackets to create a valid JSON array
|
| 1102 |
+
if not text.startswith('['):
|
| 1103 |
+
text = '[' + text
|
| 1104 |
+
if not text.endswith(']'):
|
| 1105 |
+
text = text + ']'
|
| 1106 |
+
|
| 1107 |
+
# Try loading the corrected text
|
| 1108 |
+
users_data = json.loads(text)
|
| 1109 |
+
logging.info("JSON format successfully corrected")
|
| 1110 |
+
|
| 1111 |
+
logging.info(f"Loaded {len(users_data)} records")
|
| 1112 |
+
except FileNotFoundError:
|
| 1113 |
+
logging.error(f"File {DATA_PATH} not found!")
|
| 1114 |
+
return
|
| 1115 |
+
except Exception as e:
|
| 1116 |
+
logging.error(f"Unable to load file: {str(e)}")
|
| 1117 |
+
return
|
| 1118 |
+
|
| 1119 |
+
# Initialize pipeline
|
| 1120 |
+
logging.info("Initializing pipeline...")
|
| 1121 |
+
pipeline = UserEmbeddingPipeline(
|
| 1122 |
+
output_dim=OUTPUT_DIM,
|
| 1123 |
+
max_sequence_length=MAX_SEQ_LENGTH
|
| 1124 |
+
)
|
| 1125 |
+
|
| 1126 |
+
# Build vocabularies
|
| 1127 |
+
logging.info("Building vocabularies...")
|
| 1128 |
+
try:
|
| 1129 |
+
pipeline.build_vocabularies(users_data)
|
| 1130 |
+
vocab_sizes = {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()}
|
| 1131 |
+
logging.info(f"Vocabulary sizes: {vocab_sizes}")
|
| 1132 |
+
except Exception as e:
|
| 1133 |
+
logging.error(f"Error building vocabularies: {str(e)}")
|
| 1134 |
+
return
|
| 1135 |
+
|
| 1136 |
+
# Initialize model
|
| 1137 |
+
logging.info("Initializing model...")
|
| 1138 |
+
try:
|
| 1139 |
+
pipeline.initialize_model()
|
| 1140 |
+
logging.info("Model initialized successfully")
|
| 1141 |
+
except Exception as e:
|
| 1142 |
+
logging.error(f"Error initializing model: {str(e)}")
|
| 1143 |
+
return
|
| 1144 |
+
|
| 1145 |
+
# Training
|
| 1146 |
+
logging.info("Starting training...")
|
| 1147 |
+
try:
|
| 1148 |
+
# Create directory for checkpoints if it doesn't exist
|
| 1149 |
+
model_dir = output_dir / "model_checkpoints"
|
| 1150 |
+
model_dir.mkdir(exist_ok=True, parents=True)
|
| 1151 |
+
|
| 1152 |
+
model = train_user_embeddings(
|
| 1153 |
+
pipeline.model,
|
| 1154 |
+
users_data,
|
| 1155 |
+
pipeline, # Pass the pipeline to training
|
| 1156 |
+
num_epochs=NUM_EPOCHS,
|
| 1157 |
+
batch_size=BATCH_SIZE,
|
| 1158 |
+
lr=LEARNING_RATE,
|
| 1159 |
+
save_dir=model_dir, # Add incremental saving
|
| 1160 |
+
save_interval=SAVE_INTERVAL, # Save every N epochs
|
| 1161 |
+
num_triplets=NUM_TRIPLETS
|
| 1162 |
+
)
|
| 1163 |
+
logging.info("Training completed")
|
| 1164 |
+
pipeline.model = model
|
| 1165 |
+
|
| 1166 |
+
# Save only the model file
|
| 1167 |
+
logging.info("Saving model...")
|
| 1168 |
+
|
| 1169 |
+
# Create output directory
|
| 1170 |
+
output_dir.mkdir(exist_ok=True)
|
| 1171 |
+
|
| 1172 |
+
# Save path
|
| 1173 |
+
model_path = output_dir / 'model.pth'
|
| 1174 |
+
|
| 1175 |
+
# Prepare dictionary with model state and metadata
|
| 1176 |
+
checkpoint = {
|
| 1177 |
+
'model_state_dict': pipeline.model.state_dict(),
|
| 1178 |
+
'vocab_maps': pipeline.vocab_maps,
|
| 1179 |
+
'embedding_dims': pipeline.embedding_dims,
|
| 1180 |
+
'output_dim': pipeline.output_dim,
|
| 1181 |
+
'max_sequence_length': pipeline.max_sequence_length
|
| 1182 |
+
}
|
| 1183 |
+
|
| 1184 |
+
# Save model
|
| 1185 |
+
torch.save(checkpoint, model_path)
|
| 1186 |
+
|
| 1187 |
+
logging.info(f"Model saved to: {model_path}")
|
| 1188 |
+
|
| 1189 |
+
# Also save a configuration file for reference
|
| 1190 |
+
config_info = {
|
| 1191 |
+
'model_type': 'UserEmbeddingModel',
|
| 1192 |
+
'vocab_sizes': {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()},
|
| 1193 |
+
'embedding_dims': pipeline.embedding_dims,
|
| 1194 |
+
'output_dim': pipeline.output_dim,
|
| 1195 |
+
'max_sequence_length': pipeline.max_sequence_length,
|
| 1196 |
+
'padded_fields': list(pipeline.model.padded_fields),
|
| 1197 |
+
'fields': pipeline.fields
|
| 1198 |
+
}
|
| 1199 |
+
|
| 1200 |
+
config_path = output_dir / 'model_config.json'
|
| 1201 |
+
with open(config_path, 'w') as f:
|
| 1202 |
+
json.dump(config_info, f, indent=2)
|
| 1203 |
+
|
| 1204 |
+
logging.info(f"Model configuration saved to: {config_path}")
|
| 1205 |
+
|
| 1206 |
+
# Only save in HuggingFace format if requested
|
| 1207 |
+
save_hf = os.environ.get("SAVE_HF_FORMAT", "false").lower() == "true"
|
| 1208 |
+
if save_hf:
|
| 1209 |
+
logging.info("Saving in HuggingFace format...")
|
| 1210 |
+
# Save model in HuggingFace format
|
| 1211 |
+
hf_dir = output_dir / 'huggingface'
|
| 1212 |
+
hf_dir.mkdir(exist_ok=True)
|
| 1213 |
+
|
| 1214 |
+
# Save model in HF format
|
| 1215 |
+
torch.save(pipeline.model.state_dict(), hf_dir / 'pytorch_model.bin')
|
| 1216 |
+
|
| 1217 |
+
# Save config
|
| 1218 |
+
with open(hf_dir / 'config.json', 'w') as f:
|
| 1219 |
+
json.dump(config_info, f, indent=2)
|
| 1220 |
+
|
| 1221 |
+
logging.info(f"Model saved in HuggingFace format to: {hf_dir}")
|
| 1222 |
+
|
| 1223 |
+
# Push to HuggingFace if environment variable is set
|
| 1224 |
+
hf_repo_id = os.environ.get("HF_REPO_ID")
|
| 1225 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 1226 |
+
|
| 1227 |
+
if save_hf and hf_repo_id and hf_token:
|
| 1228 |
+
try:
|
| 1229 |
+
from huggingface_hub import HfApi
|
| 1230 |
+
|
| 1231 |
+
logging.info(f"Pushing model to HuggingFace: {hf_repo_id}")
|
| 1232 |
+
api = HfApi()
|
| 1233 |
+
|
| 1234 |
+
# Push the model directory
|
| 1235 |
+
api.create_repo(
|
| 1236 |
+
repo_id=hf_repo_id,
|
| 1237 |
+
token=hf_token,
|
| 1238 |
+
exist_ok=True,
|
| 1239 |
+
private=True
|
| 1240 |
+
)
|
| 1241 |
+
|
| 1242 |
+
# Upload the files
|
| 1243 |
+
for file_path in (output_dir / "huggingface").glob("**/*"):
|
| 1244 |
+
if file_path.is_file():
|
| 1245 |
+
api.upload_file(
|
| 1246 |
+
path_or_fileobj=str(file_path),
|
| 1247 |
+
path_in_repo=file_path.relative_to(output_dir / "huggingface"),
|
| 1248 |
+
repo_id=hf_repo_id,
|
| 1249 |
+
token=hf_token
|
| 1250 |
+
)
|
| 1251 |
+
|
| 1252 |
+
logging.info(f"Model successfully pushed to HuggingFace: {hf_repo_id}")
|
| 1253 |
+
except Exception as e:
|
| 1254 |
+
logging.error(f"Error pushing to HuggingFace: {str(e)}")
|
| 1255 |
+
|
| 1256 |
+
except Exception as e:
|
| 1257 |
+
logging.error(f"Error during training or saving: {str(e)}")
|
| 1258 |
+
import traceback
|
| 1259 |
+
traceback.print_exc()
|
| 1260 |
+
return
|
| 1261 |
+
|
| 1262 |
+
logging.info("Process completed successfully!")
|
| 1263 |
+
|
| 1264 |
+
if __name__ == "__main__":
|
| 1265 |
+
main()
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.1.0
|
| 2 |
+
torchvision==0.16.0
|
| 3 |
+
numpy==1.24.3
|
| 4 |
+
pandas==2.0.3
|
| 5 |
+
tqdm==4.66.1
|
| 6 |
+
scikit-learn==1.3.0
|
| 7 |
+
matplotlib==3.7.2
|
| 8 |
+
transformers==4.34.0
|
| 9 |
+
datasets==2.14.5
|
| 10 |
+
huggingface_hub==0.17.3
|
| 11 |
+
tensorboard==2.14.1
|
| 12 |
+
psutil==5.9.5
|
run.sh
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
set -e
|
| 3 |
+
|
| 4 |
+
echo "Starting user embedding model pipeline..."
|
| 5 |
+
|
| 6 |
+
# Check if model already exists
|
| 7 |
+
if [ -f "embeddings_output/model.pth" ]; then
|
| 8 |
+
echo "Model already trained."
|
| 9 |
+
else
|
| 10 |
+
# Check if input data exists
|
| 11 |
+
if [ ! -f "$DATA_PATH" ] && [ ! -f "users.json" ]; then
|
| 12 |
+
echo "Error: No data file found. Please mount a volume with users.json or set DATA_PATH."
|
| 13 |
+
exit 1
|
| 14 |
+
fi
|
| 15 |
+
|
| 16 |
+
# Run the model generation script
|
| 17 |
+
|
| 18 |
+
python3 generate_model_gpu.py
|
| 19 |
+
echo "Process completed. Check embeddings_output directory for results."
|
| 20 |
+
fi
|
| 21 |
+
|
| 22 |
+
# Get the hostname or public URL if available (for Hugging Face Spaces)
|
| 23 |
+
if [ -n "$SPACE_ID" ]; then
|
| 24 |
+
# If running in Hugging Face Spaces
|
| 25 |
+
BASE_URL="https://${SPACE_ID}.hf.space"
|
| 26 |
+
else
|
| 27 |
+
# If running locally or in generic Docker
|
| 28 |
+
BASE_URL="http://localhost:7860"
|
| 29 |
+
fi
|
| 30 |
+
|
| 31 |
+
echo "=========================================================="
|
| 32 |
+
echo "Model is available for download at the following URLs:"
|
| 33 |
+
echo "${BASE_URL}/file=embeddings_output/model.pth"
|
| 34 |
+
echo "${BASE_URL}/file=embeddings_output/model_config.json"
|
| 35 |
+
echo "=========================================================="
|
| 36 |
+
|
| 37 |
+
# Start the web server to serve files
|
| 38 |
+
echo "Starting web server on port 7860..."
|
| 39 |
+
python3 -m http.server 7860
|
users_200k.json
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d91ddb646dfcc955f89edcf756041415a6ba9f3d34cd217e8dc7a6f35cc9161
|
| 3 |
+
size 297966779
|