| import torch |
| import torch.optim as optim |
| import random |
| import os |
| from src.model import CrystalDiffusionModel |
|
|
| |
| |
| DATA_PATH = "data/perovskite_dataset.pt" |
| EPOCHS = 3000 |
| LEARNING_RATE = 1e-3 |
|
|
| def load_dataset(): |
| if not os.path.exists(DATA_PATH): |
| raise FileNotFoundError(f"β Could not find dataset at {DATA_PATH}. Check spelling!") |
| |
| data = torch.load(DATA_PATH) |
| print(f"β
Loaded {len(data)} crystals for training.") |
| return data |
|
|
| def get_random_batch(dataset, device): |
| """ |
| Picks a RANDOM crystal from the dataset. |
| This is crucial for generalization (learning rules vs memorizing one shape). |
| """ |
| |
| sample = random.choice(dataset) |
| |
| |
| z = sample["z"].to(device).long() |
| x_real = sample["pos"].to(device).float() |
| |
| |
| |
| num_atoms = z.shape[0] |
| |
| |
| row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms) |
| col = torch.arange(num_atoms).repeat(num_atoms) |
| |
| |
| mask = row != col |
| edge_index = torch.stack([row[mask], col[mask]], dim=0).to(device) |
| |
| return x_real, z, edge_index |
|
|
| def train(): |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"--- π Training on {device} ---") |
| |
| |
| dataset = load_dataset() |
| |
| |
| model = CrystalDiffusionModel().to(device) |
| optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) |
| model.train() |
|
|
| print(f"--- Starting Training Loop ({EPOCHS} Epochs) ---") |
|
|
| for epoch in range(1, EPOCHS + 1): |
| optimizer.zero_grad() |
|
|
| |
| x_real, z, edge_index = get_random_batch(dataset, device) |
|
|
| |
| |
| t = torch.rand(1, 1, device=device) |
|
|
| |
| noise = torch.randn_like(x_real) |
| |
| |
| x_noisy = x_real + (noise * t) |
|
|
| |
| |
| x_pred = model(x_noisy, z, t, edge_index) |
|
|
| |
| |
| loss = torch.mean((x_pred - x_real)**2) |
|
|
| loss.backward() |
| optimizer.step() |
|
|
| |
| if epoch % 200 == 0: |
| print(f"Epoch {epoch} | Loss: {loss.item():.6f}") |
|
|
| |
| torch.save(model.state_dict(), "model_weights.pth") |
| print("β
Training Complete. Model saved to model_weights.pth!") |
|
|
| if __name__ == "__main__": |
| train() |