File size: 2,542 Bytes
99af1d9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | """Tests for src.core.determinism."""
from __future__ import annotations
import os
import pyarrow as pa
from src.core import determinism
class TestPinThreads:
def test_sets_omp_env_var(self):
original = os.environ.pop("OMP_NUM_THREADS", None)
try:
determinism.pin_threads()
assert os.environ["OMP_NUM_THREADS"] == "1"
finally:
if original is None:
os.environ.pop("OMP_NUM_THREADS", None)
else:
os.environ["OMP_NUM_THREADS"] = original
def test_sets_openblas_env_var(self):
original = os.environ.pop("OPENBLAS_NUM_THREADS", None)
try:
determinism.pin_threads()
assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
finally:
if original is None:
os.environ.pop("OPENBLAS_NUM_THREADS", None)
else:
os.environ["OPENBLAS_NUM_THREADS"] = original
def test_sets_mkl_env_var(self):
original = os.environ.pop("MKL_NUM_THREADS", None)
try:
determinism.pin_threads()
assert os.environ["MKL_NUM_THREADS"] == "1"
finally:
if original is None:
os.environ.pop("MKL_NUM_THREADS", None)
else:
os.environ["MKL_NUM_THREADS"] = original
def test_pins_pyarrow_cpu_count_to_1(self):
pa.set_cpu_count(4)
determinism.pin_threads()
assert pa.cpu_count() == 1
def test_pins_pyarrow_io_thread_count_to_1(self):
pa.set_io_thread_count(4)
determinism.pin_threads()
assert pa.io_thread_count() == 1
def test_does_not_override_existing_env(self):
"""User explicitly setting OMP_NUM_THREADS=2 must win — pin_threads()
uses os.environ.setdefault so an upstream override is preserved."""
original = os.environ.get("OMP_NUM_THREADS")
os.environ["OMP_NUM_THREADS"] = "2"
try:
determinism.pin_threads()
assert os.environ["OMP_NUM_THREADS"] == "2"
finally:
if original is None:
os.environ.pop("OMP_NUM_THREADS", None)
else:
os.environ["OMP_NUM_THREADS"] = original
def test_idempotent(self):
determinism.pin_threads()
determinism.pin_threads()
assert pa.cpu_count() == 1
assert os.environ["OMP_NUM_THREADS"] == "1"
assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
assert os.environ["MKL_NUM_THREADS"] == "1"
|