| import numpy as np |
| import torch |
| from transformers import BertForQuestionAnswering, BertTokenizerFast |
|
|
| |
| MODEL_DIR = "model" |
| MAX_LENGTH = 384 |
| DOC_STRIDE = 128 |
| N_BEST = 20 |
| MAX_ANS_LEN = 30 |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR) |
| model = BertForQuestionAnswering.from_pretrained(MODEL_DIR).to(DEVICE) |
| model.eval() |
| print(f"✅ Model loaded on {DEVICE}") |
|
|
| def answer_question(question: str, context: str) -> dict: |
| inputs = tokenizer( |
| question, |
| context, |
| max_length=MAX_LENGTH, |
| truncation="only_second", |
| stride=DOC_STRIDE, |
| return_overflowing_tokens=True, |
| return_offsets_mapping=True, |
| padding="max_length", |
| return_tensors="pt", |
| ) |
|
|
| offset_mapping = inputs.pop("offset_mapping") |
| sample_map = inputs.pop("overflow_to_sample_mapping") |
| sequence_ids = [inputs.sequence_ids(i) for i in range(len(inputs["input_ids"]))] |
|
|
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model(**inputs) |
|
|
| start_logits = outputs.start_logits.cpu().numpy() |
| end_logits = outputs.end_logits.cpu().numpy() |
|
|
| candidates = [] |
|
|
| for chunk_idx in range(len(start_logits)): |
| offsets = offset_mapping[chunk_idx].numpy() |
| seq_ids = sequence_ids[chunk_idx] |
|
|
| s_indexes = np.argsort(start_logits[chunk_idx])[-1:-N_BEST-1:-1] |
| e_indexes = np.argsort(end_logits[chunk_idx])[-1:-N_BEST-1:-1] |
|
|
| for s in s_indexes: |
| for e in e_indexes: |
| if seq_ids[s] != 1 or seq_ids[e] != 1: |
| continue |
| if e < s or e - s + 1 > MAX_ANS_LEN: |
| continue |
| candidates.append({ |
| "score": float(start_logits[chunk_idx][s] + end_logits[chunk_idx][e]), |
| "text": context[offsets[s][0]: offsets[e][1]], |
| "start": int(offsets[s][0]), |
| "end": int(offsets[e][1]), |
| }) |
|
|
| if not candidates: |
| return {"answer": "No answer found.", "score": -999, "start": -1, "end": -1} |
|
|
| best = max(candidates, key=lambda x: x["score"]) |
| return { |
| "answer": best["text"], |
| "score": round(best["score"], 4), |
| "start": best["start"], |
| "end": best["end"], |
| } |
|
|
|
|
| def ask(question: str, context: str): |
| result = answer_question(question, context) |
| print(f"❓ Question: {question}") |
| print(f"💬 Answer : {result['answer']}") |
| print(f"📊 Score : {result['score']}") |
| print(f"📍 Position: Char {result['start']}–{result['end']}") |
| print("-" * 60) |
|
|
|
|
|
|
| ctx1 = """ |
| The Amazon rainforest, also known as Amazonia, is a moist broadleaf |
| tropical rainforest in the Amazon biome that covers most of the Amazon |
| basin of South America. This basin encompasses 7,000,000 km² of which |
| 5,500,000 km² are covered by the rainforest. The majority of the forest |
| is contained within Brazil, with 60% of the rainforest. |
| """ |
| ask("How much of the Amazon rainforest is in Brazil?", ctx1) |
|
|
| ctx2 = """ |
| The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars |
| in Paris, France. It was constructed from 1887 to 1889 as the centerpiece |
| of the 1889 World's Fair. The tower is 330 metres tall and is the tallest |
| structure in Paris. |
| """ |
| ask("When was the Eiffel Tower built?", ctx2) |
|
|
| ctx3 = """ |
| Python is a high-level, general-purpose programming language. Its design |
| philosophy emphasizes code readability with the use of significant indentation. |
| Python is dynamically typed and garbage-collected. It supports multiple |
| programming paradigms, including structured, object-oriented and functional |
| programming. It was created by Guido van Rossum and first released in 1991. |
| Python consistently ranks as one of the most popular programming languages. |
| It is widely used in data science, machine learning, web development, and |
| automation. The Python Package Index (PyPI) hosts hundreds of thousands of |
| third-party modules. The standard library is very extensive, offering tools |
| suited to many tasks. |
| """ * 3 |
|
|
| ask("When was Python first released?", ctx3) |
|
|
| print("\n" + "=" * 60) |
| print("🎮 Interactive mode – stop with 'quit'") |
| print("=" * 60) |
|
|
| context_interactive = input("📄 Input context:\n> ").strip() |
| while True: |
| q = input("\n❓ Question (or type 'quit'): ").strip() |
| if q.lower() == "quit": |
| print("👋 Bye.") |
| break |
| if not q: |
| continue |
| ask(q, context_interactive) |