123123aa123 commited on
Commit
fdf498d
·
verified ·
1 Parent(s): 66faffa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -98,7 +98,7 @@ def load_models():
98
 
99
  if vggt_model is None:
100
  print("Loading VGGT...")
101
- vggt_model = VGGT.from_pretrained(VGGT_PATH).to(device).eval()
102
 
103
  if wan_pipe is None:
104
  print("Loading Wan...")
@@ -133,10 +133,10 @@ def load_models():
133
  wan_pipe.load_lora(wan_pipe.dit, state_dict=lora_sd, alpha=1)
134
  wan_pipe.dit.load_state_dict(adapter_sd, strict=False)
135
 
136
- wan_pipe.to(device)
137
- wan_pipe.to(dtype=torch.bfloat16)
138
-
139
 
 
140
  # =========================
141
  # Renderer
142
  # =========================
@@ -268,8 +268,6 @@ def build_estimate_rel(x, y, z, phi, theta):
268
  @spaces.GPU
269
  def infer(image, prompt, seed):
270
 
271
- load_models()
272
-
273
 
274
  img = image.convert("RGB")
275
 
 
98
 
99
  if vggt_model is None:
100
  print("Loading VGGT...")
101
+ vggt_model = VGGT.from_pretrained(VGGT_PATH, torch_dtype=dtype).to(device).eval()
102
 
103
  if wan_pipe is None:
104
  print("Loading Wan...")
 
133
  wan_pipe.load_lora(wan_pipe.dit, state_dict=lora_sd, alpha=1)
134
  wan_pipe.dit.load_state_dict(adapter_sd, strict=False)
135
 
136
+ #wan_pipe.to(device)
137
+ #wan_pipe.to(dtype=torch.bfloat16)
 
138
 
139
+ load_models()
140
  # =========================
141
  # Renderer
142
  # =========================
 
268
  @spaces.GPU
269
  def infer(image, prompt, seed):
270
 
 
 
271
 
272
  img = image.convert("RGB")
273