| """ |
| CrysMTM Dataset Loading Script |
| |
| To use this dataset: |
| |
| 1. Download the dataset files from: https://huggingface.co/datasets/johnpolat/CrysMTM |
| 2. Place this script in the same directory as the downloaded files |
| 3. Run: python load_dataset.py |
| |
| Or use the Hugging Face datasets library directly: |
| from datasets import load_dataset |
| dataset = load_dataset("johnpolat/CrysMTM", use_auth_token=True) |
| """ |
|
|
| import os |
| import pandas as pd |
| from datasets import Dataset, DatasetDict |
| from PIL import Image as PILImage |
|
|
| def load_crysmtm_dataset(data_dir, split="train"): |
| """Load CrysMTM dataset for a specific split.""" |
| |
| |
| metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv") |
| df = pd.read_csv(metadata_path) |
| |
| def load_example(row): |
| """Load a single example with all modalities.""" |
| example = { |
| "phase": row["phase"], |
| "temperature": row["temperature"], |
| "rotation": row["rotation"], |
| "split": row["split"] |
| } |
| |
| |
| if pd.notna(row["image_path"]): |
| image_path = os.path.join(data_dir, row["image_path"]) |
| if os.path.exists(image_path): |
| example["image"] = PILImage.open(image_path).convert("RGB") |
| |
| |
| if pd.notna(row["xyz_path"]): |
| xyz_path = os.path.join(data_dir, row["xyz_path"]) |
| if os.path.exists(xyz_path): |
| with open(xyz_path, 'r') as f: |
| lines = f.readlines()[2:] |
| coords = [] |
| elements = [] |
| for line in lines: |
| parts = line.strip().split() |
| if len(parts) >= 4: |
| elements.append(parts[0]) |
| coords.append([float(x) for x in parts[1:4]]) |
| example["xyz_coordinates"] = coords |
| example["elements"] = elements |
| |
| |
| if pd.notna(row["text_path"]): |
| text_path = os.path.join(data_dir, row["text_path"]) |
| if os.path.exists(text_path): |
| with open(text_path, 'r') as f: |
| example["text"] = f.read() |
| |
| |
| regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"] |
| example["regression_labels"] = [row[prop] for prop in regression_properties] |
| |
| |
| example["classification_label"] = row["label"] |
| |
| return example |
| |
| |
| dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()]) |
| |
| return dataset |
|
|
| def load_dataset(data_dir="."): |
| """Load the complete CrysMTM dataset.""" |
| |
| splits = ["train", "test_id", "test_ood"] |
| dataset_dict = {} |
| |
| for split in splits: |
| try: |
| dataset_dict[split] = load_crysmtm_dataset(data_dir, split) |
| print(f"Loaded {split} split: {len(dataset_dict[split])} samples") |
| except FileNotFoundError: |
| print(f"Warning: {split} split not found") |
| |
| return DatasetDict(dataset_dict) |
|
|
| if __name__ == "__main__": |
| print("Loading CrysMTM dataset...") |
| dataset = load_dataset(".") |
| |
| print(f"\nDataset loaded successfully!") |
| print(f"Available splits: {list(dataset.keys())}") |
| |
| |
| if len(dataset) > 0: |
| first_split = list(dataset.keys())[0] |
| sample = dataset[first_split][0] |
| print(f"\nSample from {first_split} split:") |
| print(f" Phase: {sample['phase']}") |
| print(f" Temperature: {sample['temperature']}K") |
| print(f" Rotation: {sample['rotation']}") |
| if 'image' in sample and sample['image'] is not None: |
| print(f" Image size: {sample['image'].size}") |
| if 'regression_labels' in sample: |
| print(f" Regression labels: {sample['regression_labels']}") |
| if 'classification_label' in sample: |
| print(f" Classification label: {sample['classification_label']}") |
| |
| print("\n✅ Dataset ready to use!") |