| import torch |
|
|
| def get_sample_align_fn(sample_align_model): |
| r""" |
| Code is adapted from https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/scripts/classifier_sample.py#L54-L61 |
| """ |
| def sample_align_fn(x, *args, **kwargs): |
| r""" |
| Calculates `grad(log(p(y|x)))` |
| This uses the conditioning strategy from Sohl-Dickstein et al. (2015). |
| |
| Parameters |
| ---------- |
| x: torch.Tensor |
| |
| Returns |
| ------- |
| grad |
| """ |
| |
| with torch.enable_grad(): |
| x_in = x.detach().requires_grad_(True) |
| logits = sample_align_model(x_in, *args, **kwargs) |
| grad = torch.autograd.grad(logits.sum(), x_in, allow_unused=True)[0] |
| return grad |
| return sample_align_fn |
|
|
| def get_alignment_kwargs_avg_x(context_seq=None, target_seq=None, ): |
| r""" |
| Please customize this function for generating knowledge "avg_x_gt" |
| that guides the inference. |
| E.g., this function uses 2.0 ground-truth future average intensity as "avg_x_gt" for demonstration. |
| |
| Parameters |
| ---------- |
| context_seq: torch.Tensor, aka "y" |
| target_seq: torch.Tensor, aka "x" |
| |
| Returns |
| ------- |
| alignment_kwargs: Dict |
| """ |
| multiplier = 2.0 |
| batch_size = target_seq.shape[0] |
| ret = torch.mean(target_seq.view(batch_size, -1), |
| dim=1, keepdim=True) * multiplier |
| return {"avg_x_gt": ret} |