Adherence commited on
Commit
df69d5c
·
verified ·
1 Parent(s): 4c96130

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -3,6 +3,7 @@ Nuremberg Trials AI - RAG-powered Q&A system
3
  Deployed on HuggingFace Spaces
4
  """
5
 
 
6
  import json
7
  import gradio as gr
8
  import numpy as np
@@ -14,9 +15,11 @@ from datasets import load_dataset
14
  # Configuration
15
  DATASET_ID = "Adherence/nuremberg-trials-rag"
16
  EMBEDDING_MODEL = "all-MiniLM-L6-v2"
17
- LLM_MODEL = "mistralai/Mistral-7B-Instruct-v0.3"
18
  TOP_K = 5
19
 
 
 
 
20
 
21
  class NurembergRAG:
22
  def __init__(self):
@@ -50,9 +53,13 @@ class NurembergRAG:
50
  )
51
  self.index = faiss.read_index(index_path)
52
 
53
- # Initialize LLM client (free inference API)
54
- print(" Initializing LLM client...")
55
- self.llm_client = InferenceClient(model=LLM_MODEL)
 
 
 
 
56
 
57
  print(f" Loaded {len(self.chunks)} document chunks")
58
  print("Ready!")
@@ -75,25 +82,29 @@ class NurembergRAG:
75
 
76
  def generate_answer(self, question: str, context: str) -> str:
77
  """Generate answer using LLM with retrieved context."""
78
- prompt = f"""You are an expert on the Nuremberg Trials. Answer the question based ONLY on the provided context from historical documents. If the context doesn't contain enough information, say so.
 
 
 
 
79
 
80
  Context from Nuremberg Trial documents:
81
  {context}
82
 
83
  Question: {question}
84
 
85
- Answer (be specific and cite sources when possible):"""
86
 
87
  try:
88
  response = self.llm_client.text_generation(
89
  prompt,
90
- max_new_tokens=500,
 
91
  temperature=0.3,
92
- do_sample=True,
93
  )
94
  return response
95
  except Exception as e:
96
- return f"Error generating answer: {str(e)}"
97
 
98
  def query(self, question: str) -> tuple:
99
  """Full RAG pipeline: retrieve + generate."""
@@ -114,7 +125,7 @@ Answer (be specific and cite sources when possible):"""
114
  context_parts.append(f"[{i}] {chunk['text'][:1000]}")
115
  sources_md.append(
116
  f"**[{i}] {chunk['source']}** (relevance: {score:.0%})\n\n"
117
- f"{chunk['text'][:500]}..."
118
  )
119
 
120
  context = "\n\n".join(context_parts)
 
3
  Deployed on HuggingFace Spaces
4
  """
5
 
6
+ import os
7
  import json
8
  import gradio as gr
9
  import numpy as np
 
15
  # Configuration
16
  DATASET_ID = "Adherence/nuremberg-trials-rag"
17
  EMBEDDING_MODEL = "all-MiniLM-L6-v2"
 
18
  TOP_K = 5
19
 
20
+ # Try to get HF token from environment (set in Space secrets)
21
+ HF_TOKEN = os.environ.get("HF_TOKEN")
22
+
23
 
24
  class NurembergRAG:
25
  def __init__(self):
 
53
  )
54
  self.index = faiss.read_index(index_path)
55
 
56
+ # Initialize LLM client if token available
57
+ if HF_TOKEN:
58
+ print(" Initializing LLM client...")
59
+ self.llm_client = InferenceClient(token=HF_TOKEN)
60
+ else:
61
+ print(" No HF_TOKEN - running in retrieval-only mode")
62
+ self.llm_client = None
63
 
64
  print(f" Loaded {len(self.chunks)} document chunks")
65
  print("Ready!")
 
82
 
83
  def generate_answer(self, question: str, context: str) -> str:
84
  """Generate answer using LLM with retrieved context."""
85
+ if not self.llm_client:
86
+ # No LLM available - provide retrieval-only summary
87
+ return "**Retrieved passages below contain the answer.** (LLM generation requires HF_TOKEN)"
88
+
89
+ prompt = f"""You are an expert on the Nuremberg Trials. Answer the question based ONLY on the provided context from historical documents. If the context doesn't contain enough information, say so. Be concise.
90
 
91
  Context from Nuremberg Trial documents:
92
  {context}
93
 
94
  Question: {question}
95
 
96
+ Answer:"""
97
 
98
  try:
99
  response = self.llm_client.text_generation(
100
  prompt,
101
+ model="HuggingFaceH4/zephyr-7b-beta",
102
+ max_new_tokens=400,
103
  temperature=0.3,
 
104
  )
105
  return response
106
  except Exception as e:
107
+ return f"**Retrieved passages below contain the answer.** (LLM error: {str(e)[:100]})"
108
 
109
  def query(self, question: str) -> tuple:
110
  """Full RAG pipeline: retrieve + generate."""
 
125
  context_parts.append(f"[{i}] {chunk['text'][:1000]}")
126
  sources_md.append(
127
  f"**[{i}] {chunk['source']}** (relevance: {score:.0%})\n\n"
128
+ f"{chunk['text'][:600]}..."
129
  )
130
 
131
  context = "\n\n".join(context_parts)