CanLex / canlex /embed.py
Beemer
Use case-law topic, not paragraph range, as the retrieval title
666cd44
"""Build semantic embeddings for ingested chunks (local, key-free).
Uses BAAI's bge-small-en-v1.5 sentence-embedding model as ONNX, run on CPU via
onnxruntime -- no API key. A transformer embedding has far stronger retrieval
recall than a static one: it can connect a natural-language question to a
provision even when the two share few exact words.
"""
import json
import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer
from .config import PROCESSED_DIR
EMB_REPO = "Xenova/bge-small-en-v1.5"
EMB_PATH = PROCESSED_DIR / "embeddings.npz"
_MAX_TOKENS = 512
_MAX_BODY = 2000 # cap embedded body text so long sections stay topically focused
# bge-small retrieval: the query is prefixed with this instruction; passages
# are embedded without it. The asymmetry is how the model was trained.
_QUERY_PREFIX = "Represent this sentence for searching relevant passages: "
def load_chunks():
chunks = []
for path in sorted(PROCESSED_DIR.glob("*.json")):
chunks.extend(json.loads(path.read_text(encoding="utf-8")))
return chunks
def embed_text(chunk):
"""Compact, retrieval-focused representation of one section."""
# The section title is the strongest topical signal, so it is repeated to
# emphasise it. Title selection is doc_type-aware (see index.topical_title):
# a D-memo's marginal_note is a generic banner so its actual subject in
# 'part' is used; a case-law chunk's marginal_note is just the paragraph
# range so the case proposition in 'heading' is used.
from .index import topical_title
note = topical_title(chunk)
body = chunk["text"][:_MAX_BODY]
parts = [chunk["act_short"], note, note, chunk["heading"], body]
return " . ".join(p for p in parts if p)
class Embedder:
"""Local transformer sentence-embedder: bge-small-en-v1.5 as ONNX on CPU.
No API key; the model is downloaded once and cached. Produces L2-normalized
vectors, so a dot product between them is cosine similarity.
"""
def __init__(self):
model_path = None
for name in ("onnx/model_quantized.onnx", "onnx/model.onnx"):
try:
model_path = hf_hub_download(EMB_REPO, name)
break
except Exception:
continue
if model_path is None:
raise RuntimeError(f"Could not download an ONNX model from {EMB_REPO}.")
tok_path = hf_hub_download(EMB_REPO, "tokenizer.json")
self.session = ort.InferenceSession(model_path,
providers=["CPUExecutionProvider"])
self.input_names = {i.name for i in self.session.get_inputs()}
self.tokenizer = Tokenizer.from_file(tok_path)
self.tokenizer.enable_truncation(max_length=_MAX_TOKENS)
def _run(self, texts):
"""Tokenize, run the encoder, CLS-pool and L2-normalize one batch."""
encs = self.tokenizer.encode_batch(list(texts))
width = max(len(e.ids) for e in encs)
input_ids = np.zeros((len(encs), width), dtype=np.int64)
attention = np.zeros((len(encs), width), dtype=np.int64)
type_ids = np.zeros((len(encs), width), dtype=np.int64)
for row, enc in enumerate(encs):
n = len(enc.ids)
input_ids[row, :n] = enc.ids
attention[row, :n] = enc.attention_mask
type_ids[row, :n] = enc.type_ids
feed = {"input_ids": input_ids, "attention_mask": attention}
if "token_type_ids" in self.input_names:
feed["token_type_ids"] = type_ids
hidden = np.asarray(self.session.run(None, feed)[0], dtype=np.float32)
cls = hidden[:, 0, :] if hidden.ndim == 3 else hidden # BGE: CLS pooling
norms = np.linalg.norm(cls, axis=1, keepdims=True)
return cls / np.maximum(norms, 1e-9)
def encode(self, texts, batch_size=32):
"""Return L2-normalized embeddings for passages, one row per text."""
texts = list(texts)
if not texts:
return np.zeros((0, 384), dtype=np.float32)
rows = [self._run(texts[i:i + batch_size])
for i in range(0, len(texts), batch_size)]
return np.vstack(rows)
def encode_query(self, text):
"""Return the L2-normalized embedding for one search query."""
return self._run([_QUERY_PREFIX + text])[0]
def build():
chunks = load_chunks()
if not chunks:
print(f"No processed data in {PROCESSED_DIR}. Run 'canlex.ingest' first.")
return
print(f"Embedding {len(chunks)} sections with {EMB_REPO} ...")
vectors = Embedder().encode([embed_text(c) for c in chunks])
ids = np.array([c["id"] for c in chunks])
np.savez(EMB_PATH, ids=ids, vectors=vectors)
print(f" {vectors.shape[0]} vectors, dim {vectors.shape[1]} -> {EMB_PATH.name}")
if __name__ == "__main__":
build()