Spaces:
Running on Zero
Running on Zero
| # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| """Constraint conditioning: build index and data dicts from constraint sets for the denoiser.""" | |
| from collections import defaultdict | |
| import torch | |
| def build_condition_dicts(constraints_lst: list): | |
| index_dict = defaultdict(list) | |
| data_dict = defaultdict(list) | |
| for constraint in constraints_lst: | |
| constraint.update_constraints(data_dict, index_dict) | |
| return index_dict, data_dict | |
| def get_unique_index_and_data(indices_lst, data): | |
| # unique + sort them by t | |
| indices_unique, inverse = torch.unique(indices_lst, dim=0, return_inverse=True) | |
| # pick first value for each unique (t, j) | |
| first_idx = torch.zeros(indices_unique.size(0), dtype=torch.long, device=inverse.device) | |
| first_idx.scatter_(0, inverse, torch.arange(len(inverse), device=inverse.device)) | |
| assert (indices_lst[first_idx] == indices_unique).all() | |
| # get the data | |
| indices_lst = indices_lst[first_idx] | |
| data = data[first_idx] | |
| return indices_lst, data | |