sapiens2-normal / sapiens /pose /src /runners /pose_runner.py
Rawal Khirodkar
Initial sapiens2-normal Space (HF download at startup, all 4 sizes)
ba23d94
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from sapiens.engine.runners import BaseRunner
## left-right flip for pose val and test
class PoseRunner(BaseRunner):
def test(self) -> None:
if self.accelerator.is_main_process:
self.logger.info(f"\033[95mStarting test...\033[0m")
self.model.eval()
self.evaluator.reset()
for i, data_batch in enumerate(self.val_dataloader):
data_batch = self.data_preprocessor(data_batch) # preprocess
inputs, data_samples = data_batch["inputs"], data_batch["data_samples"]
with torch.no_grad():
pred = self.model(inputs) # forward
if self.val_cfg.get("flip_test", False):
with torch.no_grad():
pred_flipped = self.model(inputs.flip(-1)) # forward
flip_indices = data_samples[0]["meta"]["flip_indices"]
pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W
assert len(flip_indices) == pred_flipped.shape[1] ## K
pred_flipped = pred_flipped[:, flip_indices]
pred = (pred + pred_flipped) / 2.0
if self.accelerator.is_main_process and i > 0 and i % 100 == 0:
self.logger.info(
f"\033[95mTest: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m"
)
self.evaluator.process(
pred, data_samples, accelerator=self.accelerator
) ## accelerator used to gather and dedup in val
# metrics eval on main process
metrics = self.evaluator.evaluate(
logger=self.logger, accelerator=self.accelerator
)
if self.accelerator.is_main_process:
self.logger.info(
f"\033[95mTest: {', '.join([f'{k}: {v:.4f}' for k, v in metrics.items()])}\033[0m"
)
self.logger.info(f"\033[95mTesting finished ✔\033[0m")
# -------------------------------------------------------------------------
def val(self) -> None:
self.model.eval()
if self.accelerator.is_main_process:
self.logger.info(f"\033[95mValidating iter {self.iter}\033[0m")
self.evaluator.reset()
for i, data_batch in enumerate(self.val_dataloader):
data_batch = self.data_preprocessor(data_batch) # preprocess
inputs, data_samples = data_batch["inputs"], data_batch["data_samples"]
with torch.no_grad():
pred = self.model(inputs) # forward
if self.val_cfg.get("flip_test", False):
with torch.no_grad():
pred_flipped = self.model(inputs.flip(-1)) # forward
flip_indices = data_samples[0]["meta"]["flip_indices"]
pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W
assert len(flip_indices) == pred_flipped.shape[1] ## K
pred_flipped = pred_flipped[:, flip_indices]
pred = (pred + pred_flipped) / 2.0
if self.accelerator.is_main_process and i > 0 and i % 100 == 0:
self.logger.info(
f"\033[95mVal: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m"
)
self.evaluator.process(pred, data_samples, accelerator=self.accelerator)
metric = self.evaluator.evaluate(
logger=self.logger, accelerator=self.accelerator
)
self.model.train()
return metric