Spaces:
Running on Zero
Running on Zero
Rawal Khirodkar
Initial sapiens2-pointmap Space (HF download at startup, all 4 sizes, 3D viewer)
bff20b3 | # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from sapiens.engine.runners import BaseRunner | |
| ## left-right flip for pose val and test | |
| class PoseRunner(BaseRunner): | |
| def test(self) -> None: | |
| if self.accelerator.is_main_process: | |
| self.logger.info(f"\033[95mStarting test...\033[0m") | |
| self.model.eval() | |
| self.evaluator.reset() | |
| for i, data_batch in enumerate(self.val_dataloader): | |
| data_batch = self.data_preprocessor(data_batch) # preprocess | |
| inputs, data_samples = data_batch["inputs"], data_batch["data_samples"] | |
| with torch.no_grad(): | |
| pred = self.model(inputs) # forward | |
| if self.val_cfg.get("flip_test", False): | |
| with torch.no_grad(): | |
| pred_flipped = self.model(inputs.flip(-1)) # forward | |
| flip_indices = data_samples[0]["meta"]["flip_indices"] | |
| pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W | |
| assert len(flip_indices) == pred_flipped.shape[1] ## K | |
| pred_flipped = pred_flipped[:, flip_indices] | |
| pred = (pred + pred_flipped) / 2.0 | |
| if self.accelerator.is_main_process and i > 0 and i % 100 == 0: | |
| self.logger.info( | |
| f"\033[95mTest: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m" | |
| ) | |
| self.evaluator.process( | |
| pred, data_samples, accelerator=self.accelerator | |
| ) ## accelerator used to gather and dedup in val | |
| # metrics eval on main process | |
| metrics = self.evaluator.evaluate( | |
| logger=self.logger, accelerator=self.accelerator | |
| ) | |
| if self.accelerator.is_main_process: | |
| self.logger.info( | |
| f"\033[95mTest: {', '.join([f'{k}: {v:.4f}' for k, v in metrics.items()])}\033[0m" | |
| ) | |
| self.logger.info(f"\033[95mTesting finished ✔\033[0m") | |
| # ------------------------------------------------------------------------- | |
| def val(self) -> None: | |
| self.model.eval() | |
| if self.accelerator.is_main_process: | |
| self.logger.info(f"\033[95mValidating iter {self.iter}\033[0m") | |
| self.evaluator.reset() | |
| for i, data_batch in enumerate(self.val_dataloader): | |
| data_batch = self.data_preprocessor(data_batch) # preprocess | |
| inputs, data_samples = data_batch["inputs"], data_batch["data_samples"] | |
| with torch.no_grad(): | |
| pred = self.model(inputs) # forward | |
| if self.val_cfg.get("flip_test", False): | |
| with torch.no_grad(): | |
| pred_flipped = self.model(inputs.flip(-1)) # forward | |
| flip_indices = data_samples[0]["meta"]["flip_indices"] | |
| pred_flipped = pred_flipped.flip(-1) ## B x K x heatmap_H x heatmap_W | |
| assert len(flip_indices) == pred_flipped.shape[1] ## K | |
| pred_flipped = pred_flipped[:, flip_indices] | |
| pred = (pred + pred_flipped) / 2.0 | |
| if self.accelerator.is_main_process and i > 0 and i % 100 == 0: | |
| self.logger.info( | |
| f"\033[95mVal: {i}/{len(self.val_dataloader)}: batch_size: {len(data_batch['inputs'])}\033[0m" | |
| ) | |
| self.evaluator.process(pred, data_samples, accelerator=self.accelerator) | |
| metric = self.evaluator.evaluate( | |
| logger=self.logger, accelerator=self.accelerator | |
| ) | |
| self.model.train() | |
| return metric | |