| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import pickle |
| import glob |
| import os |
| import json |
| import wandb |
| import yaml |
|
|
| def train(config=None): |
| with wandb.init(config=config): |
| config = wandb.config |
|
|
| |
| with open('cbow/tkn_words_to_ids.pkl', 'rb') as f: |
| words_to_ids = pickle.load(f) |
| vocab_size = len(words_to_ids) |
| embedding_dim = 128 |
|
|
| |
| checkpoint_files = glob.glob('cbow/checkpoints/*.pth') |
| latest_checkpoint = max(checkpoint_files, key=os.path.getctime) |
| state_dict = torch.load(latest_checkpoint,map_location=torch.device('cpu')) |
|
|
| |
| embedding_layer = nn.Embedding(vocab_size, embedding_dim) |
| embedding_layer.weight.data.copy_(state_dict['emb.weight']) |
| embedding_layer.weight.requires_grad = False |
|
|
| class QryTower(nn.Module): |
| def __init__(self, embedding_layer, hidden_size): |
| super().__init__() |
| self.embedding = embedding_layer |
| self.embedding.weight.requires_grad = False |
| self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True) |
|
|
| def forward(self, x): |
| if not x: |
| return None |
| x = torch.tensor(x, dtype=torch.long).unsqueeze(0) |
| embeds = self.embedding(x) |
| _, h_n = self.rnn(embeds) |
| return h_n.squeeze(0).squeeze(0) |
|
|
| class DocTower(nn.Module): |
| def __init__(self, embedding_layer, hidden_size): |
| super().__init__() |
| self.embedding = embedding_layer |
| self.embedding.weight.requires_grad = False |
| self.rnn = nn.GRU(input_size=self.embedding.embedding_dim, hidden_size=hidden_size, batch_first=True) |
|
|
| def forward(self, x): |
| if not x: |
| return None |
| x = torch.tensor(x, dtype=torch.long).unsqueeze(0) |
| embeds = self.embedding(x) |
| _, h_n = self.rnn(embeds) |
| return h_n.squeeze(0).squeeze(0) |
|
|
| qryTower = QryTower(embedding_layer, config.hidden_size) |
| docTower = DocTower(embedding_layer, config.hidden_size) |
|
|
| |
| with open('tokenized_triples.json', 'r') as f: |
| triples_data = json.load(f) |
|
|
| |
| params = list(qryTower.rnn.parameters()) + list(docTower.rnn.parameters()) |
| optimizer = torch.optim.Adam(params, lr=config.learning_rate) |
| num_epochs = config.num_epochs |
| margin = config.margin |
| print(f"\nTraining on all real triples from the train split with RNN towers for {num_epochs} epochs:\n") |
| for epoch in range(num_epochs): |
| total_loss = 0 |
| count = 0 |
| for triple in triples_data['train']: |
| qry_tokens = triple['query_tokens'] |
| pos_tokens = triple['positive_document_tokens'] |
| neg_tokens = triple['negative_document_tokens'] |
|
|
| qry = qryTower(qry_tokens) |
| pos = docTower(pos_tokens) |
| neg = docTower(neg_tokens) |
|
|
| if qry is not None and pos is not None and neg is not None: |
| dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0)) |
| dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0)) |
| dst_mrg = torch.tensor(margin) |
| loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg)) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| total_loss += loss.item() |
| count += 1 |
| avg_loss = total_loss / max(count,1) |
| print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}") |
| wandb.log({'epoch': epoch+1, 'avg_loss': avg_loss}) |
|
|
| |
| print("\nEvaluating on first 5 real triples after training:\n") |
| for i, triple in enumerate(triples_data['train'][:5]): |
| qry_tokens = triple['query_tokens'] |
| pos_tokens = triple['positive_document_tokens'] |
| neg_tokens = triple['negative_document_tokens'] |
| qry_text = triple['query'] |
| pos_text = triple['positive_document'] |
| neg_text = triple['negative_document'] |
|
|
| qry = qryTower(qry_tokens) |
| pos = docTower(pos_tokens) |
| neg = docTower(neg_tokens) |
|
|
| if qry is not None and pos is not None and neg is not None: |
| dst_pos = F.cosine_similarity(qry.unsqueeze(0), pos.unsqueeze(0)) |
| dst_neg = F.cosine_similarity(qry.unsqueeze(0), neg.unsqueeze(0)) |
| dst_mrg = torch.tensor(margin) |
| loss = torch.max(torch.tensor(0.0), dst_mrg - (dst_pos - dst_neg)) |
| print(f"Example {i+1}:") |
| print(f"Query: {qry_text}") |
| print(f"Positive doc: {pos_text[:100]}...") |
| print(f"Negative doc: {neg_text[:100]}...") |
| print(f"Cosine similarity (pos): {dst_pos.item():.4f}") |
| print(f"Cosine similarity (neg): {dst_neg.item():.4f}") |
| print(f"Triplet loss: {loss.item():.4f}\n") |
| else: |
| print(f"Example {i+1}: One of the inputs was empty, skipping this triple.\n") |
|
|
| if __name__ == "__main__": |
| import yaml |
| |
| |
| with open('sweep.yaml', 'r') as f: |
| sweep_config = yaml.safe_load(f) |
| |
| |
| default_config = { |
| 'learning_rate': sweep_config['parameters']['learning_rate']['values'][0], |
| 'margin': sweep_config['parameters']['margin']['values'][0], |
| 'num_epochs': sweep_config['parameters']['num_epochs']['value'], |
| 'num_triples': sweep_config['parameters']['num_triples']['values'][0], |
| 'hidden_size': sweep_config['parameters']['hidden_size']['values'][0] |
| } |
| |
| |
| train(config=default_config) |
|
|