| from __future__ import annotations | |
| import ast | |
| import unittest | |
| from pathlib import Path | |
| INIT_PATH = Path(__file__).resolve().parents[1] / "trellis" / "models" / "__init__.py" | |
| def _from_pretrained_function() -> ast.FunctionDef: | |
| tree = ast.parse(INIT_PATH.read_text(encoding="utf-8")) | |
| for node in tree.body: | |
| if isinstance(node, ast.FunctionDef) and node.name == "from_pretrained": | |
| return node | |
| raise AssertionError("from_pretrained not found") | |
| class TrellisFromPretrainedCpuLoadTests(unittest.TestCase): | |
| def test_from_pretrained_avoids_cuda_probe_and_loads_pt_on_cpu(self) -> None: | |
| fn = _from_pretrained_function() | |
| source_segment = ast.get_source_segment(INIT_PATH.read_text(encoding="utf-8"), fn) | |
| assert source_segment is not None | |
| self.assertNotIn("torch.cuda.is_available", source_segment) | |
| self.assertIn('map_location="cpu"', source_segment) | |
| if __name__ == "__main__": | |
| unittest.main() | |