jounery-d commited on
Commit
f20f66f
·
verified ·
1 Parent(s): f0fa630

Upload 2 files

Browse files
Files changed (2) hide show
  1. python/requirements.txt +2 -0
  2. python/run_axmodel.py +79 -0
python/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ numpy
2
+ opencv-python
python/run_axmodel.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import numpy as np
4
+ import axengine as axe
5
+
6
+ def from_numpy(x):
7
+ return x if isinstance(x, np.ndarray) else np.array(x)
8
+
9
+ def post_process(raw_color, orig):
10
+ color_np = np.asarray(raw_color)
11
+ orig_np = np.asarray(orig)
12
+ color_yuv = cv2.cvtColor(color_np, cv2.COLOR_RGB2YUV)
13
+ # do a black and white transform first to get better luminance values
14
+ orig_yuv = cv2.cvtColor(orig_np, cv2.COLOR_RGB2YUV)
15
+ hires = np.copy(orig_yuv)
16
+ hires[:, :, 1:3] = color_yuv[:, :, 1:3]
17
+ final = cv2.cvtColor(hires, cv2.COLOR_YUV2RGB)
18
+ return final
19
+
20
+ def main(args):
21
+ # Initialize the model
22
+ session = axe.InferenceSession(args.model_path)
23
+ output_names = [x.name for x in session.get_outputs()]
24
+ input_name = session.get_inputs()[0].name
25
+ print(input_name)
26
+ print(output_names)
27
+
28
+ ori_image = cv2.imread(args.input_path)
29
+ h, w = ori_image.shape[:2]
30
+ image = cv2.resize(ori_image, (512, 512))
31
+ image = (image[..., ::-1] /255.0).astype(np.float32)
32
+
33
+ mean = [0.485, 0.456, 0.406]
34
+ std = [0.229, 0.224, 0.225]
35
+ image = ((image - mean) / std).astype(np.float32)
36
+
37
+ #image = (image /1.0).astype(np.float32)
38
+ image = np.transpose(np.expand_dims(np.ascontiguousarray(image), axis=0), (0,3,1,2))
39
+ print(image.shape)
40
+
41
+
42
+ # Use the model to generate super-resolved images
43
+ sr = session.run(output_names, {input_name: image})
44
+
45
+ if isinstance(sr, (list, tuple)):
46
+ sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr]
47
+ else:
48
+ sr = from_numpy(sr)
49
+
50
+ #sr_y_image = imgproc.array_to_image(sr)
51
+ sr = np.transpose(sr.squeeze(0), (1,2,0))
52
+ sr = (sr*std + mean).astype(np.float32)
53
+
54
+ # Save image
55
+ ndarr = np.clip((sr*255.0), 0, 255.0).astype(np.uint8)
56
+ ndarr = cv2.resize(ndarr[..., ::-1], (w, h))
57
+ out_image = post_process(ndarr, ori_image)
58
+
59
+ cv2.imwrite(args.output_path, out_image)
60
+ print(f"Color image save to `{args.output_path}`")
61
+
62
+
63
+ if __name__ == "__main__":
64
+ parser = argparse.ArgumentParser(description="Using the model generator super-resolution images.")
65
+ parser.add_argument("--input_path",
66
+ type=str,
67
+ default="./input.png",
68
+ help="origin image path.")
69
+ parser.add_argument("--output_path",
70
+ type=str,
71
+ default="./sr_colorized.jpg",
72
+ help="colorized image path.")
73
+ parser.add_argument("--model_path",
74
+ type=str,
75
+ default="./colorize_stable.axmodel",
76
+ help="model path.")
77
+ args = parser.parse_args()
78
+
79
+ main(args)