File size: 4,907 Bytes
2966f10
 
 
 
 
 
 
21626e7
 
 
2966f10
 
 
21626e7
 
 
2966f10
21626e7
2966f10
 
 
 
 
21626e7
 
 
 
 
 
 
 
 
 
 
589d46e
666cd44
 
 
 
 
 
21626e7
 
 
 
 
 
2966f10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21626e7
2966f10
 
 
 
 
 
 
 
21626e7
2966f10
 
 
21626e7
 
 
 
 
2966f10
21626e7
2966f10
21626e7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Build semantic embeddings for ingested chunks (local, key-free).

Uses BAAI's bge-small-en-v1.5 sentence-embedding model as ONNX, run on CPU via
onnxruntime -- no API key. A transformer embedding has far stronger retrieval
recall than a static one: it can connect a natural-language question to a
provision even when the two share few exact words.
"""
import json

import numpy as np
import onnxruntime as ort
from huggingface_hub import hf_hub_download
from tokenizers import Tokenizer

from .config import PROCESSED_DIR

EMB_REPO = "Xenova/bge-small-en-v1.5"
EMB_PATH = PROCESSED_DIR / "embeddings.npz"
_MAX_TOKENS = 512
_MAX_BODY = 2000   # cap embedded body text so long sections stay topically focused
# bge-small retrieval: the query is prefixed with this instruction; passages
# are embedded without it. The asymmetry is how the model was trained.
_QUERY_PREFIX = "Represent this sentence for searching relevant passages: "


def load_chunks():
    chunks = []
    for path in sorted(PROCESSED_DIR.glob("*.json")):
        chunks.extend(json.loads(path.read_text(encoding="utf-8")))
    return chunks


def embed_text(chunk):
    """Compact, retrieval-focused representation of one section."""
    # The section title is the strongest topical signal, so it is repeated to
    # emphasise it. Title selection is doc_type-aware (see index.topical_title):
    # a D-memo's marginal_note is a generic banner so its actual subject in
    # 'part' is used; a case-law chunk's marginal_note is just the paragraph
    # range so the case proposition in 'heading' is used.
    from .index import topical_title
    note = topical_title(chunk)
    body = chunk["text"][:_MAX_BODY]
    parts = [chunk["act_short"], note, note, chunk["heading"], body]
    return " . ".join(p for p in parts if p)


class Embedder:
    """Local transformer sentence-embedder: bge-small-en-v1.5 as ONNX on CPU.

    No API key; the model is downloaded once and cached. Produces L2-normalized
    vectors, so a dot product between them is cosine similarity.
    """

    def __init__(self):
        model_path = None
        for name in ("onnx/model_quantized.onnx", "onnx/model.onnx"):
            try:
                model_path = hf_hub_download(EMB_REPO, name)
                break
            except Exception:
                continue
        if model_path is None:
            raise RuntimeError(f"Could not download an ONNX model from {EMB_REPO}.")
        tok_path = hf_hub_download(EMB_REPO, "tokenizer.json")
        self.session = ort.InferenceSession(model_path,
                                            providers=["CPUExecutionProvider"])
        self.input_names = {i.name for i in self.session.get_inputs()}
        self.tokenizer = Tokenizer.from_file(tok_path)
        self.tokenizer.enable_truncation(max_length=_MAX_TOKENS)

    def _run(self, texts):
        """Tokenize, run the encoder, CLS-pool and L2-normalize one batch."""
        encs = self.tokenizer.encode_batch(list(texts))
        width = max(len(e.ids) for e in encs)
        input_ids = np.zeros((len(encs), width), dtype=np.int64)
        attention = np.zeros((len(encs), width), dtype=np.int64)
        type_ids = np.zeros((len(encs), width), dtype=np.int64)
        for row, enc in enumerate(encs):
            n = len(enc.ids)
            input_ids[row, :n] = enc.ids
            attention[row, :n] = enc.attention_mask
            type_ids[row, :n] = enc.type_ids
        feed = {"input_ids": input_ids, "attention_mask": attention}
        if "token_type_ids" in self.input_names:
            feed["token_type_ids"] = type_ids
        hidden = np.asarray(self.session.run(None, feed)[0], dtype=np.float32)
        cls = hidden[:, 0, :] if hidden.ndim == 3 else hidden   # BGE: CLS pooling
        norms = np.linalg.norm(cls, axis=1, keepdims=True)
        return cls / np.maximum(norms, 1e-9)

    def encode(self, texts, batch_size=32):
        """Return L2-normalized embeddings for passages, one row per text."""
        texts = list(texts)
        if not texts:
            return np.zeros((0, 384), dtype=np.float32)
        rows = [self._run(texts[i:i + batch_size])
                for i in range(0, len(texts), batch_size)]
        return np.vstack(rows)

    def encode_query(self, text):
        """Return the L2-normalized embedding for one search query."""
        return self._run([_QUERY_PREFIX + text])[0]


def build():
    chunks = load_chunks()
    if not chunks:
        print(f"No processed data in {PROCESSED_DIR}. Run 'canlex.ingest' first.")
        return
    print(f"Embedding {len(chunks)} sections with {EMB_REPO} ...")
    vectors = Embedder().encode([embed_text(c) for c in chunks])
    ids = np.array([c["id"] for c in chunks])
    np.savez(EMB_PATH, ids=ids, vectors=vectors)
    print(f"  {vectors.shape[0]} vectors, dim {vectors.shape[1]} -> {EMB_PATH.name}")


if __name__ == "__main__":
    build()