File size: 845 Bytes
5f5f544
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from sapiens.registry import MODELS
from torch.utils.data import default_collate


@MODELS.register_module()
def eval_collate(batch: list):
    passthrough_keys = {"data_samples"}
    collated_data, passthrough_data = [], {key: [] for key in passthrough_keys}
    for item in batch:
        item_for_collation = {
            k: v for k, v in item.items() if k not in passthrough_keys
        }
        for key in passthrough_keys:
            passthrough_data[key].append(item[key])
        collated_data.append(item_for_collation)
    final_batch = default_collate(collated_data)
    final_batch.update(passthrough_data)
    return final_batch