gaurv007 commited on
Commit
f0f9872
·
verified ·
1 Parent(s): b5350d6

v4.0: Add chatbot.py — OCR + RAG Chatbot + Clause Redlining

Browse files
Files changed (1) hide show
  1. chatbot.py +406 -0
chatbot.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ClauseGuard — Contract Q&A Chatbot (RAG) v1.0
3
+ ═══════════════════════════════════════════════
4
+ Architecture:
5
+ User asks question about their contract
6
+
7
+ [1] Embed question with sentence-transformers (all-MiniLM-L6-v2)
8
+
9
+ [2] Retrieve top-5 most relevant chunks from contract
10
+
11
+ [3] Build prompt:
12
+ - System: ClauseGuard analysis results (clauses, entities, risk scores)
13
+ - Context: Retrieved contract chunks (≤2.5K tokens)
14
+ - User question
15
+
16
+ [4] Stream response from LLM via HF Inference API
17
+
18
+ Key design:
19
+ • Analyzed data (clauses, entities, risk scores) → system prompt
20
+ • Raw contract text → RAG retrieval
21
+ • This gives the model both structured analysis AND verbatim evidence
22
+ """
23
+
24
+ import os
25
+ import re
26
+ import numpy as np
27
+
28
+ # ── Embedding model (soft-fail) ─────────────────────────────────────
29
+ _HAS_EMBEDDER = False
30
+ _embedder = None
31
+
32
+ try:
33
+ from sentence_transformers import SentenceTransformer
34
+ _HAS_EMBEDDER = True
35
+ except ImportError:
36
+ pass
37
+
38
+ # ── HF Inference Client (soft-fail) ─────────────────────────────────
39
+ _HAS_INFERENCE = False
40
+ _llm_client = None
41
+
42
+ try:
43
+ from huggingface_hub import InferenceClient
44
+ _HAS_INFERENCE = True
45
+ except ImportError:
46
+ pass
47
+
48
+ # ═══════════════════════════════════════════════════════════════════════
49
+ # MODEL LOADING
50
+ # ═══════════════════════════════════════════════════════════════════════
51
+
52
+ _chatbot_status = {"embedder": "not_loaded", "llm": "not_loaded"}
53
+
54
+ def _load_embedder():
55
+ """Load sentence-transformers embedding model (lazy)."""
56
+ global _embedder, _chatbot_status
57
+ if _embedder is not None:
58
+ return _embedder
59
+ if not _HAS_EMBEDDER:
60
+ _chatbot_status["embedder"] = "unavailable"
61
+ return None
62
+ try:
63
+ print("[ClauseGuard Chat] Loading embedding model: all-MiniLM-L6-v2...")
64
+ _embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
65
+ _chatbot_status["embedder"] = "loaded"
66
+ print("[ClauseGuard Chat] Embedding model loaded")
67
+ return _embedder
68
+ except Exception as e:
69
+ _chatbot_status["embedder"] = f"failed: {e}"
70
+ print(f"[ClauseGuard Chat] Embedder load failed: {e}")
71
+ return None
72
+
73
+
74
+ def _get_llm_client():
75
+ """Get or create HF Inference Client (lazy)."""
76
+ global _llm_client, _chatbot_status
77
+ if _llm_client is not None:
78
+ return _llm_client
79
+ if not _HAS_INFERENCE:
80
+ _chatbot_status["llm"] = "unavailable"
81
+ return None
82
+ try:
83
+ token = os.environ.get("HF_TOKEN", "")
84
+ _llm_client = InferenceClient(
85
+ provider="hf-inference",
86
+ api_key=token if token else None,
87
+ )
88
+ _chatbot_status["llm"] = "loaded"
89
+ print("[ClauseGuard Chat] HF Inference Client initialized")
90
+ return _llm_client
91
+ except Exception as e:
92
+ _chatbot_status["llm"] = f"failed: {e}"
93
+ print(f"[ClauseGuard Chat] LLM client init failed: {e}")
94
+ return None
95
+
96
+
97
+ def get_chatbot_status():
98
+ """Return human-readable chatbot status."""
99
+ parts = []
100
+ for name, status in _chatbot_status.items():
101
+ icon = "✅" if status == "loaded" else "⚠️" if "failed" in status else "❌"
102
+ label = {"embedder": "Embeddings", "llm": "LLM API"}[name]
103
+ parts.append(f"{icon} {label}: {status}")
104
+ return " · ".join(parts)
105
+
106
+
107
+ # ═══════════════════════════════════════════════════════════════════════
108
+ # TEXT CHUNKING (sentence-preserving, ~300 tokens, no overlap)
109
+ # ═══════════════════════════════════════════════════════════════════════
110
+
111
+ def chunk_contract_text(text, target_chunk_size=300, min_chunk_size=50):
112
+ """
113
+ Split contract text into chunks for RAG retrieval.
114
+ Sentence-preserving, ~300 tokens per chunk, 0% overlap.
115
+ Research (arxiv 2601.14123): overlap adds cost with zero benefit.
116
+ """
117
+ if not text:
118
+ return []
119
+
120
+ # First split on paragraph boundaries
121
+ paragraphs = re.split(r'\n\n+', text)
122
+ chunks = []
123
+ current_chunk = ""
124
+
125
+ for para in paragraphs:
126
+ para = para.strip()
127
+ if not para:
128
+ continue
129
+
130
+ # Estimate word count (rough token proxy)
131
+ words_current = len(current_chunk.split())
132
+ words_para = len(para.split())
133
+
134
+ if words_current + words_para <= target_chunk_size:
135
+ current_chunk += ("\n\n" + para if current_chunk else para)
136
+ else:
137
+ # Current chunk is full enough — save it
138
+ if words_current >= min_chunk_size:
139
+ chunks.append(current_chunk.strip())
140
+ current_chunk = para
141
+ else:
142
+ # Current chunk too small — need to split the paragraph into sentences
143
+ sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', para)
144
+ for sent in sentences:
145
+ words_current = len(current_chunk.split())
146
+ words_sent = len(sent.split())
147
+ if words_current + words_sent <= target_chunk_size:
148
+ current_chunk += (" " + sent if current_chunk else sent)
149
+ else:
150
+ if words_current >= min_chunk_size:
151
+ chunks.append(current_chunk.strip())
152
+ current_chunk = sent
153
+
154
+ # Don't forget the last chunk
155
+ if current_chunk.strip() and len(current_chunk.split()) >= min_chunk_size:
156
+ chunks.append(current_chunk.strip())
157
+
158
+ return chunks
159
+
160
+
161
+ # ═══════════════════════════════════════════════════════════════════════
162
+ # EMBEDDING & RETRIEVAL
163
+ # ═══════════════════════════════════════════════════════════════════════
164
+
165
+ def build_embeddings(chunks):
166
+ """
167
+ Embed chunks using sentence-transformers.
168
+ Returns numpy array of shape (N, 384) or None if embedder unavailable.
169
+ """
170
+ embedder = _load_embedder()
171
+ if embedder is None or not chunks:
172
+ return None
173
+ try:
174
+ embeddings = embedder.encode(
175
+ chunks,
176
+ normalize_embeddings=True,
177
+ batch_size=32,
178
+ show_progress_bar=False,
179
+ )
180
+ return embeddings # numpy array (N, 384)
181
+ except Exception as e:
182
+ print(f"[ClauseGuard Chat] Embedding error: {e}")
183
+ return None
184
+
185
+
186
+ def retrieve_chunks(query, chunks, embeddings, top_k=5):
187
+ """
188
+ Retrieve top-k most relevant chunks for a query.
189
+ Uses cosine similarity (embeddings are L2-normalized → dot product = cosine).
190
+ Context budget: top-5 chunks, ≤2.5K tokens.
191
+ """
192
+ embedder = _load_embedder()
193
+ if embedder is None or embeddings is None or not chunks:
194
+ return []
195
+
196
+ try:
197
+ q_emb = embedder.encode([query], normalize_embeddings=True)
198
+ scores = (q_emb @ embeddings.T)[0]
199
+ top_indices = np.argsort(scores)[::-1][:top_k]
200
+
201
+ results = []
202
+ total_words = 0
203
+ max_words = 600 # ~2.5K tokens budget
204
+
205
+ for idx in top_indices:
206
+ chunk = chunks[idx]
207
+ chunk_words = len(chunk.split())
208
+ if total_words + chunk_words > max_words and results:
209
+ break
210
+ results.append({
211
+ "text": chunk,
212
+ "score": float(scores[idx]),
213
+ "index": int(idx),
214
+ })
215
+ total_words += chunk_words
216
+
217
+ return results
218
+ except Exception as e:
219
+ print(f"[ClauseGuard Chat] Retrieval error: {e}")
220
+ return []
221
+
222
+
223
+ # ═══════════════════════════════════════════════════════════════════════
224
+ # SYSTEM PROMPT BUILDER
225
+ # ═══════════════════════════════════════════════════════════════════════
226
+
227
+ def _build_system_prompt(analysis_result, retrieved_chunks):
228
+ """
229
+ Build the system prompt with:
230
+ 1. ClauseGuard analysis results (clauses, entities, risk scores) — NOT through RAG
231
+ 2. Retrieved contract chunks — through RAG
232
+ """
233
+ parts = []
234
+
235
+ parts.append("""You are ClauseGuard AI, a legal contract analysis assistant. You help users understand their contracts by answering questions based on the contract text and analysis results.
236
+
237
+ RULES:
238
+ - Answer ONLY based on the provided contract text and analysis. Never make up information.
239
+ - If the answer isn't in the provided context, say "I don't see that information in the analyzed contract."
240
+ - Cite specific clauses or sections when possible.
241
+ - Be concise but thorough. Use plain language, not legal jargon.
242
+ - Always end with: "⚠️ This is AI analysis, not legal advice. Consult an attorney for legal decisions."
243
+ """)
244
+
245
+ # Add analysis summary if available
246
+ if analysis_result:
247
+ risk = analysis_result.get("risk", {})
248
+ parts.append(f"""
249
+ ═��═ CONTRACT ANALYSIS SUMMARY ═══
250
+ Risk Score: {risk.get('score', 'N/A')}/100 (Grade {risk.get('grade', 'N/A')})
251
+ Risk Breakdown: {risk.get('breakdown', {})}
252
+ Total Clauses Analyzed: {analysis_result.get('metadata', {}).get('total_clauses', 'N/A')}
253
+ Flagged Clauses: {analysis_result.get('metadata', {}).get('flagged_clauses', 'N/A')}
254
+ """)
255
+
256
+ # Add detected clauses summary
257
+ clauses = analysis_result.get("clauses", [])
258
+ if clauses:
259
+ clause_summary = []
260
+ seen = set()
261
+ for c in clauses:
262
+ key = c["label"]
263
+ if key not in seen:
264
+ seen.add(key)
265
+ risk_level = c.get("risk", "LOW")
266
+ clause_summary.append(f" • [{risk_level}] {key}: {c.get('description', '')}")
267
+ parts.append("═══ DETECTED CLAUSES ═══\n" + "\n".join(clause_summary[:20]))
268
+
269
+ # Add entities summary
270
+ entities = analysis_result.get("entities", [])
271
+ if entities:
272
+ entity_summary = []
273
+ seen = set()
274
+ for e in entities:
275
+ key = f"{e['type']}: {e['text']}"
276
+ if key not in seen and len(seen) < 15:
277
+ seen.add(key)
278
+ entity_summary.append(f" • {e['type']}: {e['text']}")
279
+ parts.append("═══ EXTRACTED ENTITIES ═══\n" + "\n".join(entity_summary))
280
+
281
+ # Add contradictions
282
+ contradictions = analysis_result.get("contradictions", [])
283
+ if contradictions:
284
+ contra_summary = []
285
+ for c in contradictions:
286
+ contra_summary.append(f" • [{c['type']}] {c['explanation']}")
287
+ parts.append("═══ CONTRADICTIONS / ISSUES ═══\n" + "\n".join(contra_summary))
288
+
289
+ # Add retrieved contract text
290
+ if retrieved_chunks:
291
+ context_text = "\n---\n".join(c["text"] for c in retrieved_chunks)
292
+ parts.append(f"""
293
+ ═══ RELEVANT CONTRACT TEXT (Retrieved) ═══
294
+ {context_text}
295
+ """)
296
+
297
+ return "\n\n".join(parts)
298
+
299
+
300
+ # ═══════════════════════════════════════════════════════════════════════
301
+ # CHAT RESPONSE (Streaming)
302
+ # ═══════════════════════════════════════════════════════════════════════
303
+
304
+ # LLM model to use
305
+ _LLM_MODEL = "Qwen/Qwen2.5-7B-Instruct"
306
+
307
+ def chat_respond(message, history, chunks, embeddings, analysis_result):
308
+ """
309
+ RAG chatbot response function for gr.ChatInterface.
310
+
311
+ Args:
312
+ message: User's question (str)
313
+ history: Chat history (list of dicts with role/content)
314
+ chunks: Contract text chunks (list of str)
315
+ embeddings: Chunk embeddings (numpy array or None)
316
+ analysis_result: Full analysis result dict (or None)
317
+
318
+ Yields:
319
+ Partial response string (streaming)
320
+ """
321
+ # Validate inputs
322
+ if not chunks or embeddings is None:
323
+ yield ("⚠️ No contract loaded yet. Please upload and analyze a contract in the "
324
+ "**📄 Single Contract Analysis** tab first, then come back here to ask questions.")
325
+ return
326
+
327
+ if not message or not message.strip():
328
+ yield "Please ask a question about your contract."
329
+ return
330
+
331
+ # Step 1: Retrieve relevant chunks
332
+ retrieved = retrieve_chunks(message, chunks, embeddings, top_k=5)
333
+
334
+ # Step 2: Build system prompt with analysis + retrieved context
335
+ system_prompt = _build_system_prompt(analysis_result, retrieved)
336
+
337
+ # Step 3: Build message history for LLM
338
+ messages = [{"role": "system", "content": system_prompt}]
339
+
340
+ # Add recent history (last 6 turns to stay in context window)
341
+ if history:
342
+ for h in history[-6:]:
343
+ messages.append({"role": h["role"], "content": h["content"]})
344
+
345
+ messages.append({"role": "user", "content": message})
346
+
347
+ # Step 4: Stream response from LLM
348
+ client = _get_llm_client()
349
+ if client is None:
350
+ yield ("⚠️ LLM service unavailable. Please ensure `huggingface_hub` is installed "
351
+ "and `HF_TOKEN` is set.")
352
+ return
353
+
354
+ try:
355
+ stream = client.chat_completion(
356
+ model=_LLM_MODEL,
357
+ messages=messages,
358
+ max_tokens=1024,
359
+ stream=True,
360
+ temperature=0.3, # Low temperature for factual responses
361
+ )
362
+ partial = ""
363
+ for chunk in stream:
364
+ token = chunk.choices[0].delta.content or ""
365
+ partial += token
366
+ yield partial
367
+ except Exception as e:
368
+ error_msg = str(e)
369
+ if "rate limit" in error_msg.lower() or "429" in error_msg:
370
+ yield ("⚠️ Rate limit reached on the free HF Inference API. "
371
+ "Please wait a moment and try again.")
372
+ elif "401" in error_msg or "unauthorized" in error_msg.lower():
373
+ yield ("⚠️ Authentication error. Please set your HF_TOKEN in the Space settings.")
374
+ else:
375
+ yield f"⚠️ Error generating response: {error_msg}\n\nPlease try again."
376
+
377
+
378
+ # ═══════════════════════════════════════════════════════════════════════
379
+ # INDEXING HELPER (combines chunking + embedding)
380
+ # ═══════════════════════════════════════════════════════════════════════
381
+
382
+ def index_contract(text):
383
+ """
384
+ Chunk and embed contract text for RAG retrieval.
385
+
386
+ Returns: (chunks, embeddings, status_message)
387
+ chunks: list of str
388
+ embeddings: numpy array or None
389
+ status_message: str
390
+ """
391
+ if not text or len(text.strip()) < 50:
392
+ return [], None, "⚠️ No contract text to index"
393
+
394
+ chunks = chunk_contract_text(text)
395
+ if not chunks:
396
+ return [], None, "⚠️ Could not split contract into chunks"
397
+
398
+ embeddings = build_embeddings(chunks)
399
+ if embeddings is None:
400
+ return chunks, None, "⚠️ Embedding model unavailable — chatbot will not work"
401
+
402
+ return (
403
+ chunks,
404
+ embeddings,
405
+ f"✅ Indexed {len(chunks)} chunks ({len(text)} chars) — Ready to chat!"
406
+ )