ClauseGuard / chatbot.py
gaurv007's picture
v4.0: Add chatbot.py β€” OCR + RAG Chatbot + Clause Redlining
f0f9872 verified
raw
history blame
16.5 kB
"""
ClauseGuard β€” Contract Q&A Chatbot (RAG) v1.0
═══════════════════════════════════════════════
Architecture:
User asks question about their contract
↓
[1] Embed question with sentence-transformers (all-MiniLM-L6-v2)
↓
[2] Retrieve top-5 most relevant chunks from contract
↓
[3] Build prompt:
- System: ClauseGuard analysis results (clauses, entities, risk scores)
- Context: Retrieved contract chunks (≀2.5K tokens)
- User question
↓
[4] Stream response from LLM via HF Inference API
Key design:
β€’ Analyzed data (clauses, entities, risk scores) β†’ system prompt
β€’ Raw contract text β†’ RAG retrieval
β€’ This gives the model both structured analysis AND verbatim evidence
"""
import os
import re
import numpy as np
# ── Embedding model (soft-fail) ─────────────────────────────────────
_HAS_EMBEDDER = False
_embedder = None
try:
from sentence_transformers import SentenceTransformer
_HAS_EMBEDDER = True
except ImportError:
pass
# ── HF Inference Client (soft-fail) ─────────────────────────────────
_HAS_INFERENCE = False
_llm_client = None
try:
from huggingface_hub import InferenceClient
_HAS_INFERENCE = True
except ImportError:
pass
# ═══════════════════════════════════════════════════════════════════════
# MODEL LOADING
# ═══════════════════════════════════════════════════════════════════════
_chatbot_status = {"embedder": "not_loaded", "llm": "not_loaded"}
def _load_embedder():
"""Load sentence-transformers embedding model (lazy)."""
global _embedder, _chatbot_status
if _embedder is not None:
return _embedder
if not _HAS_EMBEDDER:
_chatbot_status["embedder"] = "unavailable"
return None
try:
print("[ClauseGuard Chat] Loading embedding model: all-MiniLM-L6-v2...")
_embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
_chatbot_status["embedder"] = "loaded"
print("[ClauseGuard Chat] Embedding model loaded")
return _embedder
except Exception as e:
_chatbot_status["embedder"] = f"failed: {e}"
print(f"[ClauseGuard Chat] Embedder load failed: {e}")
return None
def _get_llm_client():
"""Get or create HF Inference Client (lazy)."""
global _llm_client, _chatbot_status
if _llm_client is not None:
return _llm_client
if not _HAS_INFERENCE:
_chatbot_status["llm"] = "unavailable"
return None
try:
token = os.environ.get("HF_TOKEN", "")
_llm_client = InferenceClient(
provider="hf-inference",
api_key=token if token else None,
)
_chatbot_status["llm"] = "loaded"
print("[ClauseGuard Chat] HF Inference Client initialized")
return _llm_client
except Exception as e:
_chatbot_status["llm"] = f"failed: {e}"
print(f"[ClauseGuard Chat] LLM client init failed: {e}")
return None
def get_chatbot_status():
"""Return human-readable chatbot status."""
parts = []
for name, status in _chatbot_status.items():
icon = "βœ…" if status == "loaded" else "⚠️" if "failed" in status else "❌"
label = {"embedder": "Embeddings", "llm": "LLM API"}[name]
parts.append(f"{icon} {label}: {status}")
return " Β· ".join(parts)
# ═══════════════════════════════════════════════════════════════════════
# TEXT CHUNKING (sentence-preserving, ~300 tokens, no overlap)
# ═══════════════════════════════════════════════════════════════════════
def chunk_contract_text(text, target_chunk_size=300, min_chunk_size=50):
"""
Split contract text into chunks for RAG retrieval.
Sentence-preserving, ~300 tokens per chunk, 0% overlap.
Research (arxiv 2601.14123): overlap adds cost with zero benefit.
"""
if not text:
return []
# First split on paragraph boundaries
paragraphs = re.split(r'\n\n+', text)
chunks = []
current_chunk = ""
for para in paragraphs:
para = para.strip()
if not para:
continue
# Estimate word count (rough token proxy)
words_current = len(current_chunk.split())
words_para = len(para.split())
if words_current + words_para <= target_chunk_size:
current_chunk += ("\n\n" + para if current_chunk else para)
else:
# Current chunk is full enough β€” save it
if words_current >= min_chunk_size:
chunks.append(current_chunk.strip())
current_chunk = para
else:
# Current chunk too small β€” need to split the paragraph into sentences
sentences = re.split(r'(?<=[.!?])\s+(?=[A-Z])', para)
for sent in sentences:
words_current = len(current_chunk.split())
words_sent = len(sent.split())
if words_current + words_sent <= target_chunk_size:
current_chunk += (" " + sent if current_chunk else sent)
else:
if words_current >= min_chunk_size:
chunks.append(current_chunk.strip())
current_chunk = sent
# Don't forget the last chunk
if current_chunk.strip() and len(current_chunk.split()) >= min_chunk_size:
chunks.append(current_chunk.strip())
return chunks
# ═══════════════════════════════════════════════════════════════════════
# EMBEDDING & RETRIEVAL
# ═══════════════════════════════════════════════════════════════════════
def build_embeddings(chunks):
"""
Embed chunks using sentence-transformers.
Returns numpy array of shape (N, 384) or None if embedder unavailable.
"""
embedder = _load_embedder()
if embedder is None or not chunks:
return None
try:
embeddings = embedder.encode(
chunks,
normalize_embeddings=True,
batch_size=32,
show_progress_bar=False,
)
return embeddings # numpy array (N, 384)
except Exception as e:
print(f"[ClauseGuard Chat] Embedding error: {e}")
return None
def retrieve_chunks(query, chunks, embeddings, top_k=5):
"""
Retrieve top-k most relevant chunks for a query.
Uses cosine similarity (embeddings are L2-normalized β†’ dot product = cosine).
Context budget: top-5 chunks, ≀2.5K tokens.
"""
embedder = _load_embedder()
if embedder is None or embeddings is None or not chunks:
return []
try:
q_emb = embedder.encode([query], normalize_embeddings=True)
scores = (q_emb @ embeddings.T)[0]
top_indices = np.argsort(scores)[::-1][:top_k]
results = []
total_words = 0
max_words = 600 # ~2.5K tokens budget
for idx in top_indices:
chunk = chunks[idx]
chunk_words = len(chunk.split())
if total_words + chunk_words > max_words and results:
break
results.append({
"text": chunk,
"score": float(scores[idx]),
"index": int(idx),
})
total_words += chunk_words
return results
except Exception as e:
print(f"[ClauseGuard Chat] Retrieval error: {e}")
return []
# ═══════════════════════════════════════════════════════════════════════
# SYSTEM PROMPT BUILDER
# ═══════════════════════════════════════════════════════════════════════
def _build_system_prompt(analysis_result, retrieved_chunks):
"""
Build the system prompt with:
1. ClauseGuard analysis results (clauses, entities, risk scores) β€” NOT through RAG
2. Retrieved contract chunks β€” through RAG
"""
parts = []
parts.append("""You are ClauseGuard AI, a legal contract analysis assistant. You help users understand their contracts by answering questions based on the contract text and analysis results.
RULES:
- Answer ONLY based on the provided contract text and analysis. Never make up information.
- If the answer isn't in the provided context, say "I don't see that information in the analyzed contract."
- Cite specific clauses or sections when possible.
- Be concise but thorough. Use plain language, not legal jargon.
- Always end with: "⚠️ This is AI analysis, not legal advice. Consult an attorney for legal decisions."
""")
# Add analysis summary if available
if analysis_result:
risk = analysis_result.get("risk", {})
parts.append(f"""
═══ CONTRACT ANALYSIS SUMMARY ═══
Risk Score: {risk.get('score', 'N/A')}/100 (Grade {risk.get('grade', 'N/A')})
Risk Breakdown: {risk.get('breakdown', {})}
Total Clauses Analyzed: {analysis_result.get('metadata', {}).get('total_clauses', 'N/A')}
Flagged Clauses: {analysis_result.get('metadata', {}).get('flagged_clauses', 'N/A')}
""")
# Add detected clauses summary
clauses = analysis_result.get("clauses", [])
if clauses:
clause_summary = []
seen = set()
for c in clauses:
key = c["label"]
if key not in seen:
seen.add(key)
risk_level = c.get("risk", "LOW")
clause_summary.append(f" β€’ [{risk_level}] {key}: {c.get('description', '')}")
parts.append("═══ DETECTED CLAUSES ═══\n" + "\n".join(clause_summary[:20]))
# Add entities summary
entities = analysis_result.get("entities", [])
if entities:
entity_summary = []
seen = set()
for e in entities:
key = f"{e['type']}: {e['text']}"
if key not in seen and len(seen) < 15:
seen.add(key)
entity_summary.append(f" β€’ {e['type']}: {e['text']}")
parts.append("═══ EXTRACTED ENTITIES ═══\n" + "\n".join(entity_summary))
# Add contradictions
contradictions = analysis_result.get("contradictions", [])
if contradictions:
contra_summary = []
for c in contradictions:
contra_summary.append(f" β€’ [{c['type']}] {c['explanation']}")
parts.append("═══ CONTRADICTIONS / ISSUES ═══\n" + "\n".join(contra_summary))
# Add retrieved contract text
if retrieved_chunks:
context_text = "\n---\n".join(c["text"] for c in retrieved_chunks)
parts.append(f"""
═══ RELEVANT CONTRACT TEXT (Retrieved) ═══
{context_text}
""")
return "\n\n".join(parts)
# ═══════════════════════════════════════════════════════════════════════
# CHAT RESPONSE (Streaming)
# ═══════════════════════════════════════════════════════════════════════
# LLM model to use
_LLM_MODEL = "Qwen/Qwen2.5-7B-Instruct"
def chat_respond(message, history, chunks, embeddings, analysis_result):
"""
RAG chatbot response function for gr.ChatInterface.
Args:
message: User's question (str)
history: Chat history (list of dicts with role/content)
chunks: Contract text chunks (list of str)
embeddings: Chunk embeddings (numpy array or None)
analysis_result: Full analysis result dict (or None)
Yields:
Partial response string (streaming)
"""
# Validate inputs
if not chunks or embeddings is None:
yield ("⚠️ No contract loaded yet. Please upload and analyze a contract in the "
"**πŸ“„ Single Contract Analysis** tab first, then come back here to ask questions.")
return
if not message or not message.strip():
yield "Please ask a question about your contract."
return
# Step 1: Retrieve relevant chunks
retrieved = retrieve_chunks(message, chunks, embeddings, top_k=5)
# Step 2: Build system prompt with analysis + retrieved context
system_prompt = _build_system_prompt(analysis_result, retrieved)
# Step 3: Build message history for LLM
messages = [{"role": "system", "content": system_prompt}]
# Add recent history (last 6 turns to stay in context window)
if history:
for h in history[-6:]:
messages.append({"role": h["role"], "content": h["content"]})
messages.append({"role": "user", "content": message})
# Step 4: Stream response from LLM
client = _get_llm_client()
if client is None:
yield ("⚠️ LLM service unavailable. Please ensure `huggingface_hub` is installed "
"and `HF_TOKEN` is set.")
return
try:
stream = client.chat_completion(
model=_LLM_MODEL,
messages=messages,
max_tokens=1024,
stream=True,
temperature=0.3, # Low temperature for factual responses
)
partial = ""
for chunk in stream:
token = chunk.choices[0].delta.content or ""
partial += token
yield partial
except Exception as e:
error_msg = str(e)
if "rate limit" in error_msg.lower() or "429" in error_msg:
yield ("⚠️ Rate limit reached on the free HF Inference API. "
"Please wait a moment and try again.")
elif "401" in error_msg or "unauthorized" in error_msg.lower():
yield ("⚠️ Authentication error. Please set your HF_TOKEN in the Space settings.")
else:
yield f"⚠️ Error generating response: {error_msg}\n\nPlease try again."
# ═══════════════════════════════════════════════════════════════════════
# INDEXING HELPER (combines chunking + embedding)
# ═══════════════════════════════════════════════════════════════════════
def index_contract(text):
"""
Chunk and embed contract text for RAG retrieval.
Returns: (chunks, embeddings, status_message)
chunks: list of str
embeddings: numpy array or None
status_message: str
"""
if not text or len(text.strip()) < 50:
return [], None, "⚠️ No contract text to index"
chunks = chunk_contract_text(text)
if not chunks:
return [], None, "⚠️ Could not split contract into chunks"
embeddings = build_embeddings(chunks)
if embeddings is None:
return chunks, None, "⚠️ Embedding model unavailable β€” chatbot will not work"
return (
chunks,
embeddings,
f"βœ… Indexed {len(chunks)} chunks ({len(text)} chars) β€” Ready to chat!"
)