NeAR / tests /test_app_gsplat_architecture.py
luh1124's picture
app: CPU preload Hunyuan+NeAR, cuda move in GPU callbacks; drop gsplat warmup
c98c836
from __future__ import annotations
import ast
import unittest
from pathlib import Path
APP_PATH = Path(__file__).resolve().parents[1] / "app_gsplat.py"
def _load_tree() -> ast.Module:
return ast.parse(APP_PATH.read_text(encoding="utf-8"))
def _get_function(tree: ast.Module, name: str) -> ast.FunctionDef:
for node in tree.body:
if isinstance(node, ast.FunctionDef) and node.name == name:
return node
raise AssertionError(f"Function {name!r} not found in app_gsplat.py")
def _called_names(function_node: ast.FunctionDef) -> set[str]:
names: set[str] = set()
for node in ast.walk(function_node):
if isinstance(node, ast.Call):
if isinstance(node.func, ast.Name):
names.add(node.func.id)
elif isinstance(node.func, ast.Attribute):
names.add(node.func.attr)
return names
class AppGsplatArchitectureTests(unittest.TestCase):
def test_no_near_stack_imports(self) -> None:
source = APP_PATH.read_text(encoding="utf-8")
self.assertNotIn("NeARImageToRelightable3DPipeline", source)
for line in source.splitlines():
stripped = line.strip()
if stripped.startswith("#") or not stripped:
continue
self.assertFalse(
stripped.startswith(("import trellis", "from trellis")),
msg=f"Unexpected trellis import: {line!r}",
)
self.assertFalse(
stripped.startswith(("import hy3dshape", "from hy3dshape")),
msg=f"Unexpected hy3dshape import: {line!r}",
)
def test_raster_helper_calls_gsplat_rasterization(self) -> None:
raster = _get_function(_load_tree(), "_raster_rgb_ed")
called = _called_names(raster)
self.assertIn("rasterize", called)
def test_run_probe_is_gpu_decorated_and_logs_entry(self) -> None:
source = APP_PATH.read_text(encoding="utf-8")
self.assertIn("@GPU", source)
self.assertIn("def run_once", source)
self.assertIn("torch.cuda.is_available()", source)
if __name__ == "__main__":
unittest.main()