Spaces:
Sleeping
Sleeping
| import os | |
| from unittest.mock import patch | |
| from server.device import select_device | |
| def test_env_override_cuda(): | |
| with patch.dict(os.environ, {"CHATTERBOX_DEVICE": "cuda"}): | |
| assert select_device() == "cuda" | |
| def test_env_override_mps(): | |
| with patch.dict(os.environ, {"CHATTERBOX_DEVICE": "MPS"}): | |
| assert select_device() == "mps" | |
| def test_env_override_cpu(): | |
| with patch.dict(os.environ, {"CHATTERBOX_DEVICE": "cpu"}): | |
| assert select_device() == "cpu" | |
| def test_invalid_env_falls_through_to_autodetect(): | |
| with patch.dict(os.environ, {"CHATTERBOX_DEVICE": "tpu"}, clear=False): | |
| with patch("server.device._cuda_available", return_value=True): | |
| assert select_device() == "cuda" | |
| def test_autodetect_prefers_cuda_over_mps(): | |
| with patch.dict(os.environ, {}, clear=True): | |
| with patch("server.device._cuda_available", return_value=True), \ | |
| patch("server.device._mps_available", return_value=True): | |
| assert select_device() == "cuda" | |
| def test_autodetect_uses_mps_when_no_cuda(): | |
| with patch.dict(os.environ, {}, clear=True): | |
| with patch("server.device._cuda_available", return_value=False), \ | |
| patch("server.device._mps_available", return_value=True): | |
| assert select_device() == "mps" | |
| def test_autodetect_falls_back_to_cpu(): | |
| with patch.dict(os.environ, {}, clear=True): | |
| with patch("server.device._cuda_available", return_value=False), \ | |
| patch("server.device._mps_available", return_value=False): | |
| assert select_device() == "cpu" | |