sapiens2-pose / sapiens /engine /evaluators /eval_collate.py
Rawal Khirodkar
Pin Python 3.10 + torch 2.1.2; vendor sapiens2 to bypass requires-python
5f5f544
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from sapiens.registry import MODELS
from torch.utils.data import default_collate
@MODELS.register_module()
def eval_collate(batch: list):
passthrough_keys = {"data_samples"}
collated_data, passthrough_data = [], {key: [] for key in passthrough_keys}
for item in batch:
item_for_collation = {
k: v for k, v in item.items() if k not in passthrough_keys
}
for key in passthrough_keys:
passthrough_data[key].append(item[key])
collated_data.append(item_for_collation)
final_batch = default_collate(collated_data)
final_batch.update(passthrough_data)
return final_batch