"""Build semantic embeddings for ingested chunks (local, key-free). Uses BAAI's bge-small-en-v1.5 sentence-embedding model as ONNX, run on CPU via onnxruntime -- no API key. A transformer embedding has far stronger retrieval recall than a static one: it can connect a natural-language question to a provision even when the two share few exact words. """ import json import numpy as np import onnxruntime as ort from huggingface_hub import hf_hub_download from tokenizers import Tokenizer from .config import PROCESSED_DIR EMB_REPO = "Xenova/bge-small-en-v1.5" EMB_PATH = PROCESSED_DIR / "embeddings.npz" _MAX_TOKENS = 512 _MAX_BODY = 2000 # cap embedded body text so long sections stay topically focused # bge-small retrieval: the query is prefixed with this instruction; passages # are embedded without it. The asymmetry is how the model was trained. _QUERY_PREFIX = "Represent this sentence for searching relevant passages: " def load_chunks(): chunks = [] for path in sorted(PROCESSED_DIR.glob("*.json")): chunks.extend(json.loads(path.read_text(encoding="utf-8"))) return chunks def embed_text(chunk): """Compact, retrieval-focused representation of one section.""" # The section title is the strongest topical signal, so it is repeated to # emphasise it. Title selection is doc_type-aware (see index.topical_title): # a D-memo's marginal_note is a generic banner so its actual subject in # 'part' is used; a case-law chunk's marginal_note is just the paragraph # range so the case proposition in 'heading' is used. from .index import topical_title note = topical_title(chunk) body = chunk["text"][:_MAX_BODY] parts = [chunk["act_short"], note, note, chunk["heading"], body] return " . ".join(p for p in parts if p) class Embedder: """Local transformer sentence-embedder: bge-small-en-v1.5 as ONNX on CPU. No API key; the model is downloaded once and cached. Produces L2-normalized vectors, so a dot product between them is cosine similarity. """ def __init__(self): model_path = None for name in ("onnx/model_quantized.onnx", "onnx/model.onnx"): try: model_path = hf_hub_download(EMB_REPO, name) break except Exception: continue if model_path is None: raise RuntimeError(f"Could not download an ONNX model from {EMB_REPO}.") tok_path = hf_hub_download(EMB_REPO, "tokenizer.json") self.session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) self.input_names = {i.name for i in self.session.get_inputs()} self.tokenizer = Tokenizer.from_file(tok_path) self.tokenizer.enable_truncation(max_length=_MAX_TOKENS) def _run(self, texts): """Tokenize, run the encoder, CLS-pool and L2-normalize one batch.""" encs = self.tokenizer.encode_batch(list(texts)) width = max(len(e.ids) for e in encs) input_ids = np.zeros((len(encs), width), dtype=np.int64) attention = np.zeros((len(encs), width), dtype=np.int64) type_ids = np.zeros((len(encs), width), dtype=np.int64) for row, enc in enumerate(encs): n = len(enc.ids) input_ids[row, :n] = enc.ids attention[row, :n] = enc.attention_mask type_ids[row, :n] = enc.type_ids feed = {"input_ids": input_ids, "attention_mask": attention} if "token_type_ids" in self.input_names: feed["token_type_ids"] = type_ids hidden = np.asarray(self.session.run(None, feed)[0], dtype=np.float32) cls = hidden[:, 0, :] if hidden.ndim == 3 else hidden # BGE: CLS pooling norms = np.linalg.norm(cls, axis=1, keepdims=True) return cls / np.maximum(norms, 1e-9) def encode(self, texts, batch_size=32): """Return L2-normalized embeddings for passages, one row per text.""" texts = list(texts) if not texts: return np.zeros((0, 384), dtype=np.float32) rows = [self._run(texts[i:i + batch_size]) for i in range(0, len(texts), batch_size)] return np.vstack(rows) def encode_query(self, text): """Return the L2-normalized embedding for one search query.""" return self._run([_QUERY_PREFIX + text])[0] def build(): chunks = load_chunks() if not chunks: print(f"No processed data in {PROCESSED_DIR}. Run 'canlex.ingest' first.") return print(f"Embedding {len(chunks)} sections with {EMB_REPO} ...") vectors = Embedder().encode([embed_text(c) for c in chunks]) ids = np.array([c["id"] for c in chunks]) np.savez(EMB_PATH, ids=ids, vectors=vectors) print(f" {vectors.shape[0]} vectors, dim {vectors.shape[1]} -> {EMB_PATH.name}") if __name__ == "__main__": build()