anonymouscla commited on
Commit
26b567b
·
verified ·
1 Parent(s): 6d0e4db

infer: accept HF Hub repo id for --adapter-dir

Browse files
Files changed (1) hide show
  1. infer.py +413 -0
infer.py ADDED
@@ -0,0 +1,413 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run inference with the anonymous judge LoRA adapter.
2
+
3
+ The script can either load files from a local directory or pull them
4
+ directly from the Hugging Face Hub. By default it points at the
5
+ companion repository ``anonymouscla/physground-judger9B``:
6
+
7
+ # From the Hub (no clone needed):
8
+ python infer.py --video demo.mp4 --caption "A ball rolls down a ramp." --metric SA
9
+ python infer.py --video demo.mp4 --caption "A ball rolls down a ramp." --law gravity
10
+
11
+ # From a local clone of the model repo:
12
+ python infer.py --adapter-dir /path/to/local/clone --video demo.mp4 \
13
+ --caption "A ball rolls down a ramp." --law gravity
14
+
15
+ It loads:
16
+ - adapter_config.json to find the base model
17
+ - adapter_model.safetensors through PEFT
18
+ - subq+human.yaml to render the scoring prompt
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import argparse
24
+ import json
25
+ import re
26
+ from pathlib import Path
27
+ from typing import Any
28
+
29
+ import torch
30
+ import yaml
31
+ from peft import PeftModel
32
+ from transformers import AutoProcessor
33
+
34
+
35
+ GENERAL_SUB_QUESTIONS: dict[str, list[str]] = {
36
+ "SA": [
37
+ "Are the main objects in the caption present in the video?",
38
+ "Are the key actions or interactions from the caption visible?",
39
+ "Are important scene attributes and relationships preserved?",
40
+ "Does the video avoid major contradictions to the caption?",
41
+ ],
42
+ "PTV": [
43
+ "Do causes appear before their effects?",
44
+ "Do physical events unfold in a plausible temporal order?",
45
+ "Are motion transitions continuous rather than abrupt jumps or loops?",
46
+ "Does the sequence avoid impossible reversals or repeated resets?",
47
+ ],
48
+ "persistence": [
49
+ "Do objects maintain consistent existence throughout the video?",
50
+ "Do objects keep a stable shape, size, color, and texture?",
51
+ "Do objects avoid disappearing, appearing, or transforming unexpectedly?",
52
+ "Do objects preserve identity through motion and brief occlusion?",
53
+ ],
54
+ }
55
+
56
+
57
+ PHYSICAL_CRITERIA: dict[str, str] = {
58
+ "gravity": "Do unsupported objects fall downward? Do thrown objects follow a curved trajectory? Does poured liquid fall with gravity?",
59
+ "inertia": "Do stationary objects remain still unless acted upon? Do moving objects maintain their motion unless stopped by friction, collision, or an obstacle?",
60
+ "momentum": "After collision, push, or pull, is the direction of motion reasonable? Ignore speed magnitude.",
61
+ "impenetrability": "Do objects maintain impenetrability -- no passing through each other?",
62
+ "collision": "After impact, is there reasonable bounce/shatter/deformation? Does response match impact force?",
63
+ "material": "Does each material respond according to its properties? (glass shatters, rubber bounces, metal is rigid, cloth deforms softly, etc.)",
64
+ "buoyancy": "Do dense objects sink? Do wood/plastic float?",
65
+ "displacement": "When you add more liquid or put an object into it, does the liquid level rise in a realistic way? Does it overflow when full?",
66
+ "flow_dynamics": "Does the liquid's overall motion behave realistically over time -- flowing along surfaces, spreading, draining naturally?",
67
+ "boundary_interaction": "When the liquid hits a boundary such as a rock face, container wall, or floor, does it respond realistically? Do local splash, rebound, or split patterns on impact look physically plausible?",
68
+ "fluid_continuity": "Does the liquid avoid disappearing or appearing out of nowhere? Small splashes that briefly break apart are okay.",
69
+ "reflection": "Does the reflection roughly match objects and colors in the scene, and avoid completely unrelated content?",
70
+ "shadow": "Are shadow directions consistent with light source? Do shadows move with objects?",
71
+ }
72
+
73
+
74
+ PHYSICAL_SUB_QUESTIONS: dict[str, list[str]] = {
75
+ "gravity": [
76
+ "Do unsupported objects or liquids move downward over time?",
77
+ "Do thrown or falling objects follow a plausible gravity-driven path?",
78
+ "Does the video avoid objects floating or rising without support?",
79
+ ],
80
+ "inertia": [
81
+ "Do stationary objects remain still unless a visible force acts on them?",
82
+ "Do moving objects continue plausibly until friction, collision, or an obstacle changes their motion?",
83
+ "Does the video avoid unexplained starts, stops, or direction changes?",
84
+ ],
85
+ "momentum": [
86
+ "After contact, push, pull, or collision, are motion directions plausible?",
87
+ "Does the reacting object move in a direction consistent with the interaction?",
88
+ "Does the video avoid impossible reversals or unrelated motion changes?",
89
+ ],
90
+ "impenetrability": [
91
+ "Do solid objects avoid passing through one another?",
92
+ "Do contacts and overlaps remain physically plausible?",
93
+ "Does the video avoid obvious clipping or penetration artifacts?",
94
+ ],
95
+ "collision": [
96
+ "Does impact cause a plausible bounce, break, deformation, or transfer of motion?",
97
+ "Is the response direction consistent with the collision?",
98
+ "Does the response avoid being much too weak, too strong, or unrelated to the impact?",
99
+ ],
100
+ "material": [
101
+ "Do objects respond consistently with their apparent material?",
102
+ "Are rigid, soft, brittle, elastic, or fluid-like objects animated appropriately?",
103
+ "Does the video avoid material behavior that contradicts the scene?",
104
+ ],
105
+ "buoyancy": [
106
+ "Do objects sink or float in a way consistent with apparent density?",
107
+ "Does the floating or sinking behavior stay stable over time?",
108
+ "Does the video avoid unsupported hovering or impossible underwater motion?",
109
+ ],
110
+ "displacement": [
111
+ "Does liquid level rise when volume is added or an object enters it?",
112
+ "Does overflow happen only when the container is plausibly full?",
113
+ "Does the liquid volume remain visually plausible?",
114
+ ],
115
+ "flow_dynamics": [
116
+ "Does liquid flow along surfaces, spread, or drain naturally?",
117
+ "Does the flow direction follow gravity and boundaries?",
118
+ "Does the video avoid abrupt stops, reversals, or unsupported uphill flow?",
119
+ ],
120
+ "boundary_interaction": [
121
+ "Does liquid react plausibly when hitting a wall, floor, container, or obstacle?",
122
+ "Are splash, rebound, or split patterns locally plausible?",
123
+ "Does the liquid remain consistent after interacting with boundaries?",
124
+ ],
125
+ "fluid_continuity": [
126
+ "Does liquid avoid disappearing or appearing without cause?",
127
+ "Does the amount of liquid remain broadly consistent?",
128
+ "Are splashes and separations temporary and physically plausible?",
129
+ ],
130
+ "reflection": [
131
+ "Does the reflection match nearby objects, colors, and motion?",
132
+ "Does the reflected content stay spatially consistent with the scene?",
133
+ "Does the video avoid unrelated or impossible reflection content?",
134
+ ],
135
+ "shadow": [
136
+ "Are shadows consistent with the apparent light source direction?",
137
+ "Do shadows move with the objects that cast them?",
138
+ "Does the video avoid missing, detached, or contradictory shadows?",
139
+ ],
140
+ }
141
+
142
+
143
+ def load_json(path: Path) -> dict[str, Any]:
144
+ with path.open() as f:
145
+ return json.load(f)
146
+
147
+
148
+ def load_yaml(path: Path) -> dict[str, Any]:
149
+ with path.open() as f:
150
+ return yaml.safe_load(f)
151
+
152
+
153
+ def questions_block(questions: list[str]) -> str:
154
+ return "\n".join(f"{idx}. {question}" for idx, question in enumerate(questions, 1))
155
+
156
+
157
+ def build_prompt(
158
+ cfg: dict[str, Any],
159
+ caption: str,
160
+ *,
161
+ metric: str | None = None,
162
+ law: str | None = None,
163
+ criteria: str | None = None,
164
+ ) -> tuple[str, str, str]:
165
+ if metric:
166
+ if metric not in GENERAL_SUB_QUESTIONS:
167
+ raise ValueError(f"unknown metric: {metric}")
168
+ prompt = cfg["eval_prompts"][metric].format(
169
+ prompt=caption,
170
+ questions_block=questions_block(GENERAL_SUB_QUESTIONS[metric]),
171
+ )
172
+ return cfg["system_prompt"], prompt, metric
173
+
174
+ if not law:
175
+ raise ValueError("either --metric or --law is required")
176
+ if law not in PHYSICAL_CRITERIA:
177
+ raise ValueError(f"unknown law: {law}")
178
+ prompt = cfg["physical_template"].format(
179
+ prompt=caption,
180
+ law=law,
181
+ criteria=criteria or PHYSICAL_CRITERIA[law],
182
+ questions_block=questions_block(PHYSICAL_SUB_QUESTIONS[law]),
183
+ )
184
+ return cfg["system_prompt"], prompt, law
185
+
186
+
187
+ def load_base_model(base_id: str, dtype: torch.dtype, device_map: str):
188
+ errors: list[str] = []
189
+ for class_name in (
190
+ "AutoModelForImageTextToText",
191
+ "AutoModelForVision2Seq",
192
+ "AutoModelForCausalLM",
193
+ ):
194
+ try:
195
+ module = __import__("transformers", fromlist=[class_name])
196
+ model_cls = getattr(module, class_name)
197
+ return model_cls.from_pretrained(
198
+ base_id,
199
+ torch_dtype=dtype,
200
+ device_map=device_map,
201
+ trust_remote_code=True,
202
+ )
203
+ except Exception as exc: # pragma: no cover - depends on local transformers version
204
+ errors.append(f"{class_name}: {exc}")
205
+ raise RuntimeError("failed to load base model:\n" + "\n".join(errors))
206
+
207
+
208
+ def resolve_adapter_dir(source: str) -> Path:
209
+ """Return a local directory holding the adapter files.
210
+
211
+ If ``source`` is a directory containing ``adapter_config.json`` it is used
212
+ as-is. Otherwise ``source`` is interpreted as a HF Hub repo id and the
213
+ snapshot is downloaded into the local cache.
214
+ """
215
+ candidate = Path(source)
216
+ if candidate.is_dir() and (candidate / "adapter_config.json").exists():
217
+ return candidate
218
+ try:
219
+ from huggingface_hub import snapshot_download
220
+ except ImportError as exc:
221
+ raise ImportError(
222
+ "huggingface_hub is required to fetch the adapter from the Hub. "
223
+ "Install it with: pip install huggingface_hub"
224
+ ) from exc
225
+ return Path(snapshot_download(repo_id=source))
226
+
227
+
228
+ def load_model(adapter_source: str, dtype: torch.dtype, device_map: str) -> tuple[Any, Any, Path]:
229
+ adapter_dir = resolve_adapter_dir(adapter_source)
230
+ adapter_cfg = load_json(adapter_dir / "adapter_config.json")
231
+ base_id = adapter_cfg["base_model_name_or_path"]
232
+ processor = AutoProcessor.from_pretrained(base_id, trust_remote_code=True)
233
+ base = load_base_model(base_id, dtype=dtype, device_map=device_map)
234
+ model = PeftModel.from_pretrained(base, adapter_dir)
235
+ model.eval()
236
+ return processor, model, adapter_dir
237
+
238
+
239
+ def build_messages(system_prompt: str, user_prompt: str, video_path: Path) -> list[dict[str, Any]]:
240
+ return [
241
+ {"role": "system", "content": system_prompt},
242
+ {
243
+ "role": "user",
244
+ "content": [
245
+ {"type": "video", "video": str(video_path)},
246
+ {"type": "text", "text": user_prompt},
247
+ ],
248
+ },
249
+ ]
250
+
251
+
252
+ def prepare_inputs(
253
+ processor: Any,
254
+ messages: list[dict[str, Any]],
255
+ device: torch.device,
256
+ *,
257
+ fps: float,
258
+ max_pixels: int,
259
+ ) -> dict[str, Any]:
260
+ text = processor.apply_chat_template(
261
+ messages,
262
+ tokenize=False,
263
+ add_generation_prompt=True,
264
+ )
265
+
266
+ try:
267
+ from qwen_vl_utils import process_vision_info
268
+ except ImportError as exc:
269
+ raise ImportError(
270
+ "qwen-vl-utils is required for local video inference. "
271
+ "Install it with: pip install qwen-vl-utils[decord]"
272
+ ) from exc
273
+
274
+ for msg in messages:
275
+ content = msg.get("content")
276
+ if isinstance(content, list):
277
+ for item in content:
278
+ if item.get("type") == "video":
279
+ item.setdefault("fps", fps)
280
+ item.setdefault("max_pixels", max_pixels)
281
+
282
+ try:
283
+ image_inputs, video_inputs, video_kwargs = process_vision_info(
284
+ messages,
285
+ return_video_kwargs=True,
286
+ )
287
+ except TypeError:
288
+ image_inputs, video_inputs = process_vision_info(messages)
289
+ video_kwargs = {}
290
+
291
+ inputs = processor(
292
+ text=[text],
293
+ images=image_inputs,
294
+ videos=video_inputs,
295
+ padding=True,
296
+ return_tensors="pt",
297
+ **video_kwargs,
298
+ )
299
+ return inputs.to(device)
300
+
301
+
302
+ def decode_generated(processor: Any, inputs: dict[str, Any], generated_ids: torch.Tensor) -> str:
303
+ input_len = inputs["input_ids"].shape[1]
304
+ generated_ids = generated_ids[:, input_len:]
305
+ return processor.batch_decode(
306
+ generated_ids,
307
+ skip_special_tokens=True,
308
+ clean_up_tokenization_spaces=False,
309
+ )[0].strip()
310
+
311
+
312
+ def parse_score(text: str, key: str) -> int | None:
313
+ match = re.search(r"\{.*?\}", text, flags=re.S)
314
+ if match:
315
+ try:
316
+ obj = json.loads(match.group(0))
317
+ value = obj.get(key)
318
+ if isinstance(value, int) and 1 <= value <= 5:
319
+ return value
320
+ except json.JSONDecodeError:
321
+ pass
322
+ match = re.search(rf'"?{re.escape(key)}"?\s*:\s*([1-5])', text)
323
+ if match:
324
+ return int(match.group(1))
325
+ return None
326
+
327
+
328
+ def dtype_from_name(name: str) -> torch.dtype:
329
+ if name == "bfloat16":
330
+ return torch.bfloat16
331
+ if name == "float16":
332
+ return torch.float16
333
+ if name == "float32":
334
+ return torch.float32
335
+ raise ValueError(f"unsupported dtype: {name}")
336
+
337
+
338
+ def main() -> None:
339
+ parser = argparse.ArgumentParser(description="Infer with the anonymous judge adapter.")
340
+ parser.add_argument(
341
+ "--adapter-dir",
342
+ default="anonymouscla/physground-judger9B",
343
+ help=(
344
+ "Local directory with adapter_config.json + adapter_model.safetensors "
345
+ "+ subq+human.yaml, or a HF Hub repo id "
346
+ "(default: anonymouscla/physground-judger9B)."
347
+ ),
348
+ )
349
+ parser.add_argument("--video", required=True, type=Path)
350
+ parser.add_argument("--caption", required=True)
351
+ group = parser.add_mutually_exclusive_group(required=True)
352
+ group.add_argument("--metric", choices=["SA", "PTV", "persistence"])
353
+ group.add_argument("--law", choices=sorted(PHYSICAL_CRITERIA))
354
+ parser.add_argument("--criteria", help="Override physical-law criterion text.")
355
+ parser.add_argument("--max-new-tokens", type=int, default=64)
356
+ parser.add_argument("--temperature", type=float, default=0.0)
357
+ parser.add_argument("--fps", type=float, default=2.0)
358
+ parser.add_argument("--max-pixels", type=int, default=360 * 640)
359
+ parser.add_argument("--dtype", choices=["bfloat16", "float16", "float32"], default="bfloat16")
360
+ parser.add_argument("--device-map", default="auto")
361
+ parser.add_argument("--print-prompt", action="store_true")
362
+ args = parser.parse_args()
363
+
364
+ if not args.video.is_file():
365
+ raise FileNotFoundError(args.video)
366
+
367
+ dtype = dtype_from_name(args.dtype)
368
+ processor, model, adapter_dir = load_model(
369
+ args.adapter_dir, dtype=dtype, device_map=args.device_map
370
+ )
371
+
372
+ prompt_cfg = load_yaml(adapter_dir / "subq+human.yaml")
373
+ system_prompt, user_prompt, score_key = build_prompt(
374
+ prompt_cfg,
375
+ args.caption,
376
+ metric=args.metric,
377
+ law=args.law,
378
+ criteria=args.criteria,
379
+ )
380
+
381
+ if args.print_prompt:
382
+ print("SYSTEM:")
383
+ print(system_prompt)
384
+ print("\nUSER:")
385
+ print(user_prompt)
386
+ print()
387
+ device = next(model.parameters()).device
388
+ messages = build_messages(system_prompt, user_prompt, args.video)
389
+ inputs = prepare_inputs(
390
+ processor,
391
+ messages,
392
+ device,
393
+ fps=args.fps,
394
+ max_pixels=args.max_pixels,
395
+ )
396
+
397
+ generation_kwargs: dict[str, Any] = {
398
+ "max_new_tokens": args.max_new_tokens,
399
+ "do_sample": args.temperature > 0,
400
+ "temperature": args.temperature if args.temperature > 0 else None,
401
+ }
402
+ generation_kwargs = {k: v for k, v in generation_kwargs.items() if v is not None}
403
+
404
+ with torch.inference_mode():
405
+ generated_ids = model.generate(**inputs, **generation_kwargs)
406
+
407
+ raw = decode_generated(processor, inputs, generated_ids)
408
+ score = parse_score(raw, score_key)
409
+ print(json.dumps({"key": score_key, "score": score, "raw": raw}, ensure_ascii=False, indent=2))
410
+
411
+
412
+ if __name__ == "__main__":
413
+ main()