Spaces:
Sleeping
Sleeping
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -3,6 +3,7 @@ Nuremberg Trials AI - RAG-powered Q&A system
|
|
| 3 |
Deployed on HuggingFace Spaces
|
| 4 |
"""
|
| 5 |
|
|
|
|
| 6 |
import json
|
| 7 |
import gradio as gr
|
| 8 |
import numpy as np
|
|
@@ -14,9 +15,11 @@ from datasets import load_dataset
|
|
| 14 |
# Configuration
|
| 15 |
DATASET_ID = "Adherence/nuremberg-trials-rag"
|
| 16 |
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
| 17 |
-
LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
|
| 18 |
TOP_K = 5
|
| 19 |
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
class NurembergRAG:
|
| 22 |
def __init__(self):
|
|
@@ -50,9 +53,13 @@ class NurembergRAG:
|
|
| 50 |
)
|
| 51 |
self.index = faiss.read_index(index_path)
|
| 52 |
|
| 53 |
-
# Initialize LLM client
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
print(f" Loaded {len(self.chunks)} document chunks")
|
| 58 |
print("Ready!")
|
|
@@ -75,25 +82,29 @@ class NurembergRAG:
|
|
| 75 |
|
| 76 |
def generate_answer(self, question: str, context: str) -> str:
|
| 77 |
"""Generate answer using LLM with retrieved context."""
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
Context from Nuremberg Trial documents:
|
| 81 |
{context}
|
| 82 |
|
| 83 |
Question: {question}
|
| 84 |
|
| 85 |
-
Answer
|
| 86 |
|
| 87 |
try:
|
| 88 |
response = self.llm_client.text_generation(
|
| 89 |
prompt,
|
| 90 |
-
|
|
|
|
| 91 |
temperature=0.3,
|
| 92 |
-
do_sample=True,
|
| 93 |
)
|
| 94 |
return response
|
| 95 |
except Exception as e:
|
| 96 |
-
return f"
|
| 97 |
|
| 98 |
def query(self, question: str) -> tuple:
|
| 99 |
"""Full RAG pipeline: retrieve + generate."""
|
|
@@ -114,7 +125,7 @@ Answer (be specific and cite sources when possible):"""
|
|
| 114 |
context_parts.append(f"[{i}] {chunk['text'][:1000]}")
|
| 115 |
sources_md.append(
|
| 116 |
f"**[{i}] {chunk['source']}** (relevance: {score:.0%})\n\n"
|
| 117 |
-
f"{chunk['text'][:
|
| 118 |
)
|
| 119 |
|
| 120 |
context = "\n\n".join(context_parts)
|
|
|
|
| 3 |
Deployed on HuggingFace Spaces
|
| 4 |
"""
|
| 5 |
|
| 6 |
+
import os
|
| 7 |
import json
|
| 8 |
import gradio as gr
|
| 9 |
import numpy as np
|
|
|
|
| 15 |
# Configuration
|
| 16 |
DATASET_ID = "Adherence/nuremberg-trials-rag"
|
| 17 |
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
|
|
|
|
| 18 |
TOP_K = 5
|
| 19 |
|
| 20 |
+
# Try to get HF token from environment (set in Space secrets)
|
| 21 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 22 |
+
|
| 23 |
|
| 24 |
class NurembergRAG:
|
| 25 |
def __init__(self):
|
|
|
|
| 53 |
)
|
| 54 |
self.index = faiss.read_index(index_path)
|
| 55 |
|
| 56 |
+
# Initialize LLM client if token available
|
| 57 |
+
if HF_TOKEN:
|
| 58 |
+
print(" Initializing LLM client...")
|
| 59 |
+
self.llm_client = InferenceClient(token=HF_TOKEN)
|
| 60 |
+
else:
|
| 61 |
+
print(" No HF_TOKEN - running in retrieval-only mode")
|
| 62 |
+
self.llm_client = None
|
| 63 |
|
| 64 |
print(f" Loaded {len(self.chunks)} document chunks")
|
| 65 |
print("Ready!")
|
|
|
|
| 82 |
|
| 83 |
def generate_answer(self, question: str, context: str) -> str:
|
| 84 |
"""Generate answer using LLM with retrieved context."""
|
| 85 |
+
if not self.llm_client:
|
| 86 |
+
# No LLM available - provide retrieval-only summary
|
| 87 |
+
return "**Retrieved passages below contain the answer.** (LLM generation requires HF_TOKEN)"
|
| 88 |
+
|
| 89 |
+
prompt = f"""You are an expert on the Nuremberg Trials. Answer the question based ONLY on the provided context from historical documents. If the context doesn't contain enough information, say so. Be concise.
|
| 90 |
|
| 91 |
Context from Nuremberg Trial documents:
|
| 92 |
{context}
|
| 93 |
|
| 94 |
Question: {question}
|
| 95 |
|
| 96 |
+
Answer:"""
|
| 97 |
|
| 98 |
try:
|
| 99 |
response = self.llm_client.text_generation(
|
| 100 |
prompt,
|
| 101 |
+
model="HuggingFaceH4/zephyr-7b-beta",
|
| 102 |
+
max_new_tokens=400,
|
| 103 |
temperature=0.3,
|
|
|
|
| 104 |
)
|
| 105 |
return response
|
| 106 |
except Exception as e:
|
| 107 |
+
return f"**Retrieved passages below contain the answer.** (LLM error: {str(e)[:100]})"
|
| 108 |
|
| 109 |
def query(self, question: str) -> tuple:
|
| 110 |
"""Full RAG pipeline: retrieve + generate."""
|
|
|
|
| 125 |
context_parts.append(f"[{i}] {chunk['text'][:1000]}")
|
| 126 |
sources_md.append(
|
| 127 |
f"**[{i}] {chunk['source']}** (relevance: {score:.0%})\n\n"
|
| 128 |
+
f"{chunk['text'][:600]}..."
|
| 129 |
)
|
| 130 |
|
| 131 |
context = "\n\n".join(context_parts)
|