| import json |
| import pickle |
| from tqdm import tqdm |
| import numpy as np |
|
|
| def load_tokenizer(): |
| """Load the CBOW tokenizer mappings.""" |
| with open('tkn_words_to_ids.pkl', 'rb') as f: |
| words_to_ids = pickle.load(f) |
| with open('tkn_ids_to_words.pkl', 'rb') as f: |
| ids_to_words = pickle.load(f) |
| return words_to_ids, ids_to_words |
|
|
| def tokenize_text(text, words_to_ids): |
| """Tokenize text using the CBOW tokenizer.""" |
| |
| words = text.lower().split() |
| |
| token_ids = [words_to_ids.get(word, 0) for word in words] |
| return token_ids |
|
|
| def process_triples(input_file, output_file): |
| """Process triples and tokenize queries and documents.""" |
| print("Loading tokenizer...") |
| words_to_ids, ids_to_words = load_tokenizer() |
| |
| print("Loading triples...") |
| with open(input_file, 'r') as f: |
| data = json.load(f) |
| |
| tokenized_data = { |
| 'train': [], |
| 'validation': [], |
| 'test': [] |
| } |
| |
| for split in ['train', 'validation', 'test']: |
| print(f"\nTokenizing {split} split...") |
| for triple in tqdm(data[split]): |
| query = triple['query'] |
| pos_doc = triple['positive_doc'] |
| neg_doc = triple['negative_doc'] |
| |
| |
| query_tokens = tokenize_text(query, words_to_ids) |
| pos_doc_tokens = tokenize_text(pos_doc, words_to_ids) |
| neg_doc_tokens = tokenize_text(neg_doc, words_to_ids) |
| |
| tokenized_data[split].append({ |
| 'query_tokens': query_tokens, |
| 'positive_document_tokens': pos_doc_tokens, |
| 'negative_document_tokens': neg_doc_tokens, |
| 'query': query, |
| 'positive_document': pos_doc, |
| 'negative_document': neg_doc |
| }) |
| |
| print("Saving tokenized triples...") |
| with open(output_file, 'w') as f: |
| json.dump(tokenized_data, f, indent=2) |
| |
| |
| for split in ['train', 'validation', 'test']: |
| print(f"\n{split.upper()} split:") |
| print(f"Number of tokenized triples: {len(tokenized_data[split])}") |
| if tokenized_data[split]: |
| sample = tokenized_data[split][0] |
| print("\nSample tokenized triple:") |
| print("Query tokens length:", len(sample['query_tokens'])) |
| print("Positive doc tokens length:", len(sample['positive_document_tokens'])) |
| print("Negative doc tokens length:", len(sample['negative_document_tokens'])) |
|
|
| if __name__ == "__main__": |
| input_file = "triples_small.json" |
| output_file = "tokenized_triples.json" |
| process_triples(input_file, output_file) |