akhaliq HF Staff commited on
Commit
32967e1
·
1 Parent(s): a11cb66

feat: implement manual video frame loading using PyAV to support direct frame passing for video processing

Browse files
Files changed (1) hide show
  1. app.py +40 -5
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
  import re
 
4
  from PIL import Image
5
  from transformers import AutoModelForImageTextToText, AutoProcessor
6
  from gradio import Server
@@ -17,9 +18,41 @@ model = AutoModelForImageTextToText.from_pretrained(
17
  model_id,
18
  torch_dtype=torch.bfloat16,
19
  trust_remote_code=True,
20
- device_map="cuda"
21
  )
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  # Utility for response normalization
24
  _PATTERN = re.compile(
25
  r'(```[\s\S]*?```|`[^`]+`|\$\$[\s\S]*?\$\$|\$[^$]+\$|\\\([\s\S]*?\\\)|\\\[[\s\S]*?\\\])'
@@ -54,26 +87,29 @@ def predict(message: str, file: FileData = None, downsample_mode: str = "16x") -
54
  is_video = any(file_path.lower().endswith(ext) for ext in ['.mp4', '.mkv', '.mov', '.avi'])
55
 
56
  if is_video:
 
 
 
57
  messages = [
58
  {
59
  "role": "user",
60
  "content": [
61
- {"type": "video", "url": file_path},
62
  {"type": "text", "text": message},
63
  ],
64
  }
65
  ]
66
- # Video specific params
67
  inputs = processor.apply_chat_template(
68
  messages, tokenize=True, add_generation_prompt=True,
69
  return_dict=True, return_tensors="pt",
70
  downsample_mode=downsample_mode,
71
- max_num_frames=64, # Optimized for speed
72
  stack_frames=1,
73
  max_slice_nums=1,
74
  use_image_id=False,
75
  ).to(model.device)
76
  else:
 
77
  messages = [
78
  {
79
  "role": "user",
@@ -83,7 +119,6 @@ def predict(message: str, file: FileData = None, downsample_mode: str = "16x") -
83
  ],
84
  }
85
  ]
86
- # Image specific params
87
  inputs = processor.apply_chat_template(
88
  messages, tokenize=True, add_generation_prompt=True,
89
  return_dict=True, return_tensors="pt",
 
1
  import os
2
  import torch
3
  import re
4
+ import av
5
  from PIL import Image
6
  from transformers import AutoModelForImageTextToText, AutoProcessor
7
  from gradio import Server
 
18
  model_id,
19
  torch_dtype=torch.bfloat16,
20
  trust_remote_code=True,
21
+ device_map="cuda"
22
  )
23
 
24
+ def load_video(video_path, max_frames=64):
25
+ """Utility to load video frames using PyAV."""
26
+ container = av.open(video_path)
27
+ frames = []
28
+ # Get total frames to sample uniformly
29
+ stream = container.streams.video[0]
30
+ total_frames = stream.frames
31
+
32
+ if total_frames <= 0: # Some containers don't report frame count
33
+ print("Frame count unknown, decoding all and sampling...")
34
+ temp_frames = []
35
+ for frame in container.decode(video=0):
36
+ temp_frames.append(frame.to_image())
37
+
38
+ if len(temp_frames) > max_frames:
39
+ indices = [int(i * len(temp_frames) / max_frames) for i in range(max_frames)]
40
+ frames = [temp_frames[i] for i in indices]
41
+ else:
42
+ frames = temp_frames
43
+ else:
44
+ # Sample max_frames uniformly
45
+ indices = [int(i * total_frames / max_frames) for i in range(max_frames)]
46
+ current_idx = 0
47
+ for i, frame in enumerate(container.decode(video=0)):
48
+ if current_idx < len(indices) and i == indices[current_idx]:
49
+ frames.append(frame.to_image())
50
+ current_idx += 1
51
+ if current_idx >= len(indices):
52
+ break
53
+ container.close()
54
+ return frames
55
+
56
  # Utility for response normalization
57
  _PATTERN = re.compile(
58
  r'(```[\s\S]*?```|`[^`]+`|\$\$[\s\S]*?\$\$|\$[^$]+\$|\\\([\s\S]*?\\\)|\\\[[\s\S]*?\\\])'
 
87
  is_video = any(file_path.lower().endswith(ext) for ext in ['.mp4', '.mkv', '.mov', '.avi'])
88
 
89
  if is_video:
90
+ print(f"Processing video: {file_path}")
91
+ # Load video frames manually to avoid torchvision decode error
92
+ frames = load_video(file_path, max_frames=64)
93
  messages = [
94
  {
95
  "role": "user",
96
  "content": [
97
+ {"type": "video", "video": frames}, # Pass frames directly
98
  {"type": "text", "text": message},
99
  ],
100
  }
101
  ]
 
102
  inputs = processor.apply_chat_template(
103
  messages, tokenize=True, add_generation_prompt=True,
104
  return_dict=True, return_tensors="pt",
105
  downsample_mode=downsample_mode,
106
+ max_num_frames=64,
107
  stack_frames=1,
108
  max_slice_nums=1,
109
  use_image_id=False,
110
  ).to(model.device)
111
  else:
112
+ print(f"Processing image: {file_path}")
113
  messages = [
114
  {
115
  "role": "user",
 
119
  ],
120
  }
121
  ]
 
122
  inputs = processor.apply_chat_template(
123
  messages, tokenize=True, add_generation_prompt=True,
124
  return_dict=True, return_tensors="pt",