feat(core): extract pin_threads() helper for determinism
Browse filesCo-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- src/core/determinism.py +30 -0
- tests/core/test_determinism.py +75 -0
src/core/determinism.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Threading determinism: pin BLAS / OpenMP / pyarrow to single-threaded mode.
|
| 2 |
+
|
| 3 |
+
Multi-threaded floating-point reductions reorder operands non-deterministically
|
| 4 |
+
on each call, breaking the byte-identity guarantee in AGENTS.md §4 rule 3. Each
|
| 5 |
+
pipeline calls `pin_threads()` at import time to lock the process to a single
|
| 6 |
+
thread before any numerical work runs.
|
| 7 |
+
|
| 8 |
+
Honors pre-set env vars: if the caller exported `OMP_NUM_THREADS=4` upstream,
|
| 9 |
+
that value is preserved (we use `setdefault`, not `setitem`). The user is
|
| 10 |
+
responsible for the determinism trade-off in that case.
|
| 11 |
+
"""
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
import pyarrow as pa
|
| 17 |
+
|
| 18 |
+
_ENV_VARS: tuple[str, ...] = (
|
| 19 |
+
"OMP_NUM_THREADS",
|
| 20 |
+
"OPENBLAS_NUM_THREADS",
|
| 21 |
+
"MKL_NUM_THREADS",
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def pin_threads() -> None:
|
| 26 |
+
"""Pin BLAS / OpenMP / pyarrow to single-threaded mode (idempotent)."""
|
| 27 |
+
for var in _ENV_VARS:
|
| 28 |
+
os.environ.setdefault(var, "1")
|
| 29 |
+
pa.set_cpu_count(1)
|
| 30 |
+
pa.set_io_thread_count(1)
|
tests/core/test_determinism.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for src.core.determinism."""
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
import pyarrow as pa
|
| 7 |
+
|
| 8 |
+
from src.core import determinism
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TestPinThreads:
|
| 12 |
+
def test_sets_omp_env_var(self):
|
| 13 |
+
original = os.environ.pop("OMP_NUM_THREADS", None)
|
| 14 |
+
try:
|
| 15 |
+
determinism.pin_threads()
|
| 16 |
+
assert os.environ["OMP_NUM_THREADS"] == "1"
|
| 17 |
+
finally:
|
| 18 |
+
if original is None:
|
| 19 |
+
os.environ.pop("OMP_NUM_THREADS", None)
|
| 20 |
+
else:
|
| 21 |
+
os.environ["OMP_NUM_THREADS"] = original
|
| 22 |
+
|
| 23 |
+
def test_sets_openblas_env_var(self):
|
| 24 |
+
original = os.environ.pop("OPENBLAS_NUM_THREADS", None)
|
| 25 |
+
try:
|
| 26 |
+
determinism.pin_threads()
|
| 27 |
+
assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
|
| 28 |
+
finally:
|
| 29 |
+
if original is None:
|
| 30 |
+
os.environ.pop("OPENBLAS_NUM_THREADS", None)
|
| 31 |
+
else:
|
| 32 |
+
os.environ["OPENBLAS_NUM_THREADS"] = original
|
| 33 |
+
|
| 34 |
+
def test_sets_mkl_env_var(self):
|
| 35 |
+
original = os.environ.pop("MKL_NUM_THREADS", None)
|
| 36 |
+
try:
|
| 37 |
+
determinism.pin_threads()
|
| 38 |
+
assert os.environ["MKL_NUM_THREADS"] == "1"
|
| 39 |
+
finally:
|
| 40 |
+
if original is None:
|
| 41 |
+
os.environ.pop("MKL_NUM_THREADS", None)
|
| 42 |
+
else:
|
| 43 |
+
os.environ["MKL_NUM_THREADS"] = original
|
| 44 |
+
|
| 45 |
+
def test_pins_pyarrow_cpu_count_to_1(self):
|
| 46 |
+
pa.set_cpu_count(4)
|
| 47 |
+
determinism.pin_threads()
|
| 48 |
+
assert pa.cpu_count() == 1
|
| 49 |
+
|
| 50 |
+
def test_pins_pyarrow_io_thread_count_to_1(self):
|
| 51 |
+
pa.set_io_thread_count(4)
|
| 52 |
+
determinism.pin_threads()
|
| 53 |
+
assert pa.io_thread_count() == 1
|
| 54 |
+
|
| 55 |
+
def test_does_not_override_existing_env(self):
|
| 56 |
+
"""User explicitly setting OMP_NUM_THREADS=2 must win — pin_threads()
|
| 57 |
+
uses os.environ.setdefault so an upstream override is preserved."""
|
| 58 |
+
original = os.environ.get("OMP_NUM_THREADS")
|
| 59 |
+
os.environ["OMP_NUM_THREADS"] = "2"
|
| 60 |
+
try:
|
| 61 |
+
determinism.pin_threads()
|
| 62 |
+
assert os.environ["OMP_NUM_THREADS"] == "2"
|
| 63 |
+
finally:
|
| 64 |
+
if original is None:
|
| 65 |
+
os.environ.pop("OMP_NUM_THREADS", None)
|
| 66 |
+
else:
|
| 67 |
+
os.environ["OMP_NUM_THREADS"] = original
|
| 68 |
+
|
| 69 |
+
def test_idempotent(self):
|
| 70 |
+
determinism.pin_threads()
|
| 71 |
+
determinism.pin_threads()
|
| 72 |
+
assert pa.cpu_count() == 1
|
| 73 |
+
assert os.environ["OMP_NUM_THREADS"] == "1"
|
| 74 |
+
assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
|
| 75 |
+
assert os.environ["MKL_NUM_THREADS"] == "1"
|