NeAR / tests /test_trellis_from_pretrained_cpu_load.py
luh1124's picture
fix(load): CPU-only checkpoint load for ZeroGPU main process
31f61c1
from __future__ import annotations
import ast
import unittest
from pathlib import Path
INIT_PATH = Path(__file__).resolve().parents[1] / "trellis" / "models" / "__init__.py"
def _from_pretrained_function() -> ast.FunctionDef:
tree = ast.parse(INIT_PATH.read_text(encoding="utf-8"))
for node in tree.body:
if isinstance(node, ast.FunctionDef) and node.name == "from_pretrained":
return node
raise AssertionError("from_pretrained not found")
class TrellisFromPretrainedCpuLoadTests(unittest.TestCase):
def test_from_pretrained_avoids_cuda_probe_and_loads_pt_on_cpu(self) -> None:
fn = _from_pretrained_function()
source_segment = ast.get_source_segment(INIT_PATH.read_text(encoding="utf-8"), fn)
assert source_segment is not None
self.assertNotIn("torch.cuda.is_available", source_segment)
self.assertIn('map_location="cpu"', source_segment)
if __name__ == "__main__":
unittest.main()