| import os |
| from PIL import Image |
| import numpy as np |
| import time |
| import torch |
| import argparse |
| from glob import glob |
| from sklearn.model_selection import train_test_split |
| from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
| from dataset import decode_text |
| from tqdm import tqdm |
| from datasets import load_metric |
|
|
| cer_metric = load_metric("./cer.py") |
|
|
|
|
| def compute_metrics(pred_str, label_str): |
| """ |
| 计算cer,acc |
| :param pred: |
| :return: |
| """ |
| cer = cer_metric.compute(predictions=pred_str, references=label_str) |
| acc = [pred == label for pred, label in zip(pred_str, label_str)] |
| acc = sum(acc) / (len(acc) + 0.000001) |
| return {"cer": cer, "acc": acc} |
|
|
|
|
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser(description='trocr 模型评估') |
| parser.add_argument('--cust_data_init_weights_path', default='./cust-data/weights', type=str, |
| help="初始化训练权重,用于自己数据集上fine-tune权重") |
| parser.add_argument('--CUDA_VISIBLE_DEVICES', default='-1', type=str, help="GPU设置") |
| parser.add_argument('--dataset_path', default='dataset/HW-hand-write/HW_Chinese/*/*.[j|p]*', type=str, |
| help="img path") |
| parser.add_argument('--random_state', default=10086, type=int, help="用于训练集划分的随机数") |
|
|
| args = parser.parse_args() |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.CUDA_VISIBLE_DEVICES |
| paths = glob(args.dataset_path) |
| if args.random_state is not None: |
| train_paths, test_paths = train_test_split(paths, test_size=0.05, random_state=args.random_state) |
|
|
| else: |
| train_paths = [] |
| test_paths = paths |
|
|
| print("train num:", len(train_paths), "test num:", len(test_paths)) |
|
|
| processor = TrOCRProcessor.from_pretrained(args.cust_data_init_weights_path) |
| vocab = processor.tokenizer.get_vocab() |
|
|
| vocab_inp = {vocab[key]: key for key in vocab} |
| mps_device = torch.device("mps") |
| model = VisionEncoderDecoderModel.from_pretrained(args.cust_data_init_weights_path) |
| model.eval() |
| model.to(mps_device) |
|
|
| vocab = processor.tokenizer.get_vocab() |
| vocab_inp = {vocab[key]: key for key in vocab} |
|
|
| pred_str, label_str = [], [] |
| for p in tqdm(test_paths): |
| img = Image.open(p).convert('RGB') |
| txt_p = os.path.splitext(p)[0] + '.txt' |
| with open(txt_p) as f: |
| label = f.read().strip() |
| pixel_values = processor([img], return_tensors="pt").pixel_values |
|
|
| with torch.no_grad(): |
| generated_ids = model.generate(pixel_values[:, :, :].to(mps_device)) |
|
|
| generated_text = decode_text(generated_ids[0].cpu().numpy(), vocab, vocab_inp) |
| pred_str.append(generated_text) |
| label_str.append(label) |
|
|
| res = compute_metrics(pred_str, label_str) |
| print(res) |
|
|