| import os |
| import torch |
| import numpy as np |
| from torch.utils.tensorboard import SummaryWriter |
| import json |
| import models.tae as tae |
| import options.option_tae as option_tae |
| import utils.utils_model as utils_model |
| import utils.eval_trans as eval_trans |
| from humanml3d_272 import dataset_eval_tae |
| import sys |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| os.chdir('Evaluator_272') |
| sys.path.insert(0, os.getcwd()) |
|
|
|
|
| comp_device = torch.device('cuda') |
|
|
| |
| args = option_tae.get_args_parser() |
| torch.manual_seed(args.seed) |
|
|
| args.out_dir = os.path.join(args.out_dir, f'{args.exp_name}') |
| os.makedirs(args.out_dir, exist_ok = True) |
|
|
| |
| logger = utils_model.get_logger(args.out_dir) |
| writer = SummaryWriter(args.out_dir) |
| logger.info(json.dumps(vars(args), indent=4, sort_keys=True)) |
|
|
| val_loader = dataset_eval_tae.DATALoader(args.dataname, True, 32) |
|
|
| |
| clip_range = [-30,20] |
|
|
| net = tae.Causal_HumanTAE( |
| hidden_size=args.hidden_size, |
| down_t=args.down_t, |
| stride_t=args.stride_t, |
| depth=args.depth, |
| dilation_growth_rate=args.dilation_growth_rate, |
| activation='relu', |
| latent_dim=args.latent_dim, |
| clip_range=clip_range |
| ) |
|
|
|
|
| print ('loading checkpoint from {}'.format(args.resume_pth)) |
| ckpt = torch.load(args.resume_pth, map_location='cpu') |
| net.load_state_dict(ckpt['net'], strict=True) |
| net.eval() |
| net.to(comp_device) |
|
|
|
|
| |
| import torch |
| from mld.models.architectures.temos.textencoder.distillbert_actor import DistilbertActorAgnosticEncoder |
| from mld.models.architectures.temos.motionencoder.actor import ActorAgnosticEncoder |
|
|
| modelpath = 'distilbert-base-uncased' |
|
|
| textencoder = DistilbertActorAgnosticEncoder(modelpath, num_layers=4, latent_dim=256) |
| motionencoder = ActorAgnosticEncoder(nfeats=272, vae = True, num_layers=4, latent_dim=256, max_len=300) |
|
|
| ckpt = torch.load('epoch=99.ckpt') |
|
|
| |
| textencoder_ckpt = {} |
| for k, v in ckpt['state_dict'].items(): |
| if k.split(".")[0] == "textencoder": |
| name = k.replace("textencoder.", "") |
| textencoder_ckpt[name] = v |
| textencoder.load_state_dict(textencoder_ckpt, strict=True) |
| textencoder.eval() |
| textencoder.to(comp_device) |
|
|
| |
| motionencoder_ckpt = {} |
| for k, v in ckpt['state_dict'].items(): |
| if k.split(".")[0] == "motionencoder": |
| name = k.replace("motionencoder.", "") |
| motionencoder_ckpt[name] = v |
| motionencoder.load_state_dict(motionencoder_ckpt, strict=True) |
| motionencoder.eval() |
| motionencoder.to(comp_device) |
| |
|
|
| evaluator = [textencoder, motionencoder] |
|
|
| fid = [] |
| mpjpe = [] |
|
|
| best_fid, best_mpjpe, writer, logger = eval_trans.evaluation_tae_single(args.out_dir, val_loader, net, logger, writer, evaluator=evaluator, device=comp_device) |
| fid.append(best_fid) |
| mpjpe.append(best_mpjpe) |
|
|
| logger.info('final result:') |
| logger.info(f'fid: {fid}') |
| logger.info(f'mpjpe: {mpjpe} (mm)') |