Spaces:
Running
Running
File size: 3,758 Bytes
ba23d94 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 | # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from sapiens.engine.runners import BaseRunner
## left-right flip for pose val and test
class PoseRunner(BaseRunner):
def test(self) -> None:
if self.accelerator.is_main_process:
self.logger.info(f"\033[95mStarting test...\033[0m")
self.model.eval()
self.evaluator.reset()
for i, data_batch in enumerate(self.val_dataloader):
data_batch = self.data_preprocessor(data_batch) # preprocess
inputs, data_samples = data_batch["inputs"], data_batch["data_samples"]
with torch.no_grad():
pred = self.model(inputs) # forward
if self.val_cfg.get("flip_test", False):
with torch.no_grad():
pred_flipped = self.model(inputs.flip(-1)) # forward
flip_indices = data_samples[0]["meta"]["flip_indices"]
pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W
assert len(flip_indices) == pred_flipped.shape[1] ## K
pred_flipped = pred_flipped[:, flip_indices]
pred = (pred + pred_flipped) / 2.0
if self.accelerator.is_main_process and i > 0 and i % 100 == 0:
self.logger.info(
f"\033[95mTest: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m"
)
self.evaluator.process(
pred, data_samples, accelerator=self.accelerator
) ## accelerator used to gather and dedup in val
# metrics eval on main process
metrics = self.evaluator.evaluate(
logger=self.logger, accelerator=self.accelerator
)
if self.accelerator.is_main_process:
self.logger.info(
f"\033[95mTest: {', '.join([f'{k}: {v:.4f}' for k, v in metrics.items()])}\033[0m"
)
self.logger.info(f"\033[95mTesting finished ✔\033[0m")
# -------------------------------------------------------------------------
def val(self) -> None:
self.model.eval()
if self.accelerator.is_main_process:
self.logger.info(f"\033[95mValidating iter {self.iter}\033[0m")
self.evaluator.reset()
for i, data_batch in enumerate(self.val_dataloader):
data_batch = self.data_preprocessor(data_batch) # preprocess
inputs, data_samples = data_batch["inputs"], data_batch["data_samples"]
with torch.no_grad():
pred = self.model(inputs) # forward
if self.val_cfg.get("flip_test", False):
with torch.no_grad():
pred_flipped = self.model(inputs.flip(-1)) # forward
flip_indices = data_samples[0]["meta"]["flip_indices"]
pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W
assert len(flip_indices) == pred_flipped.shape[1] ## K
pred_flipped = pred_flipped[:, flip_indices]
pred = (pred + pred_flipped) / 2.0
if self.accelerator.is_main_process and i > 0 and i % 100 == 0:
self.logger.info(
f"\033[95mVal: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m"
)
self.evaluator.process(pred, data_samples, accelerator=self.accelerator)
metric = self.evaluator.evaluate(
logger=self.logger, accelerator=self.accelerator
)
self.model.train()
return metric
|