import torch import numpy as np from src.model import CrystalDiffusionModel # --- CONFIGURATION --- # Select the atoms you want to generate! # Example: BaTiO3 (Barium Titanate) -> Ba=56, Ti=22, O=8 # Example: SrTiO3 (Strontium Titanate) -> Sr=38, Ti=22, O=8 # Example: CaTiO3 (Calcium Titanate) -> Ca=20, Ti=22, O=8 # Let's try generating Strontium Titanate (SrTiO3) this time TARGET_ATOMS = [38, 22, 8, 8, 8] # Sr, Ti, O, O, O MODEL_PATH = "model_weights.pth" STEPS = 50 # Number of diffusion steps def save_xyz(pos, z, filename): """ Saves the crystal in XYZ format for visualization. """ with open(filename, "w") as f: f.write(f"{len(pos)}\n") f.write("Generated by CrystalDiff\n") for i in range(len(pos)): # Simple periodic table lookup for common perovskite elements # You can add more if you generate other materials elem_map = { 8: "O", 22: "Ti", 20: "Ca", 56: "Ba", 38: "Sr", 82: "Pb", 26: "Fe", 40: "Zr" } atom_symbol = elem_map.get(int(z[i]), "X") # Default to X if unknown f.write(f"{atom_symbol} {pos[i,0]:.4f} {pos[i,1]:.4f} {pos[i,2]:.4f}\n") def generate(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"--- 💎 Generating Crystal on {device} ---") # 1. Load Model model = CrystalDiffusionModel().to(device) try: model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) except FileNotFoundError: print(f"❌ Error: Could not find '{MODEL_PATH}'. Did you run train.py?") return model.eval() # 2. Setup Target Chemistry z = torch.tensor(TARGET_ATOMS, device=device) num_atoms = len(z) print(f"Target Atoms: {z.tolist()}") # Create fully connected graph row = torch.repeat_interleave(torch.arange(num_atoms), num_atoms) col = torch.arange(num_atoms).repeat(num_atoms) mask = row != col edge_index = torch.stack([row[mask], col[mask]], dim=0).to(device) # 3. Start with Pure Noise (The "Chaos") # We use a noise scale of 1.0 to match training x = torch.randn(num_atoms, 3, device=device) print(f"Initial State: Random Gas Cloud") save_xyz(x, z, "gen_step_00.xyz") # 4. The Reverse Diffusion Loop dt = 1.0 / STEPS for i in range(STEPS): # Time goes from 1.0 -> 0.0 t_val = 1.0 - (i * dt) t_tensor = torch.tensor([[t_val]], device=device) with torch.no_grad(): # Predict where the atoms SHOULD be x_pred = model(x, z, t_tensor, edge_index) # Update Position (Euler Integration) # We move 10% towards the prediction at each step for stability x = x + (x_pred - x) * 0.1 if i % 10 == 0: print(f"Step {i}/{STEPS}: Denoising...") save_xyz(x, z, f"gen_step_{i:02d}.xyz") # 5. Final Save print(f"✅ Final Structure Generated!") save_xyz(x, z, "gen_final.xyz") print("Check 'gen_final.xyz' to see your crystal.") if __name__ == "__main__": generate()