akhaliq HF Staff commited on
Commit
f009ec7
·
1 Parent(s): f1f0cc8

feat: add text-only inference support and conditional downsample_mode parameter for model generation

Browse files
Files changed (1) hide show
  1. app.py +19 -8
app.py CHANGED
@@ -40,8 +40,15 @@ def predict(message: str, file: FileData = None, downsample_mode: str = "16x"):
40
  General inference endpoint for both image and video.
41
  """
42
  if file is None:
43
- # Text-only inference (standard LLM behavior)
44
  messages = [{"role": "user", "content": [{"type": "text", "text": message}]}]
 
 
 
 
 
 
 
45
  else:
46
  file_path = file["path"]
47
  is_video = any(file_path.lower().endswith(ext) for ext in ['.mp4', '.mkv', '.mov', '.avi'])
@@ -85,13 +92,17 @@ def predict(message: str, file: FileData = None, downsample_mode: str = "16x"):
85
  ).to(model.device)
86
 
87
  with torch.no_grad():
88
- generated_ids = model.generate(
89
- **inputs,
90
- downsample_mode=downsample_mode,
91
- max_new_tokens=1024,
92
- do_sample=True,
93
- temperature=0.7
94
- )
 
 
 
 
95
 
96
  generated_ids_trimmed = [
97
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
 
40
  General inference endpoint for both image and video.
41
  """
42
  if file is None:
43
+ # Text-only inference
44
  messages = [{"role": "user", "content": [{"type": "text", "text": message}]}]
45
+ inputs = processor.apply_chat_template(
46
+ messages,
47
+ tokenize=True,
48
+ add_generation_prompt=True,
49
+ return_dict=True,
50
+ return_tensors="pt"
51
+ ).to(model.device)
52
  else:
53
  file_path = file["path"]
54
  is_video = any(file_path.lower().endswith(ext) for ext in ['.mp4', '.mkv', '.mov', '.avi'])
 
92
  ).to(model.device)
93
 
94
  with torch.no_grad():
95
+ generate_kwargs = {
96
+ **inputs,
97
+ "max_new_tokens": 1024,
98
+ "do_sample": True,
99
+ "temperature": 0.7
100
+ }
101
+
102
+ if file is not None:
103
+ generate_kwargs["downsample_mode"] = downsample_mode
104
+
105
+ generated_ids = model.generate(**generate_kwargs)
106
 
107
  generated_ids_trimmed = [
108
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)