| import json |
| import pickle |
| import torch |
| import torch.nn as nn |
| from tqdm import tqdm |
| import glob |
| import os |
| import redis |
| import numpy as np |
|
|
| |
| REDIS_HOST = 'your-redis-host' |
| REDIS_PORT = 12345 |
| REDIS_PASSWORD = 'your-redis-password' |
| INDEX_NAME = 'doc_index' |
| VECTOR_DIM = 128 |
|
|
| r = redis.Redis( |
| host=REDIS_HOST, |
| port=REDIS_PORT, |
| password=REDIS_PASSWORD, |
| decode_responses=False |
| ) |
|
|
| def load_latest_checkpoint(): |
| """Load the latest CBOW model checkpoint.""" |
| print("Loading latest CBOW checkpoint...") |
| checkpoint_files = glob.glob('cbow/checkpoints/*.pth') |
| if not checkpoint_files: |
| raise FileNotFoundError("No checkpoint files found in cbow/checkpoints/") |
| |
| |
| latest_checkpoint = max(checkpoint_files, key=os.path.getctime) |
| print(f"Using checkpoint: {latest_checkpoint}") |
| |
| |
| state_dict = torch.load(latest_checkpoint) |
| return state_dict |
|
|
| def load_tokenizer(): |
| """Load the CBOW tokenizer mappings.""" |
| print("Loading tokenizer...") |
| with open('cbow/tkn_words_to_ids.pkl', 'rb') as f: |
| words_to_ids = pickle.load(f) |
| with open('cbow/tkn_ids_to_words.pkl', 'rb') as f: |
| ids_to_words = pickle.load(f) |
| return words_to_ids, ids_to_words |
|
|
| def load_tokenized_triples(): |
| """Load the tokenized triples.""" |
| print("Loading tokenized triples...") |
| with open('tokenized_triples.json', 'r') as f: |
| data = json.load(f) |
| return data |
|
|
| def create_embedding_layer(state_dict, vocab_size, embedding_dim=128): |
| """Create embedding layer from CBOW weights.""" |
| embedding = nn.Embedding(vocab_size, embedding_dim) |
| |
| embedding.weight.data.copy_(state_dict['emb.weight']) |
| |
| embedding.weight.requires_grad = False |
| return embedding |
|
|
| def average_pool(tokens, embedding_layer): |
| """Create average pooled vector for a list of tokens.""" |
| |
| tokens_tensor = torch.tensor(tokens, dtype=torch.long) |
| |
| embeddings = embedding_layer(tokens_tensor) |
| |
| return torch.mean(embeddings, dim=0).detach().numpy() |
|
|
| def save_doc_embedding_to_redis(doc_id, embedding, text): |
| |
| r.hset(doc_id, mapping={ |
| 'embedding': embedding.astype(np.float32).tobytes(), |
| 'text': text, |
| 'doc_id': doc_id |
| }) |
|
|
| |
| |
|
|
| def process_triples(data, embedding_layer): |
| """Process triples and create average pooled vectors. Save positive doc embeddings to Redis.""" |
| processed_data = { |
| 'train': [], |
| 'validation': [], |
| 'test': [] |
| } |
| doc_counter = 0 |
| for split in ['train', 'validation', 'test']: |
| print(f"\nProcessing {split} split...") |
| for triple in tqdm(data[split]): |
| |
| query_vector = average_pool(triple['query_tokens'], embedding_layer) |
| pos_doc_vector = average_pool(triple['positive_document_tokens'], embedding_layer) |
| neg_doc_vector = average_pool(triple['negative_document_tokens'], embedding_layer) |
|
|
| |
| doc_id = f"doc:{doc_counter}" |
| save_doc_embedding_to_redis(doc_id, pos_doc_vector, triple['positive_document']) |
| doc_counter += 1 |
|
|
| processed_data[split].append({ |
| 'query_vector': query_vector.tolist(), |
| 'positive_document_vector': pos_doc_vector.tolist(), |
| 'negative_document_vector': neg_doc_vector.tolist(), |
| 'query': triple['query'], |
| 'positive_document': triple['positive_document'], |
| 'negative_document': triple['negative_document'] |
| }) |
| return processed_data |
|
|
| def main(): |
| |
| state_dict = load_latest_checkpoint() |
| words_to_ids, ids_to_words = load_tokenizer() |
| data = load_tokenized_triples() |
| |
| |
| vocab_size = len(words_to_ids) |
| embedding_layer = create_embedding_layer(state_dict, vocab_size) |
| |
| |
| processed_data = process_triples(data, embedding_layer) |
| |
| |
| print("\nSaving processed data...") |
| with open('triple_embeddings_cbow.json', 'w') as f: |
| json.dump(processed_data, f) |
| |
| |
| for split in ['train', 'validation', 'test']: |
| print(f"\n{split.upper()} split:") |
| print(f"Number of processed triples: {len(processed_data[split])}") |
| if processed_data[split]: |
| sample = processed_data[split][0] |
| print("\nSample vector shapes:") |
| print("Query vector shape:", len(sample['query_vector'])) |
| print("Positive doc vector shape:", len(sample['positive_document_vector'])) |
| print("Negative doc vector shape:", len(sample['negative_document_vector'])) |
|
|
| if __name__ == "__main__": |
| main() |