| import numpy as np |
| from PIL import Image |
| import axengine as ort |
| import torch |
| from torchvision.transforms import Normalize, Compose, InterpolationMode, ToTensor, Resize, CenterCrop |
| from tokenizer import SimpleTokenizer |
| import argparse |
|
|
| OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) |
| OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) |
|
|
|
|
| def image_transform_v2(): |
| resolution = 256 |
| resize_size = resolution |
| centercrop_size = resolution |
| mean = OPENAI_DATASET_MEAN |
| std = OPENAI_DATASET_STD |
| aug_list = [ |
| Resize( |
| resize_size, |
| interpolation=InterpolationMode.BICUBIC, |
| ), |
| CenterCrop(centercrop_size), |
| ToTensor(), |
| Normalize(mean=mean, std=std) |
| ] |
| preprocess = Compose(aug_list) |
| return preprocess |
| |
| |
| def image_transform_v1(): |
| resolution = 256 |
| resize_size = resolution |
| centercrop_size = resolution |
| aug_list = [ |
| Resize( |
| resize_size, |
| interpolation=InterpolationMode.BILINEAR, |
| ), |
| CenterCrop(centercrop_size), |
| ToTensor(), |
| ] |
| preprocess = Compose(aug_list) |
| return preprocess |
|
|
|
|
| def softmax(x, axis=-1): |
| """ |
| 对 numpy 数组在指定维度上应用 softmax 函数 |
| |
| 参数: |
| x: numpy 数组,输入数据 |
| axis: 计算 softmax 的维度,默认为最后一个维度 (-1) |
| |
| 返回: |
| 经过 softmax 处理的 numpy 数组,与输入形状相同 |
| """ |
| |
| e_x = np.exp(x - np.max(x, axis=axis, keepdims=True)) |
| |
| return e_x / np.sum(e_x, axis=axis, keepdims=True) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("-ie", "--image_encoder_path", type=str, default="./mobileclip2_s4_image_encoder.axmodel", |
| help="image encoder axmodel path") |
| parser.add_argument("-te", "--text_encoder_path", type=str, default="./mobileclip2_s4_text_encoder.axmodel", |
| help="text encoder axmodel path") |
| parser.add_argument("-i", "--image", type=str, default="./zebra.jpg", |
| help="input image path") |
| parser.add_argument("-t", "--class_text", type=str, nargs='+', default=["a zebra", "a dog", "two zebras"], |
| help='List of captions, e.g.: "a zebra" "a dog" "two zebras"') |
| args = parser.parse_args() |
|
|
| image_encoder_path = args.image_encoder_path |
| text_encoder_path = args.text_encoder_path |
| |
| preprocess = image_transform_v1() |
| tokenizer = SimpleTokenizer(context_length=77) |
|
|
| image = preprocess(Image.open(args.image).convert('RGB')).unsqueeze(0) |
| text = tokenizer(args.class_text) |
| text = text.to(torch.int32) |
|
|
| onnx_image_encoder = ort.InferenceSession(image_encoder_path) |
| onnx_text_encoder = ort.InferenceSession(text_encoder_path) |
|
|
| image_features = onnx_image_encoder.run(["unnorm_image_features"],{"image":np.array(image)})[0] |
| |
| |
| |
| |
| |
| text_features = onnx_text_encoder.run(["unnorm_text_features"], {"text": text.numpy()})[0] |
| image_features /= np.linalg.norm(image_features, ord=2, axis=-1, keepdims=True) |
| text_features /= np.linalg.norm(text_features, ord=2, axis=-1, keepdims=True) |
|
|
| text_probs = softmax(100.0 * image_features @ text_features.T) |
|
|
| print("Label probs:", text_probs) |