mekosotto Claude Sonnet 4.6 commited on
Commit
99af1d9
·
1 Parent(s): 9e9b239

feat(core): extract pin_threads() helper for determinism

Browse files

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

src/core/determinism.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Threading determinism: pin BLAS / OpenMP / pyarrow to single-threaded mode.
2
+
3
+ Multi-threaded floating-point reductions reorder operands non-deterministically
4
+ on each call, breaking the byte-identity guarantee in AGENTS.md §4 rule 3. Each
5
+ pipeline calls `pin_threads()` at import time to lock the process to a single
6
+ thread before any numerical work runs.
7
+
8
+ Honors pre-set env vars: if the caller exported `OMP_NUM_THREADS=4` upstream,
9
+ that value is preserved (we use `setdefault`, not `setitem`). The user is
10
+ responsible for the determinism trade-off in that case.
11
+ """
12
+ from __future__ import annotations
13
+
14
+ import os
15
+
16
+ import pyarrow as pa
17
+
18
+ _ENV_VARS: tuple[str, ...] = (
19
+ "OMP_NUM_THREADS",
20
+ "OPENBLAS_NUM_THREADS",
21
+ "MKL_NUM_THREADS",
22
+ )
23
+
24
+
25
+ def pin_threads() -> None:
26
+ """Pin BLAS / OpenMP / pyarrow to single-threaded mode (idempotent)."""
27
+ for var in _ENV_VARS:
28
+ os.environ.setdefault(var, "1")
29
+ pa.set_cpu_count(1)
30
+ pa.set_io_thread_count(1)
tests/core/test_determinism.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for src.core.determinism."""
2
+ from __future__ import annotations
3
+
4
+ import os
5
+
6
+ import pyarrow as pa
7
+
8
+ from src.core import determinism
9
+
10
+
11
+ class TestPinThreads:
12
+ def test_sets_omp_env_var(self):
13
+ original = os.environ.pop("OMP_NUM_THREADS", None)
14
+ try:
15
+ determinism.pin_threads()
16
+ assert os.environ["OMP_NUM_THREADS"] == "1"
17
+ finally:
18
+ if original is None:
19
+ os.environ.pop("OMP_NUM_THREADS", None)
20
+ else:
21
+ os.environ["OMP_NUM_THREADS"] = original
22
+
23
+ def test_sets_openblas_env_var(self):
24
+ original = os.environ.pop("OPENBLAS_NUM_THREADS", None)
25
+ try:
26
+ determinism.pin_threads()
27
+ assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
28
+ finally:
29
+ if original is None:
30
+ os.environ.pop("OPENBLAS_NUM_THREADS", None)
31
+ else:
32
+ os.environ["OPENBLAS_NUM_THREADS"] = original
33
+
34
+ def test_sets_mkl_env_var(self):
35
+ original = os.environ.pop("MKL_NUM_THREADS", None)
36
+ try:
37
+ determinism.pin_threads()
38
+ assert os.environ["MKL_NUM_THREADS"] == "1"
39
+ finally:
40
+ if original is None:
41
+ os.environ.pop("MKL_NUM_THREADS", None)
42
+ else:
43
+ os.environ["MKL_NUM_THREADS"] = original
44
+
45
+ def test_pins_pyarrow_cpu_count_to_1(self):
46
+ pa.set_cpu_count(4)
47
+ determinism.pin_threads()
48
+ assert pa.cpu_count() == 1
49
+
50
+ def test_pins_pyarrow_io_thread_count_to_1(self):
51
+ pa.set_io_thread_count(4)
52
+ determinism.pin_threads()
53
+ assert pa.io_thread_count() == 1
54
+
55
+ def test_does_not_override_existing_env(self):
56
+ """User explicitly setting OMP_NUM_THREADS=2 must win — pin_threads()
57
+ uses os.environ.setdefault so an upstream override is preserved."""
58
+ original = os.environ.get("OMP_NUM_THREADS")
59
+ os.environ["OMP_NUM_THREADS"] = "2"
60
+ try:
61
+ determinism.pin_threads()
62
+ assert os.environ["OMP_NUM_THREADS"] == "2"
63
+ finally:
64
+ if original is None:
65
+ os.environ.pop("OMP_NUM_THREADS", None)
66
+ else:
67
+ os.environ["OMP_NUM_THREADS"] = original
68
+
69
+ def test_idempotent(self):
70
+ determinism.pin_threads()
71
+ determinism.pin_threads()
72
+ assert pa.cpu_count() == 1
73
+ assert os.environ["OMP_NUM_THREADS"] == "1"
74
+ assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
75
+ assert os.environ["MKL_NUM_THREADS"] == "1"