| """ |
| This examples demonstrates the setup for Question-Answer-Retrieval. |
| |
| You can input a query or a question. The script then uses semantic search |
| to find relevant passages in Simple English Wikipedia (as it is smaller and fits better in RAM). |
| |
| As model, we use: nq-distilbert-base-v1 |
| |
| It was trained on the Natural Questions dataset, a dataset with real questions from Google Search |
| together with annotated data from Wikipedia providing the answer. For the passages, we encode the |
| Wikipedia article tile together with the individual text passages. |
| |
| Google Colab Example: https://colab.research.google.com/drive/11GunvCqJuebfeTlgbJWkIMT0xJH6PWF1?usp=sharing |
| """ |
| import json |
| from sentence_transformers import SentenceTransformer, util |
| import time |
| import gzip |
| import os |
| import torch |
| |
|
|
| |
| model_name = 'nq-distilbert-base-v1' |
| bi_encoder = SentenceTransformer(model_name) |
| top_k = 5 |
|
|
| |
| |
|
|
| wikipedia_filepath = 'data/simplewiki-2020-11-01.jsonl.gz' |
|
|
| if not os.path.exists(wikipedia_filepath): |
| util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01.jsonl.gz', wikipedia_filepath) |
|
|
| passages = [] |
| with gzip.open(wikipedia_filepath, 'rt', encoding='utf8') as fIn: |
| for line in fIn: |
| data = json.loads(line.strip()) |
| for paragraph in data['paragraphs']: |
| |
| passages.append([data['title'], paragraph]) |
| print(data['title']) |
| print(paragraph) |
| print("________+________") |
|
|
| |
| print("Passages:", len(passages)) |
|
|
| |
| |
| if model_name == 'nq-distilbert-base-v1': |
| embeddings_filepath = 'simplewiki-2020-11-01-nq-distilbert-base-v1.pt' |
| if not os.path.exists(embeddings_filepath): |
| util.http_get('http://sbert.net/datasets/simplewiki-2020-11-01-nq-distilbert-base-v1.pt', embeddings_filepath) |
|
|
| corpus_embeddings = torch.load(embeddings_filepath, map_location=torch.device('cpu')) |
| corpus_embeddings = corpus_embeddings.float() |
| if torch.cuda.is_available(): |
| corpus_embeddings = corpus_embeddings.to('cuda') |
| else: |
| corpus_embeddings = corpus_embeddings.to('cpu') |
| else: |
| corpus_embeddings = bi_encoder.encode(passages, convert_to_tensor=True, show_progress_bar=True) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| def search(query): |
|
|
| |
| start_time = time.time() |
| question_embedding = bi_encoder.encode(query, convert_to_tensor=True) |
| hits = util.semantic_search(question_embedding, corpus_embeddings, top_k=top_k) |
| hits = hits[0] |
|
|
| end_time = time.time() |
|
|
| |
| print("Input question:", query) |
| print("Results (after {:.3f} seconds):".format(end_time - start_time)) |
| for hit in hits: |
| print("\t{:.3f}\t{}".format(hit['score'], passages[hit['corpus_id']])) |
|
|
| print("\n\n========\n") |
|
|
| def main(): |
| query = input("Please enter a question: ") |
| search(query) |
|
|
| if __name__ == "__main__": |
|
|
| main() |