File size: 2,542 Bytes
99af1d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""Tests for src.core.determinism."""
from __future__ import annotations

import os

import pyarrow as pa

from src.core import determinism


class TestPinThreads:
    def test_sets_omp_env_var(self):
        original = os.environ.pop("OMP_NUM_THREADS", None)
        try:
            determinism.pin_threads()
            assert os.environ["OMP_NUM_THREADS"] == "1"
        finally:
            if original is None:
                os.environ.pop("OMP_NUM_THREADS", None)
            else:
                os.environ["OMP_NUM_THREADS"] = original

    def test_sets_openblas_env_var(self):
        original = os.environ.pop("OPENBLAS_NUM_THREADS", None)
        try:
            determinism.pin_threads()
            assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
        finally:
            if original is None:
                os.environ.pop("OPENBLAS_NUM_THREADS", None)
            else:
                os.environ["OPENBLAS_NUM_THREADS"] = original

    def test_sets_mkl_env_var(self):
        original = os.environ.pop("MKL_NUM_THREADS", None)
        try:
            determinism.pin_threads()
            assert os.environ["MKL_NUM_THREADS"] == "1"
        finally:
            if original is None:
                os.environ.pop("MKL_NUM_THREADS", None)
            else:
                os.environ["MKL_NUM_THREADS"] = original

    def test_pins_pyarrow_cpu_count_to_1(self):
        pa.set_cpu_count(4)
        determinism.pin_threads()
        assert pa.cpu_count() == 1

    def test_pins_pyarrow_io_thread_count_to_1(self):
        pa.set_io_thread_count(4)
        determinism.pin_threads()
        assert pa.io_thread_count() == 1

    def test_does_not_override_existing_env(self):
        """User explicitly setting OMP_NUM_THREADS=2 must win — pin_threads()
        uses os.environ.setdefault so an upstream override is preserved."""
        original = os.environ.get("OMP_NUM_THREADS")
        os.environ["OMP_NUM_THREADS"] = "2"
        try:
            determinism.pin_threads()
            assert os.environ["OMP_NUM_THREADS"] == "2"
        finally:
            if original is None:
                os.environ.pop("OMP_NUM_THREADS", None)
            else:
                os.environ["OMP_NUM_THREADS"] = original

    def test_idempotent(self):
        determinism.pin_threads()
        determinism.pin_threads()
        assert pa.cpu_count() == 1
        assert os.environ["OMP_NUM_THREADS"] == "1"
        assert os.environ["OPENBLAS_NUM_THREADS"] == "1"
        assert os.environ["MKL_NUM_THREADS"] == "1"