Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,11 +19,11 @@ import pandas as pd
|
|
| 19 |
TITLE = '๐R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'
|
| 20 |
|
| 21 |
TITLE_MD = '<h1 align="center">๐R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>'
|
| 22 |
-
DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding.
|
| 23 |
-
GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the
|
| 24 |
|
| 25 |
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
|
| 26 |
-
WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/
|
| 27 |
|
| 28 |
# yapf:disable
|
| 29 |
EXAMPLES = [
|
|
@@ -45,7 +45,7 @@ def load_video(video_path, cfg):
|
|
| 45 |
decord.bridge.set_bridge('torch')
|
| 46 |
|
| 47 |
vr = VideoReader(video_path)
|
| 48 |
-
stride = vr.get_avg_fps() /
|
| 49 |
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
|
| 50 |
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
|
| 51 |
|
|
@@ -75,6 +75,8 @@ def init_model(config, checkpoint):
|
|
| 75 |
|
| 76 |
|
| 77 |
def main(video, query, model, cfg):
|
|
|
|
|
|
|
| 78 |
if len(query) == 0:
|
| 79 |
raise gr.Error('Text query can not be empty.')
|
| 80 |
|
|
@@ -86,19 +88,16 @@ def main(video, query, model, cfg):
|
|
| 86 |
query = clip.tokenize(query, truncate=True)
|
| 87 |
|
| 88 |
device = next(model.parameters()).device
|
| 89 |
-
data = dict(video=video.to(device), query=query.to(device), fps=[
|
| 90 |
|
| 91 |
with torch.inference_mode():
|
| 92 |
pred = model(data)
|
| 93 |
|
| 94 |
-
mr = pred['_out']['boundary'][:5].cpu().tolist()
|
| 95 |
-
mr = [[convert_time(p[0]), convert_time(p[1]), round(p[2], 2)] for p in mr]
|
| 96 |
-
|
| 97 |
hd = pred['_out']['saliency'].cpu()
|
| 98 |
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
|
| 99 |
-
hd = pd.DataFrame(dict(x=range(
|
| 100 |
|
| 101 |
-
return
|
| 102 |
|
| 103 |
|
| 104 |
model, cfg = init_model(CONFIG, WEIGHT)
|
|
@@ -121,8 +120,6 @@ with gr.Blocks(title=TITLE) as demo:
|
|
| 121 |
submit_btn = gr.Button(value='๐ Submit')
|
| 122 |
|
| 123 |
with gr.Column():
|
| 124 |
-
mr = gr.DataFrame(
|
| 125 |
-
headers=['Start Time', 'End Time', 'Score'], label='Moment Retrieval')
|
| 126 |
hd = gr.LinePlot(
|
| 127 |
x='x',
|
| 128 |
y='y',
|
|
@@ -131,6 +128,6 @@ with gr.Blocks(title=TITLE) as demo:
|
|
| 131 |
label='Highlight Detection')
|
| 132 |
|
| 133 |
random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
|
| 134 |
-
submit_btn.click(fn, [video, query],
|
| 135 |
|
| 136 |
-
demo.launch()
|
|
|
|
| 19 |
TITLE = '๐R2-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding'
|
| 20 |
|
| 21 |
TITLE_MD = '<h1 align="center">๐R<sup>2</sup>-Tuning: Efficient Image-to-Video Transfer Learning for Video Temporal Grounding</h1>'
|
| 22 |
+
DESCRIPTION_MD = 'R<sup>2</sup>-Tuning is a parameter- and memory-efficient transfer learning method for video temporal grounding.'
|
| 23 |
+
GUIDE_MD = '### User Guide:\n1. Upload a video or click "random" to sample one.\n2. Input a text query. A good practice is to write a sentence with 5~15 words.\n3. Click "submit" and you\'ll see the highlight detection results on the right.'
|
| 24 |
|
| 25 |
CONFIG = 'configs/qvhighlights/r2_tuning_qvhighlights.py'
|
| 26 |
+
WEIGHT = 'https://huggingface.co/yeliudev/R2-Tuning/resolve/main/checkpoints/r2_tuning_youtube_sur-d384d8b2.pth'
|
| 27 |
|
| 28 |
# yapf:disable
|
| 29 |
EXAMPLES = [
|
|
|
|
| 45 |
decord.bridge.set_bridge('torch')
|
| 46 |
|
| 47 |
vr = VideoReader(video_path)
|
| 48 |
+
stride = vr.get_avg_fps() / 1
|
| 49 |
fm_idx = [min(round(i), len(vr) - 1) for i in np.arange(0, len(vr), stride).tolist()]
|
| 50 |
video = vr.get_batch(fm_idx).permute(0, 3, 1, 2).float() / 255
|
| 51 |
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
def main(video, query, model, cfg):
|
| 78 |
+
query = 'surfing'
|
| 79 |
+
|
| 80 |
if len(query) == 0:
|
| 81 |
raise gr.Error('Text query can not be empty.')
|
| 82 |
|
|
|
|
| 88 |
query = clip.tokenize(query, truncate=True)
|
| 89 |
|
| 90 |
device = next(model.parameters()).device
|
| 91 |
+
data = dict(video=video.to(device), query=query.to(device), fps=[1])
|
| 92 |
|
| 93 |
with torch.inference_mode():
|
| 94 |
pred = model(data)
|
| 95 |
|
|
|
|
|
|
|
|
|
|
| 96 |
hd = pred['_out']['saliency'].cpu()
|
| 97 |
hd = ((hd - hd.min()) / (hd.max() - hd.min()) * 0.9 + 0.05).tolist()
|
| 98 |
+
hd = pd.DataFrame(dict(x=range(len(hd) * 1 -1, -1, -1), y=hd))
|
| 99 |
|
| 100 |
+
return hd
|
| 101 |
|
| 102 |
|
| 103 |
model, cfg = init_model(CONFIG, WEIGHT)
|
|
|
|
| 120 |
submit_btn = gr.Button(value='๐ Submit')
|
| 121 |
|
| 122 |
with gr.Column():
|
|
|
|
|
|
|
| 123 |
hd = gr.LinePlot(
|
| 124 |
x='x',
|
| 125 |
y='y',
|
|
|
|
| 128 |
label='Highlight Detection')
|
| 129 |
|
| 130 |
random_btn.click(lambda: random.sample(EXAMPLES, 1)[0], None, [video, query])
|
| 131 |
+
submit_btn.click(fn, [video, query], hd)
|
| 132 |
|
| 133 |
+
demo.launch()
|