hackathon / tests /core /test_determinism.py
mekosotto's picture
feat(core): extract pin_threads() helper for determinism
99af1d9
"""Tests for src.core.determinism."""
from __future__ import annotations
import os
import pyarrow as pa
from src.core import determinism
class TestPinThreads:
def test_sets_omp_env_var(self):
original = os.environ.pop("OMP_NUM_THREADS", None)
try:
determinism.pin_threads()
assert os.environ["OMP_NUM_THREADS"] == "1"
finally:
if original is None:
os.environ.pop("OMP_NUM_THREADS", None)
else:
os.environ["OMP_NUM_THREADS"] = original
def test_sets_openblas_env_var(self):
original = os.environ.pop("OPENBLAS_NUM_THREADS", None)
try:
determinism.pin_threads()
assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
finally:
if original is None:
os.environ.pop("OPENBLAS_NUM_THREADS", None)
else:
os.environ["OPENBLAS_NUM_THREADS"] = original
def test_sets_mkl_env_var(self):
original = os.environ.pop("MKL_NUM_THREADS", None)
try:
determinism.pin_threads()
assert os.environ["MKL_NUM_THREADS"] == "1"
finally:
if original is None:
os.environ.pop("MKL_NUM_THREADS", None)
else:
os.environ["MKL_NUM_THREADS"] = original
def test_pins_pyarrow_cpu_count_to_1(self):
pa.set_cpu_count(4)
determinism.pin_threads()
assert pa.cpu_count() == 1
def test_pins_pyarrow_io_thread_count_to_1(self):
pa.set_io_thread_count(4)
determinism.pin_threads()
assert pa.io_thread_count() == 1
def test_does_not_override_existing_env(self):
"""User explicitly setting OMP_NUM_THREADS=2 must win — pin_threads()
uses os.environ.setdefault so an upstream override is preserved."""
original = os.environ.get("OMP_NUM_THREADS")
os.environ["OMP_NUM_THREADS"] = "2"
try:
determinism.pin_threads()
assert os.environ["OMP_NUM_THREADS"] == "2"
finally:
if original is None:
os.environ.pop("OMP_NUM_THREADS", None)
else:
os.environ["OMP_NUM_THREADS"] = original
def test_idempotent(self):
determinism.pin_threads()
determinism.pin_threads()
assert pa.cpu_count() == 1
assert os.environ["OMP_NUM_THREADS"] == "1"
assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
assert os.environ["MKL_NUM_THREADS"] == "1"