| from __future__ import annotations |
|
|
| import ast |
| import unittest |
| from pathlib import Path |
|
|
|
|
| APP_PATH = Path(__file__).resolve().parents[1] / "app_gsplat.py" |
|
|
|
|
| def _load_tree() -> ast.Module: |
| return ast.parse(APP_PATH.read_text(encoding="utf-8")) |
|
|
|
|
| def _get_function(tree: ast.Module, name: str) -> ast.FunctionDef: |
| for node in tree.body: |
| if isinstance(node, ast.FunctionDef) and node.name == name: |
| return node |
| raise AssertionError(f"Function {name!r} not found in app_gsplat.py") |
|
|
|
|
| def _called_names(function_node: ast.FunctionDef) -> set[str]: |
| names: set[str] = set() |
| for node in ast.walk(function_node): |
| if isinstance(node, ast.Call): |
| if isinstance(node.func, ast.Name): |
| names.add(node.func.id) |
| elif isinstance(node.func, ast.Attribute): |
| names.add(node.func.attr) |
| return names |
|
|
|
|
| class AppGsplatArchitectureTests(unittest.TestCase): |
| def test_no_near_stack_imports(self) -> None: |
| source = APP_PATH.read_text(encoding="utf-8") |
|
|
| self.assertNotIn("NeARImageToRelightable3DPipeline", source) |
| for line in source.splitlines(): |
| stripped = line.strip() |
| if stripped.startswith("#") or not stripped: |
| continue |
| self.assertFalse( |
| stripped.startswith(("import trellis", "from trellis")), |
| msg=f"Unexpected trellis import: {line!r}", |
| ) |
| self.assertFalse( |
| stripped.startswith(("import hy3dshape", "from hy3dshape")), |
| msg=f"Unexpected hy3dshape import: {line!r}", |
| ) |
|
|
| def test_raster_helper_calls_gsplat_rasterization(self) -> None: |
| raster = _get_function(_load_tree(), "_raster_rgb_ed") |
| called = _called_names(raster) |
|
|
| self.assertIn("rasterize", called) |
|
|
| def test_run_probe_is_gpu_decorated_and_logs_entry(self) -> None: |
| source = APP_PATH.read_text(encoding="utf-8") |
|
|
| self.assertIn("@GPU", source) |
| self.assertIn("def run_once", source) |
| self.assertIn("torch.cuda.is_available()", source) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|