Excerp_v1 / use.py
LH-Tech-AI's picture
Create use.py
d80d74f verified
import numpy as np
import torch
from transformers import BertForQuestionAnswering, BertTokenizerFast
# ── Config ───────────────────────────────────────────────────
MODEL_DIR = "model"
MAX_LENGTH = 384
DOC_STRIDE = 128
N_BEST = 20
MAX_ANS_LEN = 30
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained(MODEL_DIR)
model = BertForQuestionAnswering.from_pretrained(MODEL_DIR).to(DEVICE)
model.eval()
print(f"✅ Model loaded on {DEVICE}")
def answer_question(question: str, context: str) -> dict:
inputs = tokenizer(
question,
context,
max_length=MAX_LENGTH,
truncation="only_second",
stride=DOC_STRIDE,
return_overflowing_tokens=True,
return_offsets_mapping=True,
padding="max_length",
return_tensors="pt",
)
offset_mapping = inputs.pop("offset_mapping") # (n_chunks, seq_len, 2)
sample_map = inputs.pop("overflow_to_sample_mapping")
sequence_ids = [inputs.sequence_ids(i) for i in range(len(inputs["input_ids"]))]
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
start_logits = outputs.start_logits.cpu().numpy() # (n_chunks, seq_len)
end_logits = outputs.end_logits.cpu().numpy()
candidates = []
for chunk_idx in range(len(start_logits)):
offsets = offset_mapping[chunk_idx].numpy()
seq_ids = sequence_ids[chunk_idx]
s_indexes = np.argsort(start_logits[chunk_idx])[-1:-N_BEST-1:-1]
e_indexes = np.argsort(end_logits[chunk_idx])[-1:-N_BEST-1:-1]
for s in s_indexes:
for e in e_indexes:
if seq_ids[s] != 1 or seq_ids[e] != 1:
continue
if e < s or e - s + 1 > MAX_ANS_LEN:
continue
candidates.append({
"score": float(start_logits[chunk_idx][s] + end_logits[chunk_idx][e]),
"text": context[offsets[s][0]: offsets[e][1]],
"start": int(offsets[s][0]),
"end": int(offsets[e][1]),
})
if not candidates:
return {"answer": "No answer found.", "score": -999, "start": -1, "end": -1}
best = max(candidates, key=lambda x: x["score"])
return {
"answer": best["text"],
"score": round(best["score"], 4),
"start": best["start"],
"end": best["end"],
}
def ask(question: str, context: str):
result = answer_question(question, context)
print(f"❓ Question: {question}")
print(f"💬 Answer : {result['answer']}")
print(f"📊 Score : {result['score']}")
print(f"📍 Position: Char {result['start']}{result['end']}")
print("-" * 60)
ctx1 = """
The Amazon rainforest, also known as Amazonia, is a moist broadleaf
tropical rainforest in the Amazon biome that covers most of the Amazon
basin of South America. This basin encompasses 7,000,000 km² of which
5,500,000 km² are covered by the rainforest. The majority of the forest
is contained within Brazil, with 60% of the rainforest.
"""
ask("How much of the Amazon rainforest is in Brazil?", ctx1)
ctx2 = """
The Eiffel Tower is a wrought-iron lattice tower on the Champ de Mars
in Paris, France. It was constructed from 1887 to 1889 as the centerpiece
of the 1889 World's Fair. The tower is 330 metres tall and is the tallest
structure in Paris.
"""
ask("When was the Eiffel Tower built?", ctx2)
ctx3 = """
Python is a high-level, general-purpose programming language. Its design
philosophy emphasizes code readability with the use of significant indentation.
Python is dynamically typed and garbage-collected. It supports multiple
programming paradigms, including structured, object-oriented and functional
programming. It was created by Guido van Rossum and first released in 1991.
Python consistently ranks as one of the most popular programming languages.
It is widely used in data science, machine learning, web development, and
automation. The Python Package Index (PyPI) hosts hundreds of thousands of
third-party modules. The standard library is very extensive, offering tools
suited to many tasks.
""" * 3
ask("When was Python first released?", ctx3)
print("\n" + "=" * 60)
print("🎮 Interactive mode – stop with 'quit'")
print("=" * 60)
context_interactive = input("📄 Input context:\n> ").strip()
while True:
q = input("\n❓ Question (or type 'quit'): ").strip()
if q.lower() == "quit":
print("👋 Bye.")
break
if not q:
continue
ask(q, context_interactive)