Adherence commited on
Commit
34257f6
·
verified ·
1 Parent(s): 1a2c8d6

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +228 -0
app.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Nuremberg Trials AI - RAG-powered Q&A system
3
+ Deployed on HuggingFace Spaces
4
+ """
5
+
6
+ import json
7
+ import gradio as gr
8
+ import numpy as np
9
+ import faiss
10
+ from sentence_transformers import SentenceTransformer
11
+ from huggingface_hub import hf_hub_download, InferenceClient
12
+ from datasets import load_dataset
13
+
14
+ # Configuration
15
+ DATASET_ID = "Adherence/nuremberg-trials-rag"
16
+ EMBEDDING_MODEL = "all-MiniLM-L6-v2"
17
+ LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
18
+ TOP_K = 5
19
+
20
+
21
+ class NurembergRAG:
22
+ def __init__(self):
23
+ self.index = None
24
+ self.chunks = None
25
+ self.model = None
26
+ self.llm_client = None
27
+
28
+ def load(self):
29
+ """Load RAG components from HuggingFace."""
30
+ print("Loading Nuremberg Trials RAG system...")
31
+
32
+ # Load embedding model
33
+ print(" Loading embedding model...")
34
+ self.model = SentenceTransformer(EMBEDDING_MODEL)
35
+
36
+ # Load chunks from dataset
37
+ print(" Loading document chunks...")
38
+ dataset = load_dataset(DATASET_ID, split="train")
39
+ self.chunks = [
40
+ {"text": row["text"], "source": row["source"]}
41
+ for row in dataset
42
+ ]
43
+
44
+ # Load FAISS index
45
+ print(" Loading FAISS index...")
46
+ index_path = hf_hub_download(
47
+ repo_id=DATASET_ID,
48
+ filename="faiss_index.bin",
49
+ repo_type="dataset"
50
+ )
51
+ self.index = faiss.read_index(index_path)
52
+
53
+ # Initialize LLM client (free inference API)
54
+ print(" Initializing LLM client...")
55
+ self.llm_client = InferenceClient(model=LLM_MODEL)
56
+
57
+ print(f" Loaded {len(self.chunks)} document chunks")
58
+ print("Ready!")
59
+
60
+ def search(self, query: str, top_k: int = TOP_K):
61
+ """Search for relevant chunks."""
62
+ query_embedding = self.model.encode([query], convert_to_numpy=True)
63
+ distances, indices = self.index.search(
64
+ query_embedding.astype(np.float32), top_k
65
+ )
66
+
67
+ results = []
68
+ for idx, distance in zip(indices[0], distances[0]):
69
+ if idx < len(self.chunks):
70
+ chunk = self.chunks[idx]
71
+ similarity = 1 / (1 + distance)
72
+ results.append((chunk, similarity))
73
+
74
+ return results
75
+
76
+ def generate_answer(self, question: str, context: str) -> str:
77
+ """Generate answer using LLM with retrieved context."""
78
+ prompt = f"""You are an expert on the Nuremberg Trials. Answer the question based ONLY on the provided context from historical documents. If the context doesn't contain enough information, say so.
79
+
80
+ Context from Nuremberg Trial documents:
81
+ {context}
82
+
83
+ Question: {question}
84
+
85
+ Answer (be specific and cite sources when possible):"""
86
+
87
+ try:
88
+ response = self.llm_client.text_generation(
89
+ prompt,
90
+ max_new_tokens=500,
91
+ temperature=0.3,
92
+ do_sample=True,
93
+ )
94
+ return response
95
+ except Exception as e:
96
+ return f"Error generating answer: {str(e)}"
97
+
98
+ def query(self, question: str) -> tuple:
99
+ """Full RAG pipeline: retrieve + generate."""
100
+ if not question.strip():
101
+ return "Please enter a question.", ""
102
+
103
+ # Retrieve relevant passages
104
+ results = self.search(question, TOP_K)
105
+
106
+ if not results:
107
+ return "No relevant information found.", ""
108
+
109
+ # Format context for LLM
110
+ context_parts = []
111
+ sources_md = []
112
+
113
+ for i, (chunk, score) in enumerate(results, 1):
114
+ context_parts.append(f"[{i}] {chunk['text'][:1000]}")
115
+ sources_md.append(
116
+ f"**[{i}] {chunk['source']}** (relevance: {score:.0%})\n\n"
117
+ f"{chunk['text'][:500]}..."
118
+ )
119
+
120
+ context = "\n\n".join(context_parts)
121
+
122
+ # Generate answer
123
+ answer = self.generate_answer(question, context)
124
+
125
+ # Format sources
126
+ sources = "\n\n---\n\n".join(sources_md)
127
+
128
+ return answer, sources
129
+
130
+
131
+ # Initialize RAG system
132
+ print("Initializing Nuremberg Trials AI...")
133
+ rag = NurembergRAG()
134
+ rag.load()
135
+
136
+
137
+ def answer_question(question: str) -> tuple:
138
+ """Gradio interface function."""
139
+ return rag.query(question)
140
+
141
+
142
+ # Example questions
143
+ examples = [
144
+ "How many defendants were sentenced to death at Nuremberg?",
145
+ "What were the four counts in the Nuremberg indictment?",
146
+ "Who was the chief prosecutor for the United States?",
147
+ "What happened to Hermann Goering?",
148
+ "What was the legal basis for the Nuremberg trials?",
149
+ "Who were the judges at Nuremberg?",
150
+ "What was the verdict for Albert Speer?",
151
+ "What were the crimes against humanity?",
152
+ ]
153
+
154
+ # Build Gradio interface
155
+ with gr.Blocks(
156
+ title="Nuremberg Trials AI",
157
+ theme=gr.themes.Soft(),
158
+ ) as demo:
159
+ gr.Markdown(
160
+ """
161
+ # Nuremberg Trials AI
162
+
163
+ Ask questions about the Nuremberg Trials (1945-1946). This system uses
164
+ **Retrieval-Augmented Generation (RAG)** to search through 12,000+ passages from:
165
+
166
+ - **Harvard Law School Nuremberg Trials Project** - Full IMT transcript (17,268 pages)
167
+ - **Yale Avalon Project** - Judgments, indictments, charter documents
168
+ - **Wikipedia** - Defendant biographies
169
+
170
+ All answers are grounded in actual historical documents with source citations.
171
+ """
172
+ )
173
+
174
+ with gr.Row():
175
+ with gr.Column(scale=2):
176
+ question_input = gr.Textbox(
177
+ label="Your Question",
178
+ placeholder="e.g., How many defendants were sentenced to death?",
179
+ lines=2,
180
+ )
181
+ submit_btn = gr.Button("Ask", variant="primary")
182
+
183
+ with gr.Column(scale=1):
184
+ gr.Examples(
185
+ examples=examples,
186
+ inputs=question_input,
187
+ label="Example Questions",
188
+ )
189
+
190
+ with gr.Row():
191
+ with gr.Column():
192
+ answer_output = gr.Textbox(
193
+ label="Answer",
194
+ lines=8,
195
+ show_copy_button=True,
196
+ )
197
+
198
+ with gr.Accordion("Source Documents", open=False):
199
+ sources_output = gr.Markdown(label="Retrieved Passages")
200
+
201
+ submit_btn.click(
202
+ fn=answer_question,
203
+ inputs=question_input,
204
+ outputs=[answer_output, sources_output],
205
+ )
206
+
207
+ question_input.submit(
208
+ fn=answer_question,
209
+ inputs=question_input,
210
+ outputs=[answer_output, sources_output],
211
+ )
212
+
213
+ gr.Markdown(
214
+ """
215
+ ---
216
+ **About**: This project aims to make the historical record of the Nuremberg Trials
217
+ accessible through AI. Built with sentence-transformers, FAISS, and Mistral-7B.
218
+
219
+ **Data Sources**: [Harvard Nuremberg Project](https://nuremberg.law.harvard.edu/) |
220
+ [Yale Avalon Project](https://avalon.law.yale.edu/subject_menus/imt.asp)
221
+
222
+ **Code**: [GitHub](https://github.com/your-repo) |
223
+ **Dataset**: [HuggingFace](https://huggingface.co/datasets/Adherence/nuremberg-trials-rag)
224
+ """
225
+ )
226
+
227
+ if __name__ == "__main__":
228
+ demo.launch()