dwijverma2 commited on
Commit
0fbb682
·
verified ·
1 Parent(s): 4043f4e

Add main orchestrator

Browse files
Files changed (1) hide show
  1. doc_enricher/enricher.py +239 -0
doc_enricher/enricher.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Document Enricher — Main orchestrator.
3
+
4
+ Coordinates the handler, chunker, and LLM client to produce
5
+ a re-enriched copy of a document with proper heading formatting.
6
+
7
+ Usage:
8
+ from doc_enricher import DocumentEnricher, DocxHandler
9
+
10
+ enricher = DocumentEnricher(
11
+ handler=DocxHandler(),
12
+ model="llama3",
13
+ ollama_url="http://localhost:11434",
14
+ )
15
+
16
+ output_path = enricher.enrich("input.docx", "output.docx")
17
+ """
18
+
19
+ import os
20
+ import logging
21
+ import time
22
+ from typing import Optional
23
+
24
+ from .base_handler import BaseHandler, ParagraphInfo
25
+ from .chunker import build_chunks, Chunk, estimate_tokens
26
+ from .llm_client import OllamaClassifier
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class DocumentEnricher:
32
+ """
33
+ Orchestrates the document re-enrichment pipeline:
34
+
35
+ 1. Extract paragraphs from the original document (via handler)
36
+ 2. Chunk paragraphs into LLM-digestible batches
37
+ 3. Classify each batch using the local LLM
38
+ 4. Apply classifications to a copy of the document (via handler)
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ handler: BaseHandler,
44
+ model: str = "llama3",
45
+ ollama_url: str = "http://localhost:11434",
46
+ max_tokens_per_chunk: int = 3000,
47
+ overlap: int = 3,
48
+ temperature: float = 0.0,
49
+ num_ctx: int = 8192,
50
+ timeout_per_request: int = 180,
51
+ include_formatting_hints: bool = True,
52
+ ):
53
+ """
54
+ Args:
55
+ handler: Document format handler (e.g. DocxHandler)
56
+ model: Ollama model name (e.g. "llama3", "llama3:8b")
57
+ ollama_url: Ollama API base URL
58
+ max_tokens_per_chunk: Max tokens of paragraph text per LLM call
59
+ overlap: Paragraphs of overlap between chunks
60
+ temperature: LLM temperature (0.0 = deterministic)
61
+ num_ctx: LLM context window size
62
+ timeout_per_request: HTTP timeout per LLM call in seconds
63
+ include_formatting_hints: Send existing formatting metadata to LLM
64
+ """
65
+ self.handler = handler
66
+ self.max_tokens_per_chunk = max_tokens_per_chunk
67
+ self.overlap = overlap
68
+ self.include_formatting_hints = include_formatting_hints
69
+
70
+ self.classifier = OllamaClassifier(
71
+ model=model,
72
+ ollama_url=ollama_url,
73
+ temperature=temperature,
74
+ num_ctx=num_ctx,
75
+ timeout=timeout_per_request,
76
+ )
77
+
78
+ def enrich(
79
+ self,
80
+ src_path: str,
81
+ dst_path: Optional[str] = None,
82
+ ) -> str:
83
+ """
84
+ Produce a re-enriched copy of the document.
85
+
86
+ Args:
87
+ src_path: Path to the original document
88
+ dst_path: Path for the re-enriched copy.
89
+ If None, uses "{name}_enriched.{ext}"
90
+
91
+ Returns:
92
+ Path to the re-enriched document
93
+ """
94
+ if not os.path.exists(src_path):
95
+ raise FileNotFoundError(f"Source document not found: {src_path}")
96
+
97
+ if dst_path is None:
98
+ base, ext = os.path.splitext(src_path)
99
+ dst_path = f"{base}_enriched{ext}"
100
+
101
+ logger.info(f"=== Starting re-enrichment: {src_path} → {dst_path} ===")
102
+ start_time = time.time()
103
+
104
+ # Step 1: Extract paragraphs
105
+ logger.info("Step 1/4: Extracting paragraphs...")
106
+ paragraphs = self.handler.extract_paragraphs(src_path)
107
+ logger.info(f" Found {len(paragraphs)} non-empty paragraphs")
108
+
109
+ if not paragraphs:
110
+ logger.warning("No paragraphs found. Producing unmodified copy.")
111
+ import shutil
112
+ shutil.copy2(src_path, dst_path)
113
+ return dst_path
114
+
115
+ # Step 2: Chunk paragraphs
116
+ logger.info("Step 2/4: Chunking paragraphs...")
117
+ chunks = build_chunks(
118
+ paragraphs,
119
+ max_tokens_per_chunk=self.max_tokens_per_chunk,
120
+ overlap=self.overlap,
121
+ )
122
+ logger.info(f" Created {len(chunks)} chunk(s)")
123
+
124
+ # Build lookup for paragraph info by index
125
+ para_lookup = {p.index: p for p in paragraphs}
126
+
127
+ # Step 3: Classify each chunk
128
+ logger.info("Step 3/4: Classifying paragraphs with LLM...")
129
+ all_classifications: dict[int, str] = {}
130
+
131
+ for chunk_num, chunk in enumerate(chunks, 1):
132
+ logger.info(f" Chunk {chunk_num}/{len(chunks)}: "
133
+ f"{len(chunk.all_indices)} paragraphs "
134
+ f"({len(chunk.classify_indices)} to classify)")
135
+
136
+ # Build the batch for this chunk
137
+ batch = []
138
+ for idx in chunk.all_indices:
139
+ p = para_lookup[idx]
140
+ entry = {
141
+ "index": idx,
142
+ "text": p.text,
143
+ }
144
+ if self.include_formatting_hints:
145
+ entry["style_name"] = p.style_name
146
+ entry["is_bold"] = p.is_bold
147
+ entry["avg_font_size_pt"] = p.avg_font_size_pt
148
+ batch.append(entry)
149
+
150
+ # Call the LLM
151
+ try:
152
+ result = self.classifier.classify_batch(
153
+ batch,
154
+ formatting_hints=self.include_formatting_hints,
155
+ )
156
+ except Exception as e:
157
+ logger.error(f" LLM classification failed for chunk {chunk_num}: {e}")
158
+ logger.info(" Defaulting all paragraphs in this chunk to BODY")
159
+ for idx in chunk.classify_indices:
160
+ if idx not in all_classifications:
161
+ all_classifications[idx] = "BODY"
162
+ continue
163
+
164
+ # Store results, but ONLY for primary (non-overlap) indices
165
+ classify_set = set(chunk.classify_indices)
166
+ for item in result["classifications"]:
167
+ idx = item["index"]
168
+ label = item["label"]
169
+ if idx in classify_set and idx not in all_classifications:
170
+ all_classifications[idx] = label
171
+
172
+ # Fill in any missing classifications with BODY
173
+ for p in paragraphs:
174
+ if p.index not in all_classifications:
175
+ logger.warning(f" Paragraph {p.index} not classified, defaulting to BODY")
176
+ all_classifications[p.index] = "BODY"
177
+
178
+ # Log classification summary
179
+ label_counts = {}
180
+ for label in all_classifications.values():
181
+ label_counts[label] = label_counts.get(label, 0) + 1
182
+ logger.info(f" Classification summary: {label_counts}")
183
+
184
+ # Step 4: Apply formatting to copy
185
+ logger.info("Step 4/4: Applying formatting to document copy...")
186
+ output = self.handler.apply_classifications(
187
+ src_path, dst_path, all_classifications
188
+ )
189
+
190
+ elapsed = time.time() - start_time
191
+ logger.info(f"=== Re-enrichment complete in {elapsed:.1f}s: {output} ===")
192
+
193
+ return output
194
+
195
+ def enrich_batch(
196
+ self,
197
+ src_dir: str,
198
+ dst_dir: str,
199
+ extension: str = ".docx",
200
+ ) -> list[str]:
201
+ """
202
+ Re-enrich all documents in a directory.
203
+
204
+ Args:
205
+ src_dir: Directory containing original documents
206
+ dst_dir: Directory where re-enriched copies will be saved
207
+ extension: File extension to filter by
208
+
209
+ Returns:
210
+ List of paths to re-enriched documents
211
+ """
212
+ os.makedirs(dst_dir, exist_ok=True)
213
+
214
+ src_files = sorted(
215
+ f for f in os.listdir(src_dir)
216
+ if f.lower().endswith(extension)
217
+ )
218
+
219
+ if not src_files:
220
+ logger.warning(f"No {extension} files found in {src_dir}")
221
+ return []
222
+
223
+ logger.info(f"Batch processing {len(src_files)} files from {src_dir}")
224
+ outputs = []
225
+
226
+ for i, filename in enumerate(src_files, 1):
227
+ src_path = os.path.join(src_dir, filename)
228
+ dst_path = os.path.join(dst_dir, filename)
229
+
230
+ logger.info(f"[{i}/{len(src_files)}] Processing: {filename}")
231
+ try:
232
+ output = self.enrich(src_path, dst_path)
233
+ outputs.append(output)
234
+ except Exception as e:
235
+ logger.error(f" Failed to process {filename}: {e}")
236
+ continue
237
+
238
+ logger.info(f"Batch complete: {len(outputs)}/{len(src_files)} files processed")
239
+ return outputs