CanLex / tests /test_index.py
Beemer
Sweep-tune the regulation and back-matter penalties; revert the failed swap
a7a22f5
"""Unit tests for the retrieval pipeline (canlex/index.py).
Fast, offline tests of the pure retrieval logic -- tokenisation, section-
reference parsing, the diversity cap, the result-set guarantee and the doc-type
flags. They build a bare LegislationIndex via __new__, so no corpus, embeddings
or reranker are loaded.
python -m unittest discover -s tests
"""
import unittest
from canlex.index import (
LegislationIndex, SOURCE_CAP, APPENDIX_CAP, tokenize, _section_refs,
_provision_units,
)
def chunk(doc_type="legislation", act_code="I-2.5", section="1",
marginal_note="Title", part="", **extra):
"""A minimal corpus chunk carrying the fields the index logic reads."""
c = {"doc_type": doc_type, "act_code": act_code, "section": section,
"marginal_note": marginal_note, "part": part, "heading": "",
"act_short": "X", "text": ""}
c.update(extra)
return c
def bare_index(chunks):
"""A LegislationIndex with only .chunks set -- enough for the pure methods."""
idx = LegislationIndex.__new__(LegislationIndex)
idx.chunks = chunks
return idx
class TokenizeTests(unittest.TestCase):
def test_case_insensitive(self):
self.assertEqual(tokenize("REPORT goods"), tokenize("report Goods"))
def test_stemming_unifies_word_forms(self):
# The point of stemming: different forms collapse to one token.
self.assertEqual(tokenize("reporting"), tokenize("reported"))
self.assertEqual(tokenize("importation"), tokenize("import"))
def test_splits_on_non_alphanumeric(self):
self.assertEqual(tokenize("s.34(1)(a)"), ["s", "34", "1", "a"])
def test_empty(self):
self.assertEqual(tokenize(""), [])
class SectionRefTests(unittest.TestCase):
def test_plain_section(self):
self.assertEqual(_section_refs("inadmissible under section 34"), {"34"})
def test_decimal_and_abbreviated(self):
self.assertEqual(_section_refs("see s. 20.1 and section 5"), {"20.1", "5"})
def test_no_reference(self):
self.assertEqual(_section_refs("what is a pre-removal risk assessment"),
set())
class ProvisionUnitsTests(unittest.TestCase):
def test_structured_provision_yields_units(self):
text = "(1) The chapeau.\n(a) first paragraph\n(b) second paragraph"
self.assertTrue(_provision_units(text))
def test_flat_provision_yields_nothing(self):
self.assertEqual(_provision_units("A flat provision with no markers."),
[])
class SourceKeyTests(unittest.TestCase):
"""_source_key decides what the diversity cap collapses."""
def test_primary_instruments_are_never_capped(self):
idx = bare_index([
chunk(doc_type="legislation"),
chunk(doc_type="agreement", act_code="FB"),
chunk(doc_type="directive", act_code="d1"),
])
for i in range(3):
self.assertIsNone(idx._source_key(i))
def test_caselaw_and_memoranda_are_keyed(self):
idx = bare_index([
chunk(doc_type="memorandum", act_code="D-Memo", section="D1-1-1"),
chunk(doc_type="caselaw", act_code="2019 SCC 65"),
])
self.assertEqual(idx._source_key(0), ("memorandum", "D1-1-1"))
self.assertEqual(idx._source_key(1), ("caselaw", "2019 SCC 65"))
class DiversifyTests(unittest.TestCase):
def test_caps_caselaw_per_decision(self):
n = SOURCE_CAP + 2
chunks = [chunk(doc_type="caselaw", act_code="2019 SCC 65")
for _ in range(n)]
chunks.append(chunk(doc_type="legislation")) # index n
idx = bare_index(chunks)
out = idx._diversify(list(range(n + 1)))
kept, deferred = out[:SOURCE_CAP + 1], out[SOURCE_CAP + 1:]
self.assertIn(n, kept) # legislation never capped
self.assertEqual(
sum(1 for i in kept if idx.chunks[i]["doc_type"] == "caselaw"),
SOURCE_CAP)
self.assertEqual(len(deferred), n - SOURCE_CAP)
def test_does_not_cap_agreements(self):
n = SOURCE_CAP + 3
idx = bare_index([chunk(doc_type="agreement", act_code="FB",
section=str(i)) for i in range(n)])
out = idx._diversify(list(range(n)))
self.assertEqual(out, list(range(n))) # uncapped: order intact
class EnsurePrimaryTests(unittest.TestCase):
def test_pulls_primary_into_a_caselaw_dominated_top_k(self):
idx = bare_index([
chunk(doc_type="caselaw", act_code="A"),
chunk(doc_type="caselaw", act_code="B"),
chunk(doc_type="caselaw", act_code="C"),
chunk(doc_type="legislation"),
chunk(doc_type="legislation"),
])
out = idx._ensure_primary([0, 1, 2, 3, 4], top_k=3, q_tokens=set())
top = out[:3]
n_prim = sum(1 for i in top
if idx.chunks[i]["doc_type"] == "legislation")
self.assertGreaterEqual(n_prim, 2)
self.assertEqual(out[0], 0) # the #1 hit is preserved
def test_no_op_when_primary_already_present(self):
idx = bare_index([
chunk(doc_type="legislation"),
chunk(doc_type="legislation"),
chunk(doc_type="caselaw", act_code="A"),
])
self.assertEqual(
idx._ensure_primary([0, 1, 2], top_k=3, q_tokens=set()),
[0, 1, 2])
def test_counts_agreements_as_primary(self):
# An agreement query that surfaces only case-law in top_k should
# have the agreement article pulled in -- not just legislation.
idx = bare_index([
chunk(doc_type="caselaw", act_code="A"),
chunk(doc_type="caselaw", act_code="B"),
chunk(doc_type="caselaw", act_code="C"),
chunk(doc_type="agreement", act_code="FB", section="17",
marginal_note="discipline"),
chunk(doc_type="agreement", act_code="FB", section="25",
marginal_note="hours of work"),
])
out = idx._ensure_primary([0, 1, 2, 3, 4], top_k=3, q_tokens=set())
top_doc_types = [idx.chunks[i]["doc_type"] for i in out[:3]]
self.assertGreaterEqual(top_doc_types.count("agreement"), 2)
class DocTypeFlagTests(unittest.TestCase):
"""_build_note_tokens also flags regulations and agreement back-matter."""
def setUp(self):
self.idx = bare_index([
chunk(doc_type="legislation", act_code="I-2.5"),
chunk(doc_type="legislation", act_code="SOR-2002-227"),
chunk(doc_type="legislation", act_code="C.R.C.,_c._1041"),
chunk(doc_type="agreement", act_code="FB", section="17"),
chunk(doc_type="agreement", act_code="FB", section=""),
chunk(doc_type="memorandum", act_code="D-Memo", section="D1-1-1",
marginal_note="Guidelines", part="Importing goods"),
])
self.idx._build_note_tokens()
def test_regulation_flag(self):
self.assertEqual(self.idx._is_regulation,
[False, True, True, False, False, False])
def test_agreement_backmatter_flag(self):
self.assertEqual(self.idx._is_backmatter,
[False, False, False, False, True, False])
def test_memorandum_title_tokens_come_from_part(self):
# A memo's marginal note is generic; its title is the 'part' field.
self.assertEqual(self.idx._note_tokens[5], set(tokenize("Importing goods")))
class CosurfaceAppendixTests(unittest.TestCase):
"""_cosurface_appendices pulls a directive appendix into the result set
when a directive result cites it but retrieval missed it."""
def _idx(self):
idx = bare_index([
chunk(doc_type="directive", act_code="d10", marginal_note="Meals",
text="paid the meal allowance at the rates in Appendix C."),
chunk(doc_type="directive", act_code="d10",
marginal_note="Appendix C - Allowances", text="rate tables"),
chunk(doc_type="directive", act_code="d10",
marginal_note="Appendix B - Kilometric Rates", text="km rates"),
])
idx._build_appendix_index()
return idx
def test_cited_appendix_is_pulled_in(self):
self.assertEqual(self._idx()._cosurface_appendices([0]), [0, 1])
def test_no_duplicate_when_already_present(self):
self.assertEqual(self._idx()._cosurface_appendices([0, 1]), [0, 1])
def test_uncited_appendix_is_left_out(self):
# result 0 cites only Appendix C, so Appendix B (index 2) stays out.
self.assertNotIn(2, self._idx()._cosurface_appendices([0]))
def test_cross_directive_citation_is_left_alone(self):
idx = bare_index([
chunk(doc_type="directive", act_code="d10", marginal_note="A section",
text="see Appendix C of the NJC Travel Directive"),
chunk(doc_type="directive", act_code="d10",
marginal_note="Appendix C - Allowances", text="tables"),
])
idx._build_appendix_index()
self.assertEqual(idx._cosurface_appendices([0]), [0])
def test_cap_keeps_the_most_cited_appendix(self):
# Four appendices are cited; Appendix A by two sections, the rest once.
# With the cap exceeded, the twice-cited appendix must survive.
idx = bare_index([
chunk(doc_type="directive", act_code="d1", marginal_note="S1",
text="see Appendix A"),
chunk(doc_type="directive", act_code="d1", marginal_note="S2",
text="see Appendix A; see Appendix B"),
chunk(doc_type="directive", act_code="d1", marginal_note="S3",
text="see Appendix C; see Appendix D"),
chunk(doc_type="directive", act_code="d1", marginal_note="Appendix A"),
chunk(doc_type="directive", act_code="d1", marginal_note="Appendix B"),
chunk(doc_type="directive", act_code="d1", marginal_note="Appendix C"),
chunk(doc_type="directive", act_code="d1", marginal_note="Appendix D"),
])
idx._build_appendix_index()
out = idx._cosurface_appendices([0, 1, 2])
self.assertEqual(len(out), 3 + APPENDIX_CAP) # cap respected
self.assertIn(3, out) # Appendix A survives
if __name__ == "__main__":
unittest.main()