File size: 4,816 Bytes
bc9ef03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""inference.py — Sampling / inference for NSGF and NSGF++.

Implements:
  - NSGF Euler-step inference (standard model)
  - NSGF++ two-phase inference (NSGF → phase transition → NSF)

Reference: arXiv:2401.14069, Section 4.4, Appendix D
"""

import torch
import torch.nn as nn
from typing import Optional, Tuple, List
from dataset_loader import DatasetLoader


class NSGFSampler:
    """Sampler using a trained NSGF velocity field model."""
    def __init__(self, model: nn.Module, data_loader: DatasetLoader,
                 num_steps: int = 10, device: str = "cpu"):
        self.model = model.to(device)
        self.model.eval()
        self.data_loader = data_loader
        self.num_steps = num_steps
        self.device = device

    @torch.no_grad()
    def sample(self, n: int) -> torch.Tensor:
        X = self.data_loader.sample_source(n, self.device)
        dt = 1.0 / self.num_steps
        for step in range(self.num_steps):
            t = torch.full((n,), step * dt, device=self.device)
            v = self.model(X, t)
            X = X + dt * v
        return X

    @torch.no_grad()
    def sample_trajectory(self, n: int) -> List[torch.Tensor]:
        X = self.data_loader.sample_source(n, self.device)
        trajectory = [X.clone()]
        dt = 1.0 / self.num_steps
        for step in range(self.num_steps):
            t = torch.full((n,), step * dt, device=self.device)
            v = self.model(X, t)
            X = X + dt * v
            trajectory.append(X.clone())
        return trajectory


class NSGFPlusPlusSampler:
    """Sampler for the NSGF++ two-phase model.
    Phase 1 (NSGF): ≤5 Euler steps with Sinkhorn velocity field
    Phase 2 (NSF): Straight flow velocity field
    Total NFE = nsgf_steps + nsf_steps
    """
    def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module,
                 phase_predictor: Optional[nn.Module], data_loader: DatasetLoader,
                 nsgf_steps: int = 5, nsf_steps: int = 55, device: str = "cpu"):
        self.nsgf_model = nsgf_model.to(device)
        self.nsf_model = nsf_model.to(device)
        self.nsgf_model.eval()
        self.nsf_model.eval()
        if phase_predictor is not None:
            self.phase_predictor = phase_predictor.to(device)
            self.phase_predictor.eval()
        else:
            self.phase_predictor = None
        self.data_loader = data_loader
        self.nsgf_steps = nsgf_steps
        self.nsf_steps = nsf_steps
        self.device = device

    @torch.no_grad()
    def sample(self, n: int) -> torch.Tensor:
        X = self.data_loader.sample_source(n, self.device)
        dt_nsgf = 1.0 / self.nsgf_steps
        for step in range(self.nsgf_steps):
            t = torch.full((n,), step * dt_nsgf, device=self.device)
            v = self.nsgf_model(X, t)
            X = X + dt_nsgf * v
        if self.phase_predictor is not None:
            t_start = self.phase_predictor(X)
        else:
            t_start = torch.zeros(n, device=self.device)
        dt_nsf = 1.0 / self.nsf_steps
        for step in range(self.nsf_steps):
            t_current = t_start + step * dt_nsf * (1.0 - t_start)
            t_current = t_current.clamp(0, 1)
            v = self.nsf_model(X, t_current)
            X = X + dt_nsf * (1.0 - t_start.view(-1, *([1] * (X.dim() - 1)))) * v
        return X

    @torch.no_grad()
    def sample_simple(self, n: int) -> torch.Tensor:
        """Simplified: NSGF then NSF from t=0 to t=1."""
        X = self.data_loader.sample_source(n, self.device)
        dt_nsgf = 1.0 / self.nsgf_steps
        for step in range(self.nsgf_steps):
            t = torch.full((n,), step * dt_nsgf, device=self.device)
            v = self.nsgf_model(X, t)
            X = X + dt_nsgf * v
        dt_nsf = 1.0 / self.nsf_steps
        for step in range(self.nsf_steps):
            t = torch.full((n,), step * dt_nsf, device=self.device)
            v = self.nsf_model(X, t)
            X = X + dt_nsf * v
        return X

    @torch.no_grad()
    def sample_trajectory(self, n: int) -> Tuple[List[torch.Tensor], int]:
        trajectory = []
        X = self.data_loader.sample_source(n, self.device)
        trajectory.append(X.clone())
        dt_nsgf = 1.0 / self.nsgf_steps
        for step in range(self.nsgf_steps):
            t = torch.full((n,), step * dt_nsgf, device=self.device)
            v = self.nsgf_model(X, t)
            X = X + dt_nsgf * v
            trajectory.append(X.clone())
        phase_boundary = len(trajectory) - 1
        dt_nsf = 1.0 / self.nsf_steps
        for step in range(self.nsf_steps):
            t = torch.full((n,), step * dt_nsf, device=self.device)
            v = self.nsf_model(X, t)
            X = X + dt_nsf * v
            trajectory.append(X.clone())
        return trajectory, phase_boundary