| import torch |
| from mp_api.client import MPRester |
| import os |
| from pathlib import Path |
| from dotenv import load_dotenv |
| from tqdm import tqdm |
|
|
| |
| load_dotenv() |
| API_KEY = os.getenv("MPI_API_KEY") |
|
|
| |
| repo_root = Path(__file__).resolve().parents[1] |
| SAVE_PATH = repo_root / "data" / "perovskite_dataset.pt" |
|
|
| def fetch_data(limit=2000): |
| """ |
| Fetches a large dataset of ABO3 Perovskites (5 atoms) for the Foundation Model. |
| """ |
| print(f"Connecting to Materials Project...") |
|
|
| with MPRester(API_KEY) as mpr: |
| |
| |
| |
| docs = mpr.materials.summary.search( |
| is_stable=True, |
| nsites=5, |
| fields=["structure", "material_id", "formula_pretty"] |
| ) |
|
|
| print(f"Found {len(docs)} stable 5-atom crystals. Processing...") |
|
|
| dataset = [] |
| |
| |
| |
| |
| |
| count = 0 |
| for doc in tqdm(docs): |
| if count >= limit: |
| break |
| |
| structure = doc.structure |
| formula = doc.formula_pretty |
| |
| |
| |
| |
| if "O3" not in formula: |
| continue |
|
|
| |
| |
| |
| atomic_numbers = [site.specie.number for site in structure] |
| z_tensor = torch.tensor(atomic_numbers, dtype=torch.long) |
| |
| |
| coords = [site.coords for site in structure] |
| r_tensor = torch.tensor(coords, dtype=torch.float32) |
| |
| |
| |
| |
| r_tensor = r_tensor - torch.mean(r_tensor, dim=0, keepdim=True) |
|
|
| |
| data_point = { |
| "id": str(doc.material_id), |
| "formula": formula, |
| "z": z_tensor, |
| "pos": r_tensor |
| } |
| |
| dataset.append(data_point) |
| count += 1 |
|
|
| |
| |
| SAVE_PATH.parent.mkdir(parents=True, exist_ok=True) |
| |
| torch.save(dataset, SAVE_PATH) |
| print(f"✅ Successfully saved {len(dataset)} crystals to {SAVE_PATH}") |
| print(f" (Filtered for 5-atom unit cells containing 'O3')") |
|
|
| if __name__ == "__main__": |
| fetch_data(limit=2000) |