| import os |
| os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache' |
| |
| import multiprocessing |
| try: |
| multiprocessing.set_start_method('spawn') |
| except RuntimeError: |
| pass |
| import json |
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import pandas as pd |
| from typing import List, Dict |
| import logging |
| import torch.nn.functional as F |
| from torch.utils.data import Dataset, DataLoader |
| import random |
| from collections import defaultdict |
| from pathlib import Path |
| from tqdm import tqdm |
| import os |
| import multiprocessing |
| from multiprocessing import Pool |
| import psutil |
| import argparse |
|
|
| |
| |
| |
| |
| DEFAULT_NUM_TRIPLETS = 150000 |
| DEFAULT_NUM_EPOCHS = 20 |
| DEFAULT_BATCH_SIZE = 64 |
| DEFAULT_LEARNING_RATE = 0.001 |
| DEFAULT_OUTPUT_DIM = 256 |
| DEFAULT_MAX_SEQ_LENGTH = 15 |
| DEFAULT_SAVE_INTERVAL = 2 |
| DEFAULT_DATA_PATH = "./users.json" |
| DEFAULT_OUTPUT_DIR = "./model" |
|
|
| |
| NUM_TRIPLETS = int(os.environ.get("NUM_TRIPLETS", DEFAULT_NUM_TRIPLETS)) |
| NUM_EPOCHS = int(os.environ.get("NUM_EPOCHS", DEFAULT_NUM_EPOCHS)) |
| BATCH_SIZE = int(os.environ.get("BATCH_SIZE", DEFAULT_BATCH_SIZE)) |
| LEARNING_RATE = float(os.environ.get("LEARNING_RATE", DEFAULT_LEARNING_RATE)) |
| OUTPUT_DIM = int(os.environ.get("OUTPUT_DIM", DEFAULT_OUTPUT_DIM)) |
| MAX_SEQ_LENGTH = int(os.environ.get("MAX_SEQ_LENGTH", DEFAULT_MAX_SEQ_LENGTH)) |
| SAVE_INTERVAL = int(os.environ.get("SAVE_INTERVAL", DEFAULT_SAVE_INTERVAL)) |
| DATA_PATH = os.environ.get("DATA_PATH", DEFAULT_DATA_PATH) |
| OUTPUT_DIR = os.environ.get("OUTPUT_DIR", DEFAULT_OUTPUT_DIR) |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logging.info(f"Using device: {device}") |
|
|
| |
| |
| |
| class UserEmbeddingModel(nn.Module): |
| def __init__(self, vocab_sizes: Dict[str, int], embedding_dims: Dict[str, int], |
| output_dim: int = 256, max_sequence_length: int = 15, |
| padded_fields_length: int = 10): |
| super().__init__() |
| |
| self.max_sequence_length = max_sequence_length |
| self.padded_fields_length = padded_fields_length |
| self.padded_fields = {'dmp_channels', 'dmp_tags', 'dmp_clusters'} |
| self.embedding_layers = nn.ModuleDict() |
| |
| |
| for field, vocab_size in vocab_sizes.items(): |
| self.embedding_layers[field] = nn.Embedding( |
| vocab_size, |
| embedding_dims.get(field, 16), |
| padding_idx=0 |
| ) |
| |
| |
| self.total_input_dim = 0 |
| for field, dim in embedding_dims.items(): |
| if field in self.padded_fields: |
| self.total_input_dim += dim |
| else: |
| self.total_input_dim += dim |
| |
| print(f"Total input dimension: {self.total_input_dim}") |
| |
| self.fc = nn.Sequential( |
| nn.Linear(self.total_input_dim, self.total_input_dim // 2), |
| nn.ReLU(), |
| nn.Dropout(0.2), |
| nn.Linear(self.total_input_dim // 2, output_dim), |
| nn.LayerNorm(output_dim) |
| ) |
|
|
| def _process_sequence(self, embedding_layer: nn.Embedding, indices: torch.Tensor, |
| field_name: str) -> torch.Tensor: |
| """Process normal sequences""" |
| batch_size = indices.size(0) |
| if indices.numel() == 0: |
| return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device) |
| |
| if field_name in ['dmp_city', 'dmp_domains']: |
| if indices.dim() == 1: |
| indices = indices.unsqueeze(0) |
| if indices.size(1) > 0: |
| return embedding_layer(indices[:, 0]) |
| return torch.zeros(batch_size, embedding_layer.embedding_dim, device=indices.device) |
| |
| |
| embeddings = embedding_layer(indices) |
| return embeddings.mean(dim=1) |
|
|
| def _process_padded_sequence(self, embedding_layer: nn.Embedding, |
| indices: torch.Tensor) -> torch.Tensor: |
| """Process sequences with padding""" |
| batch_size = indices.size(0) |
| emb_dim = embedding_layer.embedding_dim |
| |
| |
| embeddings = embedding_layer(indices) |
| |
| |
| mask = (indices != 0).float().unsqueeze(-1) |
| masked_embeddings = embeddings * mask |
| sum_mask = mask.sum(dim=1).clamp(min=1.0) |
| |
| return (masked_embeddings.sum(dim=1) / sum_mask) |
|
|
| def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: |
| batch_embeddings = [] |
| |
| for field in ['dmp_city', 'dmp_domains', 'dmp_brands', |
| 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels', |
| 'link_host', 'link_path']: |
| if field in inputs and field in self.embedding_layers: |
| if field in self.padded_fields: |
| emb = self._process_padded_sequence( |
| self.embedding_layers[field], |
| inputs[field] |
| ) |
| else: |
| emb = self._process_sequence( |
| self.embedding_layers[field], |
| inputs[field], |
| field |
| ) |
| batch_embeddings.append(emb) |
| |
| combined = torch.cat(batch_embeddings, dim=1) |
| return self.fc(combined) |
|
|
| |
| |
| |
| class UserEmbeddingPipeline: |
| def __init__(self, output_dim: int = 256, max_sequence_length: int = 15): |
| self.output_dim = output_dim |
| self.max_sequence_length = max_sequence_length |
| self.model = None |
| self.vocab_maps = {} |
| |
| self.fields = [ |
| 'dmp_city', 'dmp_domains', 'dmp_brands', |
| 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'dmp_channels', |
| 'link_host', 'link_path' |
| ] |
| |
| |
| self.field_mapping = { |
| 'dmp_city': ('dmp', 'city'), |
| 'dmp_domains': ('dmp', 'domains'), |
| 'dmp_brands': ('dmp', 'brands'), |
| 'dmp_clusters': ('dmp', 'clusters'), |
| 'dmp_industries': ('dmp', 'industries'), |
| 'dmp_tags': ('dmp', 'tags'), |
| 'dmp_channels': ('dmp', 'channels'), |
| 'link_host': ('dmp', '~click__host'), |
| 'link_path': ('dmp', '~click__domain') |
| } |
| |
| self.embedding_dims = { |
| 'dmp_city': 16, |
| 'dmp_domains': 16, |
| 'dmp_brands': 32, |
| 'dmp_clusters': 32, |
| 'dmp_industries': 32, |
| 'dmp_tags': 64, |
| 'dmp_channels': 32, |
| 'link_host': 32, |
| 'link_path': 32 |
| } |
|
|
| def _clean_value(self, value): |
| if isinstance(value, float) and np.isnan(value): |
| return [] |
| if isinstance(value, str): |
| return [value.lower().strip()] |
| if isinstance(value, list): |
| return [str(v).lower().strip() for v in value if v is not None and str(v).strip()] |
| return [] |
|
|
| def _get_field_from_user(self, user, field): |
| """Extract field value from new JSON user format""" |
| mapping = self.field_mapping.get(field, (field,)) |
| value = user |
| |
| |
| for key in mapping: |
| if isinstance(value, dict): |
| value = value.get(key, {}) |
| else: |
| |
| |
| value = [] |
| break |
| |
| |
| |
| if field in {'dmp_brands', 'dmp_channels', 'dmp_clusters', 'dmp_industries', 'dmp_tags', 'link_host', 'link_path'} and not isinstance(value, list): |
| |
| if value and not isinstance(value, dict): |
| value = [value] |
| else: |
| value = [] |
| |
| return value |
|
|
| def build_vocabularies(self, users_data: List[Dict]) -> Dict[str, Dict[str, int]]: |
| field_values = {field: {'<PAD>'} for field in self.fields} |
| |
| |
| users = [] |
| for data in users_data: |
| |
| if 'raw_json' in data and 'user' in data['raw_json']: |
| users.append(data['raw_json']['user']) |
| |
| elif 'user' in data: |
| users.append(data['user']) |
| else: |
| users.append(data) |
| |
| for user in users: |
| for field in self.fields: |
| values = self._clean_value(self._get_field_from_user(user, field)) |
| field_values[field].update(values) |
| |
| self.vocab_maps = { |
| field: {val: idx for idx, val in enumerate(sorted(values))} |
| for field, values in field_values.items() |
| } |
| |
| return self.vocab_maps |
|
|
| def _prepare_input(self, user: Dict) -> Dict[str, torch.Tensor]: |
| inputs = {} |
| |
| for field in self.fields: |
| values = self._clean_value(self._get_field_from_user(user, field)) |
| vocab = self.vocab_maps[field] |
| indices = [vocab.get(val, 0) for val in values] |
| inputs[field] = torch.tensor(indices, dtype=torch.long) |
| |
| return inputs |
|
|
| def initialize_model(self) -> None: |
| vocab_sizes = {field: len(vocab) for field, vocab in self.vocab_maps.items()} |
| |
| self.model = UserEmbeddingModel( |
| vocab_sizes=vocab_sizes, |
| embedding_dims=self.embedding_dims, |
| output_dim=self.output_dim, |
| max_sequence_length=self.max_sequence_length |
| ) |
| self.model.to(device) |
| self.model.eval() |
|
|
| def generate_embeddings(self, users_data: List[Dict], batch_size: int = 32) -> Dict[str, np.ndarray]: |
| """Generate embeddings for all users""" |
| embeddings = {} |
| self.model.eval() |
| |
| |
| users = [] |
| user_ids = [] |
| |
| for data in users_data: |
| |
| if 'raw_json' in data and 'user' in data['raw_json']: |
| user = data['raw_json']['user'] |
| users.append(user) |
| |
| if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']: |
| user_ids.append(str(user['dmp']['']['id'])) |
| else: |
| |
| user_ids.append(str(user.get('uid', user.get('id', None)))) |
| |
| elif 'user' in data: |
| user = data['user'] |
| users.append(user) |
| |
| if 'dmp' in user and '' in user['dmp'] and 'id' in user['dmp']['']: |
| user_ids.append(str(user['dmp']['']['id'])) |
| else: |
| |
| user_ids.append(str(user.get('uid', user.get('id', None)))) |
| else: |
| users.append(data) |
| |
| if 'dmp' in data and '' in data['dmp'] and 'id' in data['dmp']['']: |
| user_ids.append(str(data['dmp']['']['id'])) |
| else: |
| |
| user_ids.append(str(data.get('uid', data.get('id', None)))) |
| |
| with torch.no_grad(): |
| for i in tqdm(range(0, len(users), batch_size), desc="Generating embeddings"): |
| batch_users = users[i:i+batch_size] |
| batch_ids = user_ids[i:i+batch_size] |
| batch_inputs = [] |
| valid_indices = [] |
| |
| for j, user in enumerate(batch_users): |
| if batch_ids[j] is not None: |
| batch_inputs.append(self._prepare_input(user)) |
| valid_indices.append(j) |
| |
| if batch_inputs: |
| |
| anchor_batch, _, _ = collate_batch([(inputs, inputs, inputs) for inputs in batch_inputs]) |
| |
| |
| anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()} |
| |
| |
| batch_embeddings = self.model(anchor_batch).cpu() |
| |
| |
| for j, idx in enumerate(valid_indices): |
| if batch_ids[idx]: |
| embeddings[batch_ids[idx]] = batch_embeddings[j].numpy() |
| |
| return embeddings |
|
|
| def save_embeddings(self, embeddings: Dict[str, np.ndarray], output_dir: str) -> None: |
| """Save embeddings to file""" |
| output_dir = Path(output_dir) |
| output_dir.mkdir(exist_ok=True) |
| |
| |
| json_path = output_dir / 'embeddings.json' |
| with open(json_path, 'w') as f: |
| json_embeddings = {user_id: emb.tolist() for user_id, emb in embeddings.items()} |
| json.dump(json_embeddings, f) |
| |
| |
| npy_path = output_dir / 'embeddings.npz' |
| np.savez_compressed(npy_path, |
| embeddings=np.stack(list(embeddings.values())), |
| user_ids=np.array(list(embeddings.keys()))) |
| |
| |
| vocab_path = output_dir / 'vocabularies.json' |
| with open(vocab_path, 'w') as f: |
| json.dump(self.vocab_maps, f) |
| |
| logging.info(f"\nEmbeddings saved in {output_dir}:") |
| logging.info(f"- Embeddings JSON: {json_path}") |
| logging.info(f"- Embeddings NPY: {npy_path}") |
| logging.info(f"- Vocabularies: {vocab_path}") |
| |
| def save_model(self, output_dir: str) -> None: |
| """Save model in PyTorch format (.pth)""" |
| output_dir = Path(output_dir) |
| output_dir.mkdir(exist_ok=True) |
| |
| |
| model_path = output_dir / 'model.pth' |
| |
| |
| checkpoint = { |
| 'model_state_dict': self.model.state_dict(), |
| 'vocab_maps': self.vocab_maps, |
| 'embedding_dims': self.embedding_dims, |
| 'output_dim': self.output_dim, |
| 'max_sequence_length': self.max_sequence_length |
| } |
| |
| |
| torch.save(checkpoint, model_path) |
| |
| logging.info(f"Model saved to: {model_path}") |
| |
| |
| config_info = { |
| 'model_type': 'UserEmbeddingModel', |
| 'vocab_sizes': {field: len(vocab) for field, vocab in self.vocab_maps.items()}, |
| 'embedding_dims': self.embedding_dims, |
| 'output_dim': self.output_dim, |
| 'max_sequence_length': self.max_sequence_length, |
| 'padded_fields': list(self.model.padded_fields), |
| 'fields': self.fields |
| } |
| |
| config_path = output_dir / 'model_config.json' |
| with open(config_path, 'w') as f: |
| json.dump(config_info, f, indent=2) |
| |
| logging.info(f"Model configuration saved to: {config_path}") |
| |
| |
| hf_dir = output_dir / 'huggingface' |
| hf_dir.mkdir(exist_ok=True) |
| |
| |
| torch.save(self.model.state_dict(), hf_dir / 'pytorch_model.bin') |
| |
| |
| with open(hf_dir / 'config.json', 'w') as f: |
| json.dump(config_info, f, indent=2) |
| |
| logging.info(f"Model saved in HuggingFace format to: {hf_dir}") |
| |
| def load_model(self, model_path: str) -> None: |
| """Load a previously saved model""" |
| checkpoint = torch.load(model_path, map_location=device) |
| |
| |
| self.vocab_maps = checkpoint.get('vocab_maps', self.vocab_maps) |
| self.embedding_dims = checkpoint.get('embedding_dims', self.embedding_dims) |
| self.output_dim = checkpoint.get('output_dim', self.output_dim) |
| self.max_sequence_length = checkpoint.get('max_sequence_length', self.max_sequence_length) |
| |
| |
| if self.model is None: |
| self.initialize_model() |
| |
| |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| self.model.to(device) |
| self.model.eval() |
| |
| logging.info(f"Model loaded from: {model_path}") |
|
|
| |
| |
| |
| def calculate_similarity(user1, user2, pipeline): |
| try: |
| channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None) |
| channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None) |
| clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None) |
| clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None) |
| |
| channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2)) |
| cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | clusters2)) |
| |
| return 0.5 * channel_sim + 0.5 * cluster_sim |
| except Exception as e: |
| logging.error(f"Error calculating similarity: {str(e)}") |
| return 0.0 |
|
|
| def process_batch_triplets(args): |
| try: |
| batch_idx, users, channel_index, cluster_index, num_triplets, pipeline = args |
| batch_triplets = [] |
| |
| |
| with torch.no_grad(): |
| |
| temp_device = torch.device("cpu") |
| |
| for _ in range(num_triplets): |
| anchor_idx = random.randint(0, len(users)-1) |
| anchor_user = users[anchor_idx] |
| |
| |
| candidates = set() |
| for channel in pipeline._get_field_from_user(anchor_user, 'dmp_channels'): |
| candidates.update(channel_index.get(str(channel), [])) |
| for cluster in pipeline._get_field_from_user(anchor_user, 'dmp_clusters'): |
| candidates.update(cluster_index.get(str(cluster), [])) |
| |
| |
| candidates.discard(anchor_idx) |
| |
| |
| if not candidates: |
| positive_idx = random.randint(0, len(users)-1) |
| else: |
| |
| similarities = [] |
| for idx in candidates: |
| |
| sim = cpu_calculate_similarity(anchor_user, users[idx], pipeline) |
| if sim > 0: |
| similarities.append((idx, sim)) |
| |
| if not similarities: |
| positive_idx = random.randint(0, len(users)-1) |
| else: |
| |
| similarities.sort(key=lambda x: x[1], reverse=True) |
| |
| top_k = min(10, len(similarities)) |
| positive_idx = similarities[random.randint(0, top_k-1)][0] |
| |
| |
| max_attempts = 50 |
| negative_idx = None |
| |
| for _ in range(max_attempts): |
| idx = random.randint(0, len(users)-1) |
| if idx != anchor_idx and idx != positive_idx: |
| |
| if cpu_calculate_similarity(anchor_user, users[idx], pipeline) < 0.1: |
| negative_idx = idx |
| break |
| |
| if negative_idx is None: |
| negative_idx = random.randint(0, len(users)-1) |
| |
| batch_triplets.append((anchor_idx, positive_idx, negative_idx)) |
| |
| return batch_triplets |
| except Exception as e: |
| logging.error(f"Error in batch triplet generation: {str(e)}") |
| return [] |
|
|
| |
| def cpu_calculate_similarity(user1, user2, pipeline): |
| try: |
| channels1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_channels') if c is not None) |
| channels2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_channels') if c is not None) |
| clusters1 = set(str(c) for c in pipeline._get_field_from_user(user1, 'dmp_clusters') if c is not None) |
| clusters2 = set(str(c) for c in pipeline._get_field_from_user(user2, 'dmp_clusters') if c is not None) |
| |
| channel_sim = len(channels1 & channels2) / max(1, len(channels1 | channels2)) |
| cluster_sim = len(clusters1 & clusters2) / max(1, len(clusters1 | clusters2)) |
| |
| return 0.5 * channel_sim + 0.5 * cluster_sim |
| except Exception as e: |
| logging.error(f"Error calculating similarity: {str(e)}") |
| return 0.0 |
|
|
| |
| |
| |
| class UserSimilarityDataset(Dataset): |
| def __init__(self, pipeline, users_data, num_triplets=10, num_workers=None): |
| self.triplets = [] |
| logging.info("Initializing UserSimilarityDataset...") |
| |
| |
| self.users = [] |
| for data in users_data: |
| |
| if 'raw_json' in data and 'user' in data['raw_json']: |
| self.users.append(data['raw_json']['user']) |
| |
| elif 'user' in data: |
| self.users.append(data['user']) |
| else: |
| self.users.append(data) |
| |
| self.pipeline = pipeline |
| self.num_triplets = num_triplets |
| |
| |
| if num_workers is None: |
| num_workers = max(1, min(8, os.cpu_count())) |
| self.num_workers = num_workers |
| |
| |
| self.preprocessed_inputs = {} |
| for idx, user in enumerate(self.users): |
| self.preprocessed_inputs[idx] = pipeline._prepare_input(user) |
| |
| logging.info("Creating indexes for channels and clusters...") |
| self.channel_index = defaultdict(list) |
| self.cluster_index = defaultdict(list) |
| |
| for idx, user in enumerate(self.users): |
| channels = pipeline._get_field_from_user(user, 'dmp_channels') |
| clusters = pipeline._get_field_from_user(user, 'dmp_clusters') |
| |
| if channels: |
| channels = [str(c) for c in channels if c is not None] |
| if clusters: |
| clusters = [str(c) for c in clusters if c is not None] |
| |
| for channel in channels: |
| self.channel_index[channel].append(idx) |
| for cluster in clusters: |
| self.cluster_index[cluster].append(idx) |
| |
| logging.info(f"Found {len(self.channel_index)} unique channels and {len(self.cluster_index)} unique clusters") |
| logging.info(f"Generating triplets using {self.num_workers} worker processes...") |
| |
| self.triplets = self._generate_triplets_gpu(num_triplets) |
| logging.info(f"Generated {len(self.triplets)} triplets") |
|
|
| |
| def __len__(self): |
| return len(self.triplets) |
| |
| def __getitem__(self, idx): |
| if idx >= len(self.triplets): |
| raise IndexError(f"Index {idx} out of range for dataset with {len(self.triplets)} triplets") |
| |
| anchor_idx, positive_idx, negative_idx = self.triplets[idx] |
| return ( |
| self.preprocessed_inputs[anchor_idx], |
| self.preprocessed_inputs[positive_idx], |
| self.preprocessed_inputs[negative_idx] |
| ) |
|
|
|
|
|
|
|
|
| def _generate_triplets_gpu(self, num_triplets): |
| """Generate triplets using a more reliable approach with batch processing""" |
| logging.info("Generating triplets with batch approach...") |
| |
| triplets = [] |
| batch_size = 10 |
| num_batches = (num_triplets + batch_size - 1) // batch_size |
| |
| progress_bar = tqdm( |
| range(num_batches), |
| desc="Generating triplet batches", |
| bar_format='{l_bar}{bar:10}{r_bar}' |
| ) |
| |
| for _ in progress_bar: |
| batch_triplets = [] |
| |
| |
| for i in range(batch_size): |
| if len(triplets) >= num_triplets: |
| break |
| |
| |
| anchor_idx = random.randint(0, len(self.users)-1) |
| anchor_user = self.users[anchor_idx] |
| |
| |
| candidates = set() |
| for channel in self.pipeline._get_field_from_user(anchor_user, 'dmp_channels'): |
| candidates.update(self.channel_index.get(str(channel), [])) |
| for cluster in self.pipeline._get_field_from_user(anchor_user, 'dmp_clusters'): |
| candidates.update(self.cluster_index.get(str(cluster), [])) |
| |
| |
| candidates.discard(anchor_idx) |
| |
| |
| if candidates: |
| similarities = [] |
| for idx in list(candidates)[:50]: |
| sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline) |
| if sim > 0: |
| similarities.append((idx, sim)) |
| |
| if similarities: |
| similarities.sort(key=lambda x: x[1], reverse=True) |
| top_k = min(10, len(similarities)) |
| positive_idx = similarities[random.randint(0, top_k-1)][0] |
| else: |
| positive_idx = random.randint(0, len(self.users)-1) |
| else: |
| positive_idx = random.randint(0, len(self.users)-1) |
| |
| |
| attempts = 0 |
| negative_idx = None |
| |
| while attempts < 20 and negative_idx is None: |
| idx = random.randint(0, len(self.users)-1) |
| if idx != anchor_idx and idx != positive_idx: |
| sim = calculate_similarity(anchor_user, self.users[idx], self.pipeline) |
| if sim < 0.1: |
| negative_idx = idx |
| break |
| attempts += 1 |
| |
| if negative_idx is None: |
| negative_idx = random.randint(0, len(self.users)-1) |
| |
| batch_triplets.append((anchor_idx, positive_idx, negative_idx)) |
| |
| triplets.extend(batch_triplets) |
| |
| return triplets[:num_triplets] |
|
|
| def collate_batch(batch): |
| """Custom collate function to properly handle tensor dimensions""" |
| anchor_inputs, positive_inputs, negative_inputs = zip(*batch) |
| |
| def process_group_inputs(group_inputs): |
| processed = {} |
| for field in group_inputs[0].keys(): |
| |
| max_len = max(inputs[field].size(0) for inputs in group_inputs) |
| |
| |
| padded = torch.stack([ |
| torch.cat([ |
| inputs[field], |
| torch.zeros(max_len - inputs[field].size(0), dtype=torch.long) |
| ]) if inputs[field].size(0) < max_len else inputs[field][:max_len] |
| for inputs in group_inputs |
| ]) |
| |
| processed[field] = padded |
| |
| return processed |
| |
| |
| anchor_batch = process_group_inputs(anchor_inputs) |
| positive_batch = process_group_inputs(positive_inputs) |
| negative_batch = process_group_inputs(negative_inputs) |
| |
| return anchor_batch, positive_batch, negative_batch |
|
|
| |
| |
| |
|
|
| def train_user_embeddings(model, users_data, pipeline, num_epochs=10, batch_size=32, lr=0.001, save_dir=None, save_interval=2, num_triplets=150): |
| """Main training of the model with proper batch handling and incremental saving""" |
| model.train() |
| model.to(device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) |
| |
| |
| scheduler = torch.optim.lr_scheduler.StepLR( |
| optimizer, |
| step_size=2, |
| gamma=0.9 |
| ) |
| |
| |
| num_cpu_cores = max(1, min(32, os.cpu_count())) |
| logging.info(f"Using {num_cpu_cores} CPU cores for data processing") |
| |
| |
| dataset = UserSimilarityDataset( |
| pipeline, |
| users_data, |
| num_triplets=num_triplets, |
| num_workers=num_cpu_cores |
| ) |
| |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| collate_fn=collate_batch, |
| num_workers=0, |
| pin_memory=True |
| ) |
| |
| |
| criterion = torch.nn.TripletMarginLoss(margin=1.0) |
| |
| |
| epoch_pbar = tqdm( |
| range(num_epochs), |
| desc="Training Progress", |
| bar_format='{l_bar}{bar:10}{r_bar}' |
| ) |
| |
| try: |
| from torch.utils.tensorboard import SummaryWriter |
| log_dir = Path(save_dir) / "logs" if save_dir else Path("./logs") |
| log_dir.mkdir(exist_ok=True, parents=True) |
| writer = SummaryWriter(log_dir=log_dir) |
| tensorboard_available = True |
| except ImportError: |
| logging.warning("TensorBoard not available, skipping logging") |
| tensorboard_available = False |
| |
| for epoch in epoch_pbar: |
| total_loss = 0 |
| num_batches = 0 |
| |
| |
| |
| total_batches = len(dataloader) |
| update_freq = max(1, total_batches // 10) |
| batch_pbar = tqdm( |
| dataloader, |
| desc=f"Epoch {epoch+1}/{num_epochs}", |
| leave=False, |
| miniters=update_freq, |
| bar_format='{l_bar}{bar:10}{r_bar}', |
| disable=True |
| ) |
| |
| |
| epoch_progress = tqdm( |
| total=len(dataloader), |
| desc=f"Epoch {epoch+1}/{num_epochs}", |
| leave=True, |
| bar_format='{l_bar}{bar:10}{r_bar}' |
| ) |
|
|
| |
| for batch_idx, batch_inputs in enumerate(dataloader): |
| try: |
| |
| anchor_batch, positive_batch, negative_batch = batch_inputs |
| |
| |
| anchor_batch = {k: v.to(device) for k, v in anchor_batch.items()} |
| positive_batch = {k: v.to(device) for k, v in positive_batch.items()} |
| negative_batch = {k: v.to(device) for k, v in negative_batch.items()} |
| |
| |
| anchor_emb = model(anchor_batch) |
| positive_emb = model(positive_batch) |
| negative_emb = model(negative_batch) |
| |
| |
| loss = criterion(anchor_emb, positive_emb, negative_emb) |
| |
| |
| optimizer.zero_grad() |
| loss.backward() |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
| optimizer.step() |
| |
| total_loss += loss.item() |
| num_batches += 1 |
| |
| |
| update_interval = max(1, len(dataloader) // 10) |
| if (batch_idx + 1) % update_interval == 0 or batch_idx == len(dataloader) - 1: |
| |
| remaining = min(update_interval, len(dataloader) - epoch_progress.n) |
| epoch_progress.update(remaining) |
| |
| current_avg_loss = total_loss / num_batches |
| epoch_progress.set_postfix(avg_loss=f"{current_avg_loss:.4f}", |
| last_batch_loss=f"{loss.item():.4f}") |
| |
| except Exception as e: |
| logging.error(f"Error during batch processing: {str(e)}") |
| logging.error(f"Batch details: {str(e.__class__.__name__)}") |
| continue |
|
|
| |
| epoch_progress.close() |
|
|
| avg_loss = total_loss / max(1, num_batches) |
| |
| epoch_pbar.set_postfix(avg_loss=f"{avg_loss:.4f}") |
| |
| |
| if tensorboard_available: |
| writer.add_scalar('Loss/train', avg_loss, epoch) |
| |
| |
| scheduler.step() |
| |
| |
| if save_dir and (epoch + 1) % save_interval == 0: |
| checkpoint = { |
| 'epoch': epoch, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'loss': avg_loss, |
| 'scheduler_state_dict': scheduler.state_dict() |
| } |
| |
| save_path = Path(save_dir) / f'model_checkpoint_epoch_{epoch+1}.pth' |
| torch.save(checkpoint, save_path) |
| logging.info(f"Checkpoint saved at epoch {epoch+1}: {save_path}") |
| |
| if tensorboard_available: |
| writer.close() |
| |
| return model |
| |
| |
| |
| |
| def main(): |
| |
| output_dir = Path(OUTPUT_DIR) |
| |
| |
| cuda_available = torch.cuda.is_available() |
| logging.info(f"CUDA available: {cuda_available}") |
| if cuda_available: |
| logging.info(f"CUDA device: {torch.cuda.get_device_name(0)}") |
| logging.info(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB") |
| |
| |
| cpu_count = os.cpu_count() |
| memory_info = psutil.virtual_memory() |
| logging.info(f"CPU cores: {cpu_count}") |
| logging.info(f"System memory: {memory_info.total / 1e9:.2f} GB") |
| |
| |
| logging.info("Running with the following configuration:") |
| logging.info(f"- Number of triplets: {NUM_TRIPLETS}") |
| logging.info(f"- Number of epochs: {NUM_EPOCHS}") |
| logging.info(f"- Batch size: {BATCH_SIZE}") |
| logging.info(f"- Learning rate: {LEARNING_RATE}") |
| logging.info(f"- Output dimension: {OUTPUT_DIM}") |
| logging.info(f"- Data path: {DATA_PATH}") |
| logging.info(f"- Output directory: {OUTPUT_DIR}") |
| |
| |
| logging.info("Loading user data...") |
| try: |
| try: |
| |
| with open(DATA_PATH, 'r') as f: |
| json_data = json.load(f) |
| |
| |
| if isinstance(json_data, list): |
| users_data = json_data |
| elif isinstance(json_data, dict): |
| |
| users_data = [json_data] |
| else: |
| raise ValueError("Unsupported JSON format") |
| |
| except json.JSONDecodeError: |
| |
| logging.info("Detected possible non-standard JSON format, attempting correction...") |
| with open(DATA_PATH, 'r') as f: |
| text = f.read().strip() |
| |
| |
| if not text.startswith('['): |
| text = '[' + text |
| if not text.endswith(']'): |
| text = text + ']' |
| |
| |
| users_data = json.loads(text) |
| logging.info("JSON format successfully corrected") |
| |
| logging.info(f"Loaded {len(users_data)} records") |
| except FileNotFoundError: |
| logging.error(f"File {DATA_PATH} not found!") |
| return |
| except Exception as e: |
| logging.error(f"Unable to load file: {str(e)}") |
| return |
|
|
| |
| logging.info("Initializing pipeline...") |
| pipeline = UserEmbeddingPipeline( |
| output_dim=OUTPUT_DIM, |
| max_sequence_length=MAX_SEQ_LENGTH |
| ) |
| |
| |
| logging.info("Building vocabularies...") |
| try: |
| pipeline.build_vocabularies(users_data) |
| vocab_sizes = {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()} |
| logging.info(f"Vocabulary sizes: {vocab_sizes}") |
| except Exception as e: |
| logging.error(f"Error building vocabularies: {str(e)}") |
| return |
|
|
| |
| logging.info("Initializing model...") |
| try: |
| pipeline.initialize_model() |
| logging.info("Model initialized successfully") |
| except Exception as e: |
| logging.error(f"Error initializing model: {str(e)}") |
| return |
|
|
| |
| logging.info("Starting training...") |
| try: |
| |
| model_dir = output_dir / "model_checkpoints" |
| model_dir.mkdir(exist_ok=True, parents=True) |
| |
| model = train_user_embeddings( |
| pipeline.model, |
| users_data, |
| pipeline, |
| num_epochs=NUM_EPOCHS, |
| batch_size=BATCH_SIZE, |
| lr=LEARNING_RATE, |
| save_dir=model_dir, |
| save_interval=SAVE_INTERVAL, |
| num_triplets=NUM_TRIPLETS |
| ) |
| logging.info("Training completed") |
| pipeline.model = model |
| |
| |
| logging.info("Saving model...") |
| |
| |
| output_dir.mkdir(exist_ok=True) |
| |
| |
| model_path = output_dir / 'model.pth' |
| |
| |
| checkpoint = { |
| 'model_state_dict': pipeline.model.state_dict(), |
| 'vocab_maps': pipeline.vocab_maps, |
| 'embedding_dims': pipeline.embedding_dims, |
| 'output_dim': pipeline.output_dim, |
| 'max_sequence_length': pipeline.max_sequence_length |
| } |
| |
| |
| torch.save(checkpoint, model_path) |
| |
| logging.info(f"Model saved to: {model_path}") |
| |
| |
| config_info = { |
| 'model_type': 'UserEmbeddingModel', |
| 'vocab_sizes': {field: len(vocab) for field, vocab in pipeline.vocab_maps.items()}, |
| 'embedding_dims': pipeline.embedding_dims, |
| 'output_dim': pipeline.output_dim, |
| 'max_sequence_length': pipeline.max_sequence_length, |
| 'padded_fields': list(pipeline.model.padded_fields), |
| 'fields': pipeline.fields |
| } |
| |
| config_path = output_dir / 'model_config.json' |
| with open(config_path, 'w') as f: |
| json.dump(config_info, f, indent=2) |
| |
| logging.info(f"Model configuration saved to: {config_path}") |
| |
| |
| save_hf = os.environ.get("SAVE_HF_FORMAT", "false").lower() == "true" |
| if save_hf: |
| logging.info("Saving in HuggingFace format...") |
| |
| hf_dir = output_dir / 'huggingface' |
| hf_dir.mkdir(exist_ok=True) |
| |
| |
| torch.save(pipeline.model.state_dict(), hf_dir / 'pytorch_model.bin') |
| |
| |
| with open(hf_dir / 'config.json', 'w') as f: |
| json.dump(config_info, f, indent=2) |
| |
| logging.info(f"Model saved in HuggingFace format to: {hf_dir}") |
| |
| |
| hf_repo_id = os.environ.get("HF_REPO_ID") |
| hf_token = os.environ.get("HF_TOKEN") |
| |
| if save_hf and hf_repo_id and hf_token: |
| try: |
| from huggingface_hub import HfApi |
| |
| logging.info(f"Pushing model to HuggingFace: {hf_repo_id}") |
| api = HfApi() |
| |
| |
| api.create_repo( |
| repo_id=hf_repo_id, |
| token=hf_token, |
| exist_ok=True, |
| private=True |
| ) |
| |
| |
| for file_path in (output_dir / "huggingface").glob("**/*"): |
| if file_path.is_file(): |
| api.upload_file( |
| path_or_fileobj=str(file_path), |
| path_in_repo=file_path.relative_to(output_dir / "huggingface"), |
| repo_id=hf_repo_id, |
| token=hf_token |
| ) |
| |
| logging.info(f"Model successfully pushed to HuggingFace: {hf_repo_id}") |
| except Exception as e: |
| logging.error(f"Error pushing to HuggingFace: {str(e)}") |
| |
| except Exception as e: |
| logging.error(f"Error during training or saving: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| return |
|
|
| logging.info("Process completed successfully!") |
|
|
| if __name__ == "__main__": |
| main() |