import os from unittest.mock import patch from server.device import select_device def test_env_override_cuda(): with patch.dict(os.environ, {"CHATTERBOX_DEVICE": "cuda"}): assert select_device() == "cuda" def test_env_override_mps(): with patch.dict(os.environ, {"CHATTERBOX_DEVICE": "MPS"}): assert select_device() == "mps" def test_env_override_cpu(): with patch.dict(os.environ, {"CHATTERBOX_DEVICE": "cpu"}): assert select_device() == "cpu" def test_invalid_env_falls_through_to_autodetect(): with patch.dict(os.environ, {"CHATTERBOX_DEVICE": "tpu"}, clear=False): with patch("server.device._cuda_available", return_value=True): assert select_device() == "cuda" def test_autodetect_prefers_cuda_over_mps(): with patch.dict(os.environ, {}, clear=True): with patch("server.device._cuda_available", return_value=True), \ patch("server.device._mps_available", return_value=True): assert select_device() == "cuda" def test_autodetect_uses_mps_when_no_cuda(): with patch.dict(os.environ, {}, clear=True): with patch("server.device._cuda_available", return_value=False), \ patch("server.device._mps_available", return_value=True): assert select_device() == "mps" def test_autodetect_falls_back_to_cpu(): with patch.dict(os.environ, {}, clear=True): with patch("server.device._cuda_available", return_value=False), \ patch("server.device._mps_available", return_value=False): assert select_device() == "cpu"