Add LLM client for Ollama
Browse files- doc_enricher/llm_client.py +181 -0
doc_enricher/llm_client.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
LLM Client for paragraph classification via Ollama.
|
| 3 |
+
|
| 4 |
+
Uses the /api/chat endpoint with JSON-constrained decoding
|
| 5 |
+
for reliable structured output from Llama3.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import json
|
| 9 |
+
import logging
|
| 10 |
+
import requests
|
| 11 |
+
from typing import Optional
|
| 12 |
+
|
| 13 |
+
logger = logging.getLogger(__name__)
|
| 14 |
+
|
| 15 |
+
# Default Ollama endpoint
|
| 16 |
+
DEFAULT_OLLAMA_URL = "http://localhost:11434"
|
| 17 |
+
|
| 18 |
+
SYSTEM_PROMPT = """You are a document structure classifier. Your task is to classify each paragraph of a document into exactly one of three categories:
|
| 19 |
+
|
| 20 |
+
- **TITLE**: The document title. Usually appears once near the top. Short, descriptive of the entire document's topic. Not a section name.
|
| 21 |
+
- **SECTION_HEADING**: A section or subsection heading. Short, labels a section of content. Typically a phrase or short sentence, not a full paragraph of prose.
|
| 22 |
+
- **BODY**: Regular body text. Sentences, bullet points, paragraphs of actual content. This is the default — if in doubt, classify as BODY.
|
| 23 |
+
|
| 24 |
+
Rules:
|
| 25 |
+
1. A document usually has exactly ONE title (the first significant text). If the first paragraph is short and describes the whole document, it's likely TITLE.
|
| 26 |
+
2. SECTION_HEADINGs are short (typically under 10 words) and introduce a topic. They are NOT sentences — they don't end with periods.
|
| 27 |
+
3. Everything else is BODY.
|
| 28 |
+
4. Consider context: a short line between two long paragraphs of prose is likely a SECTION_HEADING. A short line in a list of short lines is likely BODY.
|
| 29 |
+
|
| 30 |
+
You will receive paragraphs in the format:
|
| 31 |
+
[index] text...
|
| 32 |
+
|
| 33 |
+
You MUST respond with valid JSON in exactly this format:
|
| 34 |
+
{"classifications": [{"index": <int>, "label": "<TITLE|SECTION_HEADING|BODY>"}, ...]}
|
| 35 |
+
|
| 36 |
+
Include an entry for EVERY paragraph index provided. Do not skip any."""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class OllamaClassifier:
|
| 40 |
+
"""Classifies paragraphs using a local Ollama LLM instance."""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
model: str = "llama3",
|
| 45 |
+
ollama_url: str = DEFAULT_OLLAMA_URL,
|
| 46 |
+
temperature: float = 0.0,
|
| 47 |
+
num_ctx: int = 8192,
|
| 48 |
+
timeout: int = 180,
|
| 49 |
+
):
|
| 50 |
+
self.model = model
|
| 51 |
+
self.ollama_url = ollama_url.rstrip("/")
|
| 52 |
+
self.temperature = temperature
|
| 53 |
+
self.num_ctx = num_ctx
|
| 54 |
+
self.timeout = timeout
|
| 55 |
+
|
| 56 |
+
# Verify Ollama is reachable
|
| 57 |
+
self._check_connection()
|
| 58 |
+
|
| 59 |
+
def _check_connection(self):
|
| 60 |
+
"""Check that Ollama is running and the model is available."""
|
| 61 |
+
try:
|
| 62 |
+
resp = requests.get(f"{self.ollama_url}/api/tags", timeout=5)
|
| 63 |
+
resp.raise_for_status()
|
| 64 |
+
models = [m["name"] for m in resp.json().get("models", [])]
|
| 65 |
+
# Model names may include tag like "llama3:latest"
|
| 66 |
+
model_found = any(
|
| 67 |
+
self.model in m for m in models
|
| 68 |
+
)
|
| 69 |
+
if not model_found:
|
| 70 |
+
logger.warning(
|
| 71 |
+
f"Model '{self.model}' not found in Ollama. "
|
| 72 |
+
f"Available: {models}. Will attempt to use it anyway "
|
| 73 |
+
f"(Ollama may auto-pull)."
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
logger.info(f"Ollama connected. Model '{self.model}' available.")
|
| 77 |
+
except requests.ConnectionError:
|
| 78 |
+
raise ConnectionError(
|
| 79 |
+
f"Cannot connect to Ollama at {self.ollama_url}. "
|
| 80 |
+
f"Is Ollama running? Start with: ollama serve"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def classify_batch(
|
| 84 |
+
self,
|
| 85 |
+
paragraphs: list[dict],
|
| 86 |
+
formatting_hints: bool = True,
|
| 87 |
+
) -> dict:
|
| 88 |
+
"""
|
| 89 |
+
Send a batch of paragraphs to the LLM for classification.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
paragraphs: List of dicts with keys:
|
| 93 |
+
- index (int): Paragraph index in the original document
|
| 94 |
+
- text (str): Paragraph text
|
| 95 |
+
- style_name (str, optional): Current style name
|
| 96 |
+
- is_bold (bool, optional): Whether any run is bold
|
| 97 |
+
- avg_font_size_pt (float, optional): Average font size
|
| 98 |
+
formatting_hints: Whether to include formatting metadata in prompt
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
Dict with "classifications" key containing list of
|
| 102 |
+
{"index": int, "label": str} dicts
|
| 103 |
+
"""
|
| 104 |
+
# Build the user message
|
| 105 |
+
lines = []
|
| 106 |
+
for p in paragraphs:
|
| 107 |
+
line = f'[{p["index"]}] {p["text"][:300]}'
|
| 108 |
+
if formatting_hints:
|
| 109 |
+
hints = []
|
| 110 |
+
if p.get("style_name"):
|
| 111 |
+
hints.append(f'style="{p["style_name"]}"')
|
| 112 |
+
if p.get("is_bold") is not None:
|
| 113 |
+
hints.append(f'bold={p["is_bold"]}')
|
| 114 |
+
if p.get("avg_font_size_pt") is not None:
|
| 115 |
+
hints.append(f'size={p["avg_font_size_pt"]:.1f}pt')
|
| 116 |
+
if hints:
|
| 117 |
+
line += f' ({", ".join(hints)})'
|
| 118 |
+
lines.append(line)
|
| 119 |
+
|
| 120 |
+
user_content = (
|
| 121 |
+
f"Classify these {len(paragraphs)} paragraphs from a document:\n\n"
|
| 122 |
+
+ "\n".join(lines)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
logger.debug(f"Sending {len(paragraphs)} paragraphs to LLM ({len(user_content)} chars)")
|
| 126 |
+
|
| 127 |
+
response = requests.post(
|
| 128 |
+
f"{self.ollama_url}/api/chat",
|
| 129 |
+
json={
|
| 130 |
+
"model": self.model,
|
| 131 |
+
"messages": [
|
| 132 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 133 |
+
{"role": "user", "content": user_content},
|
| 134 |
+
],
|
| 135 |
+
"stream": False,
|
| 136 |
+
"format": "json",
|
| 137 |
+
"options": {
|
| 138 |
+
"temperature": self.temperature,
|
| 139 |
+
"num_ctx": self.num_ctx,
|
| 140 |
+
},
|
| 141 |
+
},
|
| 142 |
+
timeout=self.timeout,
|
| 143 |
+
)
|
| 144 |
+
response.raise_for_status()
|
| 145 |
+
|
| 146 |
+
raw_text = response.json()["message"]["content"]
|
| 147 |
+
logger.debug(f"LLM response: {raw_text[:500]}")
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
result = json.loads(raw_text)
|
| 151 |
+
except json.JSONDecodeError as e:
|
| 152 |
+
logger.error(f"Failed to parse LLM JSON response: {e}\nRaw: {raw_text[:1000]}")
|
| 153 |
+
raise ValueError(f"LLM returned invalid JSON: {e}")
|
| 154 |
+
|
| 155 |
+
# Validate structure
|
| 156 |
+
if "classifications" not in result:
|
| 157 |
+
# Some models return a flat list or different key
|
| 158 |
+
# Try to recover
|
| 159 |
+
if isinstance(result, list):
|
| 160 |
+
result = {"classifications": result}
|
| 161 |
+
elif isinstance(result, dict) and len(result) == 1:
|
| 162 |
+
result = {"classifications": list(result.values())[0]}
|
| 163 |
+
else:
|
| 164 |
+
raise ValueError(
|
| 165 |
+
f"LLM response missing 'classifications' key. Got keys: {list(result.keys())}"
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
# Validate each classification
|
| 169 |
+
valid_labels = {"TITLE", "SECTION_HEADING", "BODY"}
|
| 170 |
+
for item in result["classifications"]:
|
| 171 |
+
if "label" not in item or "index" not in item:
|
| 172 |
+
logger.warning(f"Malformed classification item: {item}")
|
| 173 |
+
continue
|
| 174 |
+
label = item["label"].upper().strip()
|
| 175 |
+
if label not in valid_labels:
|
| 176 |
+
logger.warning(f"Unknown label '{label}' for index {item['index']}, defaulting to BODY")
|
| 177 |
+
item["label"] = "BODY"
|
| 178 |
+
else:
|
| 179 |
+
item["label"] = label
|
| 180 |
+
|
| 181 |
+
return result
|