Daankular commited on
Commit
f2728af
·
verified ·
1 Parent(s): a03943e

Add RIFE model/ directory files

Browse files
Files changed (1) hide show
  1. model/warplayer.py +24 -0
model/warplayer.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5
+ backwarp_tenGrid = {}
6
+
7
+
8
+ def warp(tenInput, tenFlow):
9
+ k = (str(tenFlow.device), str(tenFlow.size()))
10
+ if k not in backwarp_tenGrid:
11
+ tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=tenFlow.device).view(
12
+ 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
13
+ tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=tenFlow.device).view(
14
+ 1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
15
+ backwarp_tenGrid[k] = torch.cat(
16
+ [tenHorizontal, tenVertical], 1).to(tenFlow.device)
17
+
18
+ tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
19
+ tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
20
+
21
+ grid = backwarp_tenGrid[k].type_as(tenFlow)
22
+
23
+ g = (grid + tenFlow).permute(0, 2, 3, 1)
24
+ return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)