| """ |
| Guidelines RAG System - Retrieval-Augmented Generation for clinical guidelines |
| Uses FAISS for vector similarity search on chunked guideline PDFs. |
| """ |
|
|
| import os |
| import json |
| import re |
| from pathlib import Path |
| from typing import List, Dict, Optional, Tuple |
|
|
| import numpy as np |
|
|
| |
| GUIDELINES_DIR = Path(__file__).parent.parent / "guidelines" |
| INDEX_DIR = GUIDELINES_DIR / "index" |
| FAISS_INDEX_PATH = INDEX_DIR / "faiss.index" |
| CHUNKS_PATH = INDEX_DIR / "chunks.json" |
|
|
| |
| CHUNK_SIZE = 500 |
| CHUNK_OVERLAP = 50 |
|
|
|
|
| class GuidelinesRAG: |
| """ |
| RAG system for clinical guidelines. |
| Extracts text from PDFs, chunks it, creates embeddings, and provides search. |
| """ |
|
|
| def __init__(self): |
| self.index = None |
| self.chunks = [] |
| self.embedder = None |
| self.loaded = False |
|
|
| def _load_embedder(self): |
| """Load sentence transformer model for embeddings""" |
| if self.embedder is None: |
| from sentence_transformers import SentenceTransformer |
| self.embedder = SentenceTransformer('all-MiniLM-L6-v2') |
|
|
| def _extract_pdf_text(self, pdf_path: Path) -> str: |
| """Extract text from a PDF file""" |
| try: |
| import pdfplumber |
| text_parts = [] |
| with pdfplumber.open(pdf_path) as pdf: |
| for page in pdf.pages: |
| page_text = page.extract_text() |
| if page_text: |
| text_parts.append(page_text) |
| return "\n\n".join(text_parts) |
| except ImportError: |
| |
| from PyPDF2 import PdfReader |
| reader = PdfReader(pdf_path) |
| text_parts = [] |
| for page in reader.pages: |
| text = page.extract_text() |
| if text: |
| text_parts.append(text) |
| return "\n\n".join(text_parts) |
|
|
| def _clean_text(self, text: str) -> str: |
| """Clean extracted text""" |
| |
| text = re.sub(r'\s+', ' ', text) |
| |
| text = re.sub(r'\n\d+\s*\n', '\n', text) |
| |
| text = re.sub(r'(\w)-\s+(\w)', r'\1\2', text) |
| return text.strip() |
|
|
| def _extract_pdf_with_pages(self, pdf_path: Path) -> List[Tuple[str, int]]: |
| """Extract text from PDF with page numbers""" |
| try: |
| import pdfplumber |
| pages = [] |
| with pdfplumber.open(pdf_path) as pdf: |
| for i, page in enumerate(pdf.pages, 1): |
| page_text = page.extract_text() |
| if page_text: |
| pages.append((page_text, i)) |
| return pages |
| except ImportError: |
| from PyPDF2 import PdfReader |
| reader = PdfReader(pdf_path) |
| pages = [] |
| for i, page in enumerate(reader.pages, 1): |
| text = page.extract_text() |
| if text: |
| pages.append((text, i)) |
| return pages |
|
|
| def _chunk_text(self, text: str, source: str, page_num: int = 0) -> List[Dict]: |
| """ |
| Chunk text into overlapping segments. |
| Returns list of dicts with 'text', 'source', 'chunk_id', 'page'. |
| """ |
| |
| words = text.split() |
| chunk_words = int(CHUNK_SIZE * 0.75) |
| overlap_words = int(CHUNK_OVERLAP * 0.75) |
|
|
| chunks = [] |
| start = 0 |
| chunk_id = 0 |
|
|
| while start < len(words): |
| end = start + chunk_words |
| chunk_text = ' '.join(words[start:end]) |
|
|
| |
| if end < len(words): |
| last_period = chunk_text.rfind('.') |
| if last_period > len(chunk_text) * 0.7: |
| chunk_text = chunk_text[:last_period + 1] |
|
|
| chunks.append({ |
| 'text': chunk_text, |
| 'source': source, |
| 'chunk_id': chunk_id, |
| 'page': page_num |
| }) |
|
|
| start = end - overlap_words |
| chunk_id += 1 |
|
|
| return chunks |
|
|
| def build_index(self, force_rebuild: bool = False) -> bool: |
| """ |
| Build FAISS index from guideline PDFs. |
| Returns True if index was built, False if loaded from cache. |
| """ |
| |
| if not force_rebuild and FAISS_INDEX_PATH.exists() and CHUNKS_PATH.exists(): |
| return self.load_index() |
|
|
| print("Building guidelines index...") |
| self._load_embedder() |
|
|
| |
| INDEX_DIR.mkdir(parents=True, exist_ok=True) |
|
|
| |
| all_chunks = [] |
| pdf_files = list(GUIDELINES_DIR.glob("*.pdf")) |
|
|
| for pdf_path in pdf_files: |
| print(f" Processing: {pdf_path.name}") |
| pages = self._extract_pdf_with_pages(pdf_path) |
| pdf_chunks = 0 |
| for page_text, page_num in pages: |
| cleaned = self._clean_text(page_text) |
| chunks = self._chunk_text(cleaned, pdf_path.name, page_num) |
| all_chunks.extend(chunks) |
| pdf_chunks += len(chunks) |
| print(f" -> {pdf_chunks} chunks from {len(pages)} pages") |
|
|
| if not all_chunks: |
| print("No chunks extracted from PDFs!") |
| return False |
|
|
| self.chunks = all_chunks |
| print(f"Total chunks: {len(self.chunks)}") |
|
|
| |
| print("Generating embeddings...") |
| texts = [c['text'] for c in self.chunks] |
| embeddings = self.embedder.encode(texts, show_progress_bar=True) |
| embeddings = np.array(embeddings).astype('float32') |
|
|
| |
| import faiss |
| dimension = embeddings.shape[1] |
| self.index = faiss.IndexFlatIP(dimension) |
|
|
| |
| faiss.normalize_L2(embeddings) |
| self.index.add(embeddings) |
|
|
| |
| faiss.write_index(self.index, str(FAISS_INDEX_PATH)) |
| with open(CHUNKS_PATH, 'w') as f: |
| json.dump(self.chunks, f) |
|
|
| print(f"Index saved to {INDEX_DIR}") |
| self.loaded = True |
| return True |
|
|
| def load_index(self) -> bool: |
| """Load persisted FAISS index and chunks""" |
| if not FAISS_INDEX_PATH.exists() or not CHUNKS_PATH.exists(): |
| return False |
|
|
| import faiss |
| self.index = faiss.read_index(str(FAISS_INDEX_PATH)) |
|
|
| with open(CHUNKS_PATH, 'r') as f: |
| self.chunks = json.load(f) |
|
|
| self._load_embedder() |
| self.loaded = True |
| return True |
|
|
| def search(self, query: str, k: int = 5) -> List[Dict]: |
| """ |
| Search for relevant guideline chunks. |
| Returns list of chunks with similarity scores. |
| """ |
| if not self.loaded: |
| if not self.load_index(): |
| self.build_index() |
|
|
| import faiss |
|
|
| |
| query_embedding = self.embedder.encode([query]) |
| query_embedding = np.array(query_embedding).astype('float32') |
| faiss.normalize_L2(query_embedding) |
|
|
| |
| scores, indices = self.index.search(query_embedding, k) |
|
|
| results = [] |
| for score, idx in zip(scores[0], indices[0]): |
| if idx < len(self.chunks): |
| chunk = self.chunks[idx].copy() |
| chunk['score'] = float(score) |
| results.append(chunk) |
|
|
| return results |
|
|
| def get_management_context(self, diagnosis: str, features: Optional[str] = None) -> Tuple[str, List[Dict]]: |
| """ |
| Get formatted context from guidelines for management recommendations. |
| Returns tuple of (context_string, references_list). |
| References can be used for citation hyperlinks. |
| """ |
| |
| query = f"{diagnosis} management treatment recommendations" |
| if features: |
| query += f" {features}" |
|
|
| chunks = self.search(query, k=5) |
|
|
| if not chunks: |
| return "No relevant guidelines found.", [] |
|
|
| |
| context_parts = [] |
| references = [] |
|
|
| |
| superscripts = ['¹', '²', '³', '⁴', '⁵', '⁶', '⁷', '⁸', '⁹'] |
|
|
| for i, chunk in enumerate(chunks, 1): |
| source = chunk['source'].replace('.pdf', '') |
| page = chunk.get('page', 0) |
| ref_id = f"ref{i}" |
| superscript = superscripts[i-1] if i <= len(superscripts) else f"[{i}]" |
|
|
| |
| context_parts.append(f"[Source {superscript}] {chunk['text']}") |
|
|
| |
| references.append({ |
| 'id': ref_id, |
| 'source': source, |
| 'page': page, |
| 'file': chunk['source'], |
| 'score': chunk.get('score', 0) |
| }) |
|
|
| context = "\n\n".join(context_parts) |
| return context, references |
|
|
| def format_references_for_prompt(self, references: List[Dict]) -> str: |
| """Format references for inclusion in LLM prompt""" |
| if not references: |
| return "" |
|
|
| lines = ["\n**References:**"] |
| for ref in references: |
| lines.append(f"[{ref['id']}] {ref['source']}, p.{ref['page']}") |
| return "\n".join(lines) |
|
|
| def format_references_for_display(self, references: List[Dict]) -> str: |
| """ |
| Format references with markers that frontend can parse into hyperlinks. |
| Uses format: [REF:id:source:page:file:superscript] |
| """ |
| if not references: |
| return "" |
|
|
| |
| superscripts = ['¹', '²', '³', '⁴', '⁵', '⁶', '⁷', '⁸', '⁹'] |
|
|
| lines = ["\n[REFERENCES]"] |
| for i, ref in enumerate(references, 1): |
| superscript = superscripts[i-1] if i <= len(superscripts) else f"[{i}]" |
| |
| lines.append(f"[REF:{ref['id']}:{ref['source']}:{ref['page']}:{ref['file']}:{superscript}]") |
| lines.append("[/REFERENCES]") |
| return "\n".join(lines) |
|
|
|
|
| |
| _rag_instance = None |
|
|
|
|
| def get_guidelines_rag() -> GuidelinesRAG: |
| """Get or create RAG instance""" |
| global _rag_instance |
| if _rag_instance is None: |
| _rag_instance = GuidelinesRAG() |
| return _rag_instance |
|
|
|
|
| if __name__ == "__main__": |
| print("=" * 60) |
| print(" Guidelines RAG System - Index Builder") |
| print("=" * 60) |
|
|
| rag = GuidelinesRAG() |
|
|
| |
| import sys |
| force = "--force" in sys.argv |
| rag.build_index(force_rebuild=force) |
|
|
| |
| print("\n" + "=" * 60) |
| print(" Testing Search") |
| print("=" * 60) |
|
|
| test_queries = [ |
| "melanoma management", |
| "actinic keratosis treatment", |
| "surgical excision margins" |
| ] |
|
|
| for query in test_queries: |
| print(f"\nQuery: '{query}'") |
| results = rag.search(query, k=2) |
| for r in results: |
| print(f" [{r['score']:.3f}] {r['source']}: {r['text'][:100]}...") |
|
|