yeliudev commited on
Commit
3ffab64
ยท
verified ยท
1 Parent(s): b0d1738

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -14
app.py CHANGED
@@ -19,11 +19,11 @@ import pandas as pd
19
  TITLE = '๐ŸŒ€R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'
20
 
21
  TITLE_MD = '<h1 align="center">๐ŸŒ€R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>'
22
- DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding. Please find more details in our <a href="https://arxiv.org/abs/2404.00801" target="_blank">Tech Report</a> and <a href="https://github.com/yeliudev/R2-Tuning" target="_blank">GitHub Repo</a>.'
23
- GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the moment retrieval and highlight detection results on the right.'
24
 
25
  CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
26
- WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_qvhighlights-ed516355.pth'
27
 
28
  # yapf:disable
29
  EXAMPLES = [
@@ -45,7 +45,7 @@ def load_video(video_path, cfg):
45
  decord.bridge.set_bridge('torch')
46
 
47
  vr = VideoReader(video_path)
48
- stride = vr.get_avg_fps() / cfg.data.val.fps
49
  fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
50
  video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
51
 
@@ -75,6 +75,8 @@ def init_model(config, checkpoint):
75
 
76
 
77
  def main(video, query, model, cfg):
 
 
78
  if len(query) == 0:
79
  raise gr.Error('Text query can not be empty.')
80
 
@@ -86,19 +88,16 @@ def main(video, query, model, cfg):
86
  query = clip.tokenize(query, truncate=True)
87
 
88
  device = next(model.parameters()).device
89
- data = dict(video=video.to(device), query=query.to(device), fps=[cfg.data.val.fps])
90
 
91
  with torch.inference_mode():
92
  pred = model(data)
93
 
94
- mr = pred['_out']['boundary'][:5].cpu().tolist()
95
- mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
96
-
97
  hd = pred['_out']['saliency'].cpu()
98
  hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
99
- hd = pd.DataFrame(dict(x=range(0, len(hd) * 2, 2), y=hd))
100
 
101
- return mr, hd
102
 
103
 
104
  model, cfg = init_model(CONFIG, WEIGHT)
@@ -121,8 +120,6 @@ with gr.Blocks(title=TITLE) as demo:
121
  submit_btn = gr.Button(value='๐Ÿš€ Submit')
122
 
123
  with gr.Column():
124
- mr = gr.DataFrame(
125
- headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval')
126
  hd = gr.LinePlot(
127
  x='x',
128
  y='y',
@@ -131,6 +128,6 @@ with gr.Blocks(title=TITLE) as demo:
131
  label='Highlight Detection')
132
 
133
  random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
134
- submit_btn.click(fn, [video, query], [mr, hd])
135
 
136
- demo.launch()
 
19
  TITLE = '๐ŸŒ€R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'
20
 
21
  TITLE_MD = '<h1 align="center">๐ŸŒ€R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>'
22
+ DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding.'
23
+ GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the highlight detection results on the right.'
24
 
25
  CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
26
+ WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_youtube_sur-d384d8b2.pth'
27
 
28
  # yapf:disable
29
  EXAMPLES = [
 
45
  decord.bridge.set_bridge('torch')
46
 
47
  vr = VideoReader(video_path)
48
+ stride = vr.get_avg_fps() / 1
49
  fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
50
  video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
51
 
 
75
 
76
 
77
  def main(video, query, model, cfg):
78
+ query = 'surfing'
79
+
80
  if len(query) == 0:
81
  raise gr.Error('Text query can not be empty.')
82
 
 
88
  query = clip.tokenize(query, truncate=True)
89
 
90
  device = next(model.parameters()).device
91
+ data = dict(video=video.to(device), query=query.to(device), fps=[1])
92
 
93
  with torch.inference_mode():
94
  pred = model(data)
95
 
 
 
 
96
  hd = pred['_out']['saliency'].cpu()
97
  hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
98
+ hd = pd.DataFrame(dict(x=range(len(hd) * 1 -1, -1, -1), y=hd))
99
 
100
+ return hd
101
 
102
 
103
  model, cfg = init_model(CONFIG, WEIGHT)
 
120
  submit_btn = gr.Button(value='๐Ÿš€ Submit')
121
 
122
  with gr.Column():
 
 
123
  hd = gr.LinePlot(
124
  x='x',
125
  y='y',
 
128
  label='Highlight Detection')
129
 
130
  random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
131
+ submit_btn.click(fn, [video, query], hd)
132
 
133
+ demo.launch()