johnpolat commited on
Commit
44052d1
·
verified ·
1 Parent(s): a2b3944

Upload dataset_loading_script.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. dataset_loading_script.py +82 -0
dataset_loading_script.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from datasets import Dataset, DatasetDict, Features, Value, Image, Sequence
4
+ from PIL import Image as PILImage
5
+
6
+ def load_crysmtm_dataset(data_dir, split="train"):
7
+ """Load CrysMTM dataset for a specific split."""
8
+
9
+ # Load metadata
10
+ metadata_path = os.path.join(data_dir, "metadata", f"{split}_metadata.csv")
11
+ df = pd.read_csv(metadata_path)
12
+
13
+ def load_example(row):
14
+ """Load a single example with all modalities."""
15
+ example = {
16
+ "phase": row["phase"],
17
+ "temperature": row["temperature"],
18
+ "rotation": row["rotation"],
19
+ "split": row["split"]
20
+ }
21
+
22
+ # Load image
23
+ if pd.notna(row["image_path"]):
24
+ image_path = os.path.join(data_dir, row["image_path"])
25
+ if os.path.exists(image_path):
26
+ example["image"] = PILImage.open(image_path).convert("RGB")
27
+
28
+ # Load XYZ coordinates
29
+ if pd.notna(row["xyz_path"]):
30
+ xyz_path = os.path.join(data_dir, row["xyz_path"])
31
+ if os.path.exists(xyz_path):
32
+ with open(xyz_path, 'r') as f:
33
+ lines = f.readlines()[2:] # Skip header lines
34
+ coords = []
35
+ elements = []
36
+ for line in lines:
37
+ parts = line.strip().split()
38
+ if len(parts) >= 4:
39
+ elements.append(parts[0])
40
+ coords.append([float(x) for x in parts[1:4]])
41
+ example["xyz_coordinates"] = coords
42
+ example["elements"] = elements
43
+
44
+ # Load text
45
+ if pd.notna(row["text_path"]):
46
+ text_path = os.path.join(data_dir, row["text_path"])
47
+ if os.path.exists(text_path):
48
+ with open(text_path, 'r') as f:
49
+ example["text"] = f.read()
50
+
51
+ # Add regression labels
52
+ regression_properties = ["HOMO", "LUMO", "Eg", "Ef", "Et", "Eta", "disp", "vol", "bond"]
53
+ example["regression_labels"] = [row[prop] for prop in regression_properties]
54
+
55
+ # Add classification label
56
+ example["classification_label"] = row["label"]
57
+
58
+ return example
59
+
60
+ # Create dataset
61
+ dataset = Dataset.from_list([load_example(row) for _, row in df.iterrows()])
62
+
63
+ return dataset
64
+
65
+ def load_dataset(data_dir):
66
+ """Load the complete CrysMTM dataset."""
67
+
68
+ splits = ["train", "test_id", "test_ood"]
69
+ dataset_dict = {}
70
+
71
+ for split in splits:
72
+ try:
73
+ dataset_dict[split] = load_crysmtm_dataset(data_dir, split)
74
+ except FileNotFoundError:
75
+ print(f"Warning: {split} split not found")
76
+
77
+ return DatasetDict(dataset_dict)
78
+
79
+ # This is the main function that Hugging Face Hub will call
80
+ def load_crysmtm():
81
+ """Main function to load CrysMTM dataset."""
82
+ return load_dataset(".")