| import torch |
| import torch.nn as nn |
| from src.layers import EGNNLayer |
|
|
| class TimeEmbedding(nn.Module): |
| """ |
| Converts a time scalar 't' into a vector embedding. |
| This allows the neural network to understand the noise level (time step). |
| """ |
|
|
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
| self.linear_1 = nn.Linear(1, dim) |
| self.linear_2 = nn.Linear(dim, dim) |
| self.act = nn.SiLU() |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| t (Tensor): Time scalars of shape (Batch_Size, 1). |
| |
| Returns: |
| Tensor: Time embeddings of shape (Batch_Size, dim). |
| """ |
| |
| x = self.act(self.linear_1(t)) |
| x = self.linear_2(x) |
| return x |
|
|
| class CrystalDiffusionModel(nn.Module): |
| """ |
| E(n)-Equivariant Diffusion Model for Crystal Generation. |
| Predicts the denoised coordinates given a noisy input. |
| """ |
|
|
| def __init__(self, hidden_dim: int = 64, num_layers: int = 3, max_atom_type: int = 100): |
| super().__init__() |
|
|
| |
| |
| self.atom_embed = nn.Embedding(max_atom_type, hidden_dim) |
|
|
| |
| |
| self.time_embed = TimeEmbedding(hidden_dim) |
|
|
| |
| |
| self.layers = nn.ModuleList([ |
| EGNNLayer(c_in=hidden_dim, c_out=hidden_dim) |
| for _ in range(num_layers) |
| ]) |
|
|
| |
| |
|
|
| def forward(self, x: torch.Tensor, z: torch.Tensor, t: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor: |
| """ |
| Forward pass of the diffusion model. |
| |
| Args: |
| x (Tensor): Noisy atom positions. Shape (N, 3). |
| z (Tensor): Atomic numbers. Shape (N,). |
| t (Tensor): Time step/Noise level. Shape (Batch_Size, 1). |
| edge_index (Tensor): Graph connectivity (Adjacency list). Shape (2, E). |
| |
| Returns: |
| Tensor: Denoised atom positions. Shape (N, 3). |
| """ |
|
|
| |
| h = self.atom_embed(z) |
| t_emb = self.time_embed(t) |
|
|
| |
| |
| |
| h = h + t_emb.mean(dim=0, keepdim=True) |
|
|
| |
| for layer in self.layers: |
| |
| h, x = layer(h, x, edge_index) |
|
|
| |
| return x |
|
|
| |