dwijverma2 commited on
Commit
e5a883e
·
verified ·
1 Parent(s): d2a9d6e

Add LLM client for Ollama

Browse files
Files changed (1) hide show
  1. doc_enricher/llm_client.py +181 -0
doc_enricher/llm_client.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Client for paragraph classification via Ollama.
3
+
4
+ Uses the /api/chat endpoint with JSON-constrained decoding
5
+ for reliable structured output from Llama3.
6
+ """
7
+
8
+ import json
9
+ import logging
10
+ import requests
11
+ from typing import Optional
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+ # Default Ollama endpoint
16
+ DEFAULT_OLLAMA_URL = "http://localhost:11434"
17
+
18
+ SYSTEM_PROMPT = """You are a document structure classifier. Your task is to classify each paragraph of a document into exactly one of three categories:
19
+
20
+ - **TITLE**: The document title. Usually appears once near the top. Short, descriptive of the entire document's topic. Not a section name.
21
+ - **SECTION_HEADING**: A section or subsection heading. Short, labels a section of content. Typically a phrase or short sentence, not a full paragraph of prose.
22
+ - **BODY**: Regular body text. Sentences, bullet points, paragraphs of actual content. This is the default — if in doubt, classify as BODY.
23
+
24
+ Rules:
25
+ 1. A document usually has exactly ONE title (the first significant text). If the first paragraph is short and describes the whole document, it's likely TITLE.
26
+ 2. SECTION_HEADINGs are short (typically under 10 words) and introduce a topic. They are NOT sentences — they don't end with periods.
27
+ 3. Everything else is BODY.
28
+ 4. Consider context: a short line between two long paragraphs of prose is likely a SECTION_HEADING. A short line in a list of short lines is likely BODY.
29
+
30
+ You will receive paragraphs in the format:
31
+ [index] text...
32
+
33
+ You MUST respond with valid JSON in exactly this format:
34
+ {"classifications": [{"index": <int>, "label": "<TITLE|SECTION_HEADING|BODY>"}, ...]}
35
+
36
+ Include an entry for EVERY paragraph index provided. Do not skip any."""
37
+
38
+
39
+ class OllamaClassifier:
40
+ """Classifies paragraphs using a local Ollama LLM instance."""
41
+
42
+ def __init__(
43
+ self,
44
+ model: str = "llama3",
45
+ ollama_url: str = DEFAULT_OLLAMA_URL,
46
+ temperature: float = 0.0,
47
+ num_ctx: int = 8192,
48
+ timeout: int = 180,
49
+ ):
50
+ self.model = model
51
+ self.ollama_url = ollama_url.rstrip("/")
52
+ self.temperature = temperature
53
+ self.num_ctx = num_ctx
54
+ self.timeout = timeout
55
+
56
+ # Verify Ollama is reachable
57
+ self._check_connection()
58
+
59
+ def _check_connection(self):
60
+ """Check that Ollama is running and the model is available."""
61
+ try:
62
+ resp = requests.get(f"{self.ollama_url}/api/tags", timeout=5)
63
+ resp.raise_for_status()
64
+ models = [m["name"] for m in resp.json().get("models", [])]
65
+ # Model names may include tag like "llama3:latest"
66
+ model_found = any(
67
+ self.model in m for m in models
68
+ )
69
+ if not model_found:
70
+ logger.warning(
71
+ f"Model '{self.model}' not found in Ollama. "
72
+ f"Available: {models}. Will attempt to use it anyway "
73
+ f"(Ollama may auto-pull)."
74
+ )
75
+ else:
76
+ logger.info(f"Ollama connected. Model '{self.model}' available.")
77
+ except requests.ConnectionError:
78
+ raise ConnectionError(
79
+ f"Cannot connect to Ollama at {self.ollama_url}. "
80
+ f"Is Ollama running? Start with: ollama serve"
81
+ )
82
+
83
+ def classify_batch(
84
+ self,
85
+ paragraphs: list[dict],
86
+ formatting_hints: bool = True,
87
+ ) -> dict:
88
+ """
89
+ Send a batch of paragraphs to the LLM for classification.
90
+
91
+ Args:
92
+ paragraphs: List of dicts with keys:
93
+ - index (int): Paragraph index in the original document
94
+ - text (str): Paragraph text
95
+ - style_name (str, optional): Current style name
96
+ - is_bold (bool, optional): Whether any run is bold
97
+ - avg_font_size_pt (float, optional): Average font size
98
+ formatting_hints: Whether to include formatting metadata in prompt
99
+
100
+ Returns:
101
+ Dict with "classifications" key containing list of
102
+ {"index": int, "label": str} dicts
103
+ """
104
+ # Build the user message
105
+ lines = []
106
+ for p in paragraphs:
107
+ line = f'[{p["index"]}] {p["text"][:300]}'
108
+ if formatting_hints:
109
+ hints = []
110
+ if p.get("style_name"):
111
+ hints.append(f'style="{p["style_name"]}"')
112
+ if p.get("is_bold") is not None:
113
+ hints.append(f'bold={p["is_bold"]}')
114
+ if p.get("avg_font_size_pt") is not None:
115
+ hints.append(f'size={p["avg_font_size_pt"]:.1f}pt')
116
+ if hints:
117
+ line += f' ({", ".join(hints)})'
118
+ lines.append(line)
119
+
120
+ user_content = (
121
+ f"Classify these {len(paragraphs)} paragraphs from a document:\n\n"
122
+ + "\n".join(lines)
123
+ )
124
+
125
+ logger.debug(f"Sending {len(paragraphs)} paragraphs to LLM ({len(user_content)} chars)")
126
+
127
+ response = requests.post(
128
+ f"{self.ollama_url}/api/chat",
129
+ json={
130
+ "model": self.model,
131
+ "messages": [
132
+ {"role": "system", "content": SYSTEM_PROMPT},
133
+ {"role": "user", "content": user_content},
134
+ ],
135
+ "stream": False,
136
+ "format": "json",
137
+ "options": {
138
+ "temperature": self.temperature,
139
+ "num_ctx": self.num_ctx,
140
+ },
141
+ },
142
+ timeout=self.timeout,
143
+ )
144
+ response.raise_for_status()
145
+
146
+ raw_text = response.json()["message"]["content"]
147
+ logger.debug(f"LLM response: {raw_text[:500]}")
148
+
149
+ try:
150
+ result = json.loads(raw_text)
151
+ except json.JSONDecodeError as e:
152
+ logger.error(f"Failed to parse LLM JSON response: {e}\nRaw: {raw_text[:1000]}")
153
+ raise ValueError(f"LLM returned invalid JSON: {e}")
154
+
155
+ # Validate structure
156
+ if "classifications" not in result:
157
+ # Some models return a flat list or different key
158
+ # Try to recover
159
+ if isinstance(result, list):
160
+ result = {"classifications": result}
161
+ elif isinstance(result, dict) and len(result) == 1:
162
+ result = {"classifications": list(result.values())[0]}
163
+ else:
164
+ raise ValueError(
165
+ f"LLM response missing 'classifications' key. Got keys: {list(result.keys())}"
166
+ )
167
+
168
+ # Validate each classification
169
+ valid_labels = {"TITLE", "SECTION_HEADING", "BODY"}
170
+ for item in result["classifications"]:
171
+ if "label" not in item or "index" not in item:
172
+ logger.warning(f"Malformed classification item: {item}")
173
+ continue
174
+ label = item["label"].upper().strip()
175
+ if label not in valid_labels:
176
+ logger.warning(f"Unknown label '{label}' for index {item['index']}, defaulting to BODY")
177
+ item["label"] = "BODY"
178
+ else:
179
+ item["label"] = label
180
+
181
+ return result