from __future__ import annotations import ast import unittest from pathlib import Path INIT_PATH = Path(__file__).resolve().parents[1] / "trellis" / "models" / "__init__.py" def _from_pretrained_function() -> ast.FunctionDef: tree = ast.parse(INIT_PATH.read_text(encoding="utf-8")) for node in tree.body: if isinstance(node, ast.FunctionDef) and node.name == "from_pretrained": return node raise AssertionError("from_pretrained not found") class TrellisFromPretrainedCpuLoadTests(unittest.TestCase): def test_from_pretrained_avoids_cuda_probe_and_loads_pt_on_cpu(self) -> None: fn = _from_pretrained_function() source_segment = ast.get_source_segment(INIT_PATH.read_text(encoding="utf-8"), fn) assert source_segment is not None self.assertNotIn("torch.cuda.is_available", source_segment) self.assertIn('map_location="cpu"', source_segment) if __name__ == "__main__": unittest.main()