File size: 3,013 Bytes
4e4ff14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
from mp_api.client import MPRester
import os
from pathlib import Path
from dotenv import load_dotenv
from tqdm import tqdm

# --- CONFIGURATION ---
load_dotenv()  # Load .env file for API keys
API_KEY = os.getenv("MPI_API_KEY")

# Save path: project root `data/` folder
repo_root = Path(__file__).resolve().parents[1]
SAVE_PATH = repo_root / "data" / "perovskite_dataset.pt"

def fetch_data(limit=2000):
    """
    Fetches a large dataset of ABO3 Perovskites (5 atoms) for the Foundation Model.
    """
    print(f"Connecting to Materials Project...")

    with MPRester(API_KEY) as mpr:
        # 1. Broad Search: Get all stable materials with 5 atoms
        # We search for materials with exactly 5 sites (atoms) in the unit cell.
        # This implicitly targets ABO3 structures (1+1+3 = 5).
        docs = mpr.materials.summary.search(
            is_stable=True,
            nsites=5, 
            fields=["structure", "material_id", "formula_pretty"]
        )

    print(f"Found {len(docs)} stable 5-atom crystals. Processing...")

    dataset = []
    
    # 2. Filter and Process
    # We want oxygen-containing perovskites generally, but let's keep it broad for now.
    # The 'nsites=5' filter does most of the heavy lifting.
    
    count = 0
    for doc in tqdm(docs):
        if count >= limit:
            break
            
        structure = doc.structure
        formula = doc.formula_pretty
        
        # Heuristic check: Perovskites usually have 3 Oxygens. 
        # This filters out random 5-atom things that aren't Perovskites.
        # (Optional but recommended for cleaner data)
        if "O3" not in formula:
            continue

        # --- TENSOR CREATION ---
        
        # A. Atomic Numbers (Integers) -> The "Identity"
        atomic_numbers = [site.specie.number for site in structure]
        z_tensor = torch.tensor(atomic_numbers, dtype=torch.long)
        
        # B. Coordinates (Floats) -> The "Geometry"
        coords = [site.coords for site in structure]
        r_tensor = torch.tensor(coords, dtype=torch.float32)
        
        # C. Center of Mass Correction (CRITICAL for Diffusion)
        # We shift the crystal so its center is at (0,0,0).
        # If we don't do this, the model wastes time learning absolute positions.
        r_tensor = r_tensor - torch.mean(r_tensor, dim=0, keepdim=True)

        # Create Data Object
        data_point = {
            "id": str(doc.material_id),
            "formula": formula,
            "z": z_tensor,   # Features
            "pos": r_tensor  # Positions (Centered)
        }
        
        dataset.append(data_point)
        count += 1

    # 3. Save to Disk
    # Ensure directory exists
    SAVE_PATH.parent.mkdir(parents=True, exist_ok=True)
    
    torch.save(dataset, SAVE_PATH)
    print(f"✅ Successfully saved {len(dataset)} crystals to {SAVE_PATH}")
    print(f"   (Filtered for 5-atom unit cells containing 'O3')")

if __name__ == "__main__":
    fetch_data(limit=2000)