| import os |
| import pandas as pd |
| from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence |
| from PIL import Image as PILImage |
|
|
| def load_crysmtm_dataset(data_dir, split="train"): |
| """Load CrysMTM dataset for a specific split.""" |
| |
| |
| metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv") |
| df = pd.read_csv(metadata_path) |
| |
| def load_example(row): |
| """Load a single example with all modalities.""" |
| example = { |
| "phase": row["phase"], |
| "temperature": row["temperature"], |
| "rotation": row["rotation"], |
| "split": row["split"] |
| } |
| |
| |
| if pd.notna(row["image_path"]): |
| image_path = os.path.join(data_dir, row["image_path"]) |
| if os.path.exists(image_path): |
| example["image"] = PILImage.open(image_path).convert("RGB") |
| |
| |
| if pd.notna(row["xyz_path"]): |
| xyz_path = os.path.join(data_dir, row["xyz_path"]) |
| if os.path.exists(xyz_path): |
| with open(xyz_path, 'r') as f: |
| lines = f.readlines()[2:] |
| coords = [] |
| elements = [] |
| for line in lines: |
| parts = line.strip().split() |
| if len(parts) >= 4: |
| elements.append(parts[0]) |
| coords.append([float(x) for x in parts[1:4]]) |
| example["xyz_coordinates"] = coords |
| example["elements"] = elements |
| |
| |
| if pd.notna(row["text_path"]): |
| text_path = os.path.join(data_dir, row["text_path"]) |
| if os.path.exists(text_path): |
| with open(text_path, 'r') as f: |
| example["text"] = f.read() |
| |
| |
| regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"] |
| example["regression_labels"] = [row[prop] for prop in regression_properties] |
| |
| |
| example["classification_label"] = row["label"] |
| |
| return example |
| |
| |
| dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()]) |
| |
| return dataset |
|
|
| def load_dataset(data_dir): |
| """Load the complete CrysMTM dataset.""" |
| |
| splits = ["train", "test_id", "test_ood"] |
| dataset_dict = {} |
| |
| for split in splits: |
| try: |
| dataset_dict[split] = load_crysmtm_dataset(data_dir, split) |
| except FileNotFoundError: |
| print(f"Warning: {split} split not found") |
| |
| return DatasetDict(dataset_dict) |
|
|
| |
| def load_crysmtm(): |
| """Main function to load CrysMTM dataset.""" |
| return load_dataset(".") |