| import torch |
| import torch.nn as nn |
| import pickle |
| import glob |
| import json |
| import redis |
| import numpy as np |
| from tqdm import tqdm |
| import os |
| from dotenv import load_dotenv |
|
|
| load_dotenv() |
|
|
| |
| REDIS_HOST = os.environ.get('REDIS_HOST', 'your-redis-host') |
| REDIS_PORT = int(os.environ.get('REDIS_PORT', 12345)) |
| REDIS_PASSWORD = os.environ.get('REDIS_PASSWORD', 'your-redis-password') |
| r = redis.Redis( |
| host=REDIS_HOST, |
| port=REDIS_PORT, |
| password=REDIS_PASSWORD, |
| decode_responses=False |
| ) |
|
|
| |
| with open('cbow/tkn_words_to_ids.pkl', 'rb') as f: |
| words_to_ids = pickle.load(f) |
| vocab_size = len(words_to_ids) |
| embedding_dim = 128 |
|
|
| |
| checkpoint_files = glob.glob('cbow/checkpoints/*.pth') |
| latest_checkpoint = max(checkpoint_files, key=os.path.getctime) |
| state_dict = torch.load(latest_checkpoint, map_location='cpu') |
|
|
| embedding_layer = nn.Embedding(vocab_size, embedding_dim) |
| embedding_layer.weight.data.copy_(state_dict['emb.weight']) |
| embedding_layer.weight.requires_grad = False |
|
|
| |
| class DocTower(nn.Module): |
| def __init__(self, embedding_layer, hidden_size): |
| super().__init__() |
| self.embedding = embedding_layer |
| self.embedding.weight.requires_grad = False |
| self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True) |
|
|
| def forward(self, x): |
| if not x: |
| return None |
| x = torch.tensor(x, dtype=torch.long).unsqueeze(0) |
| embeds = self.embedding(x) |
| _, h_n = self.rnn(embeds) |
| return h_n.squeeze(0).squeeze(0) |
|
|
| |
| |
| hidden_size = 128 |
|
|
| docTower = DocTower(embedding_layer, hidden_size) |
| |
| docTower.eval() |
|
|
| |
| with open('tokenized_triples.json', 'r') as f: |
| triples_data = json.load(f) |
|
|
| |
| seen = set() |
| documents = [] |
| for split in ['train', 'validation', 'test']: |
| for triple in triples_data[split]: |
| doc_text = triple['positive_document'] |
| doc_tokens = tuple(triple['positive_document_tokens']) |
| if doc_tokens not in seen: |
| seen.add(doc_tokens) |
| documents.append((doc_tokens, doc_text)) |
|
|
| print(f"Found {len(documents)} unique positive documents.") |
|
|
| def save_doc_embedding_to_redis(doc_id, embedding, text): |
| r.hset(doc_id, mapping={ |
| 'embedding': embedding.astype(np.float32).tobytes(), |
| 'text': text, |
| 'doc_id': doc_id |
| }) |
|
|
| |
| docTower.eval() |
| with torch.no_grad(): |
| for idx, (doc_tokens, doc_text) in enumerate(tqdm(documents, desc='Saving doc embeddings to Redis')): |
| embedding = docTower(list(doc_tokens)).detach().cpu().numpy() |
| doc_id = f"doc:{idx}" |
| save_doc_embedding_to_redis(doc_id, embedding, doc_text) |
|
|
| print(f"Saved {len(documents)} doc embeddings to Redis Cloud.") |