Demise307 commited on
Commit
4c17419
·
verified ·
1 Parent(s): 3449b6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -24
app.py CHANGED
@@ -6,6 +6,8 @@ import os
6
  import torch
7
  import numpy as np
8
  import yaml
 
 
9
  from huggingface_hub import hf_hub_download
10
  #from gradio_imageslider import ImageSlider
11
 
@@ -55,32 +57,60 @@ print("LMHEAD MODEL CKPT:", LM_MODEL)
55
  lm_head.load_state_dict(torch.load(LM_MODEL, map_location="cpu"), strict=True)
56
 
57
 
58
- def load_img (filename, norm=True,):
59
- img = np.array(Image.open(filename).convert("RGB"))
60
- if norm:
61
- img = img / 255.
62
- img = img.astype(np.float32)
63
- return img
64
-
65
-
66
- def process_img (image, prompt):
67
- img = np.array(image)
68
- img = img / 255.
69
  img = img.astype(np.float32)
70
- y = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
71
 
72
- lm_embd = language_model(prompt)
73
- lm_embd = lm_embd.to(device)
 
74
 
75
  with torch.no_grad():
76
- text_embd, deg_pred = lm_head (lm_embd)
77
  x_hat = model(y, text_embd)
78
 
79
- restored_img = x_hat.squeeze().permute(1,2,0).clamp_(0, 1).cpu().detach().numpy()
80
- restored_img = np.clip(restored_img, 0. , 1.)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- restored_img = (restored_img * 255.0).round().astype(np.uint8) # float32 to uint8
83
- return Image.fromarray(restored_img) #(image, Image.fromarray(restored_img))
84
 
85
 
86
 
@@ -146,16 +176,18 @@ css = """
146
  """
147
 
148
  demo = gr.Interface(
149
- fn=process_img,
150
  inputs=[
151
- gr.Image(type="pil", label="Input", value="images/a4960.jpg"),
152
- gr.Text(label="Prompt", value="my colors are too off, make it pop so I can use it in instagram")
 
 
 
153
  ],
154
- outputs=[gr.Image(type="pil", label="Ouput")],
155
  title=title,
156
  description=description,
157
  article=article,
158
- examples=examples,
159
  css=css,
160
  )
161
 
 
6
  import torch
7
  import numpy as np
8
  import yaml
9
+ import cv2
10
+ import tempfile
11
  from huggingface_hub import hf_hub_download
12
  #from gradio_imageslider import ImageSlider
13
 
 
57
  lm_head.load_state_dict(torch.load(LM_MODEL, map_location="cpu"), strict=True)
58
 
59
 
60
+ def process_frame(frame_bgr, prompt):
61
+ # BGR RGB
62
+ frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
63
+ img = frame_rgb / 255.0
 
 
 
 
 
 
 
64
  img = img.astype(np.float32)
 
65
 
66
+ y = torch.tensor(img).permute(2, 0, 1).unsqueeze(0).to(device)
67
+
68
+ lm_embd = language_model(prompt).to(device)
69
 
70
  with torch.no_grad():
71
+ text_embd, _ = lm_head(lm_embd)
72
  x_hat = model(y, text_embd)
73
 
74
+ restored = (
75
+ x_hat.squeeze()
76
+ .permute(1, 2, 0)
77
+ .clamp(0, 1)
78
+ .cpu()
79
+ .numpy()
80
+ )
81
+
82
+ restored = (restored * 255).astype(np.uint8)
83
+ restored_bgr = cv2.cvtColor(restored, cv2.COLOR_RGB2BGR)
84
+
85
+ return restored_bgr
86
+
87
+
88
+ def process_video(video_path, prompt):
89
+ cap = cv2.VideoCapture(video_path)
90
+
91
+ fps = cap.get(cv2.CAP_PROP_FPS)
92
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
93
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
94
+
95
+ tmp_out = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False)
96
+ out_path = tmp_out.name
97
+
98
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
99
+ writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
100
+
101
+ while True:
102
+ ret, frame = cap.read()
103
+ if not ret:
104
+ break
105
+
106
+ restored_frame = process_frame(frame, prompt)
107
+ writer.write(restored_frame)
108
+
109
+ cap.release()
110
+ writer.release()
111
+
112
+ return out_path
113
 
 
 
114
 
115
 
116
 
 
176
  """
177
 
178
  demo = gr.Interface(
179
+ fn=process_video,
180
  inputs=[
181
+ gr.Video(label="Input Video"),
182
+ gr.Text(
183
+ label="Prompt",
184
+ value="enhance this video and improve visual quality"
185
+ ),
186
  ],
187
+ outputs=gr.Video(label="Output Video"),
188
  title=title,
189
  description=description,
190
  article=article,
 
191
  css=css,
192
  )
193