Leacb4 commited on
Commit
48572da
·
verified ·
1 Parent(s): 8a7c966

Upload evaluation/utils/datasets.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. evaluation/utils/datasets.py +54 -20
evaluation/utils/datasets.py CHANGED
@@ -14,6 +14,7 @@ from __future__ import annotations
14
 
15
  import difflib
16
  import hashlib
 
17
  import sys
18
  from pathlib import Path
19
  from io import BytesIO
@@ -33,17 +34,13 @@ if str(_PROJECT_ROOT) not in sys.path:
33
  sys.path.insert(0, str(_PROJECT_ROOT))
34
 
35
  from config import ( # type: ignore
 
36
  column_local_image_path,
37
  fashion_mnist_csv,
38
  local_dataset_path,
39
  images_dir,
40
  )
41
 
42
- _VALID_COLORS = [
43
- "beige", "black", "blue", "brown", "green",
44
- "orange", "pink", "purple", "red", "white", "yellow",
45
- ]
46
-
47
  # ---------------------------------------------------------------------------
48
  # Fashion-MNIST helpers
49
  # ---------------------------------------------------------------------------
@@ -215,7 +212,6 @@ class KaggleDataset(Dataset):
215
 
216
  self.transform = transforms.Compose([
217
  transforms.Resize((224, 224)),
218
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
219
  transforms.ToTensor(),
220
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
221
  ])
@@ -244,17 +240,40 @@ class KaggleDataset(Dataset):
244
  return image, description, color
245
 
246
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  def load_kaggle_marqo_dataset(
248
  max_samples: int = 5000,
249
  include_hierarchy: bool = False,
 
250
  ) -> KaggleDataset:
251
- """Download and prepare the KAGL Marqo HuggingFace dataset."""
252
- from datasets import load_dataset # type: ignore
253
 
254
- print("Loading KAGL Marqo dataset...")
255
- dataset = load_dataset("Marqo/KAGL")
256
- df = dataset["data"].to_pandas()
257
- print(f"Dataset loaded: {len(df)} samples, columns: {list(df.columns)}")
 
 
 
 
 
 
 
258
 
259
  df = df.dropna(subset=["text", "image"])
260
 
@@ -269,8 +288,8 @@ def load_kaggle_marqo_dataset(
269
  })
270
 
271
  kaggle_df = kaggle_df.dropna(subset=["color"])
272
- kaggle_df = kaggle_df[kaggle_df["color"].isin(_VALID_COLORS)]
273
- print(f"After color filtering: {len(kaggle_df)} samples, colors: {sorted(kaggle_df['color'].unique())}")
274
 
275
  return KaggleDataset(kaggle_df, include_hierarchy=include_hierarchy)
276
 
@@ -289,7 +308,6 @@ class LocalDataset(Dataset):
289
 
290
  self.transform = transforms.Compose([
291
  transforms.Resize((224, 224)),
292
- transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
293
  transforms.ToTensor(),
294
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
295
  ])
@@ -301,7 +319,9 @@ class LocalDataset(Dataset):
301
  row = self.dataframe.iloc[idx]
302
  try:
303
  image_path = row.get(column_local_image_path) if hasattr(row, "get") else None
304
- if isinstance(image_path, str) and image_path and Path(image_path).exists():
 
 
305
  image = Image.open(image_path).convert("RGB")
306
  else:
307
  # Fallback: download image from URL (and cache).
@@ -339,10 +359,21 @@ class LocalDataset(Dataset):
339
  def load_local_validation_dataset(
340
  max_samples: int = 5000,
341
  include_hierarchy: bool = False,
 
342
  ) -> LocalDataset:
343
- """Load and prepare the internal local validation dataset."""
344
- print("Loading local validation dataset...")
345
- df = pd.read_csv(local_dataset_path)
 
 
 
 
 
 
 
 
 
 
346
  print(f"Dataset loaded: {len(df)} samples")
347
 
348
  if column_local_image_path in df.columns:
@@ -352,7 +383,6 @@ def load_local_validation_dataset(
352
  print(f"Column '{column_local_image_path}' not found; falling back to 'image_url'.")
353
 
354
  if "color" in df.columns:
355
- df = df[df["color"].isin(_VALID_COLORS)]
356
  print(f"After color filtering: {len(df)} samples, colors: {sorted(df['color'].unique())}")
357
 
358
  if len(df) > max_samples:
@@ -376,6 +406,10 @@ def collate_fn_filter_none(batch):
376
  if not batch:
377
  print("Empty batch after filtering None values")
378
  return torch.tensor([]), [], []
 
 
 
 
379
  images, texts, colors = zip(*batch)
380
  return torch.stack(images), list(texts), list(colors)
381
 
 
14
 
15
  import difflib
16
  import hashlib
17
+ import os
18
  import sys
19
  from pathlib import Path
20
  from io import BytesIO
 
34
  sys.path.insert(0, str(_PROJECT_ROOT))
35
 
36
  from config import ( # type: ignore
37
+ ROOT_DIR,
38
  column_local_image_path,
39
  fashion_mnist_csv,
40
  local_dataset_path,
41
  images_dir,
42
  )
43
 
 
 
 
 
 
44
  # ---------------------------------------------------------------------------
45
  # Fashion-MNIST helpers
46
  # ---------------------------------------------------------------------------
 
212
 
213
  self.transform = transforms.Compose([
214
  transforms.Resize((224, 224)),
 
215
  transforms.ToTensor(),
216
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
217
  ])
 
240
  return image, description, color
241
 
242
 
243
+ def download_kaggle_raw_df() -> pd.DataFrame:
244
+ """Download the raw KAGL Marqo DataFrame from HuggingFace.
245
+
246
+ This is the expensive network operation. Callers can cache the result
247
+ and pass it to :func:`load_kaggle_marqo_dataset` via *raw_df* to avoid
248
+ repeated downloads.
249
+ """
250
+ from datasets import load_dataset # type: ignore
251
+
252
+ print("Downloading KAGL Marqo dataset from HuggingFace...")
253
+ dataset = load_dataset("Marqo/KAGL")
254
+ df = dataset["data"].to_pandas()
255
+ print(f"KAGL dataset downloaded: {len(df)} samples, columns: {list(df.columns)}")
256
+ return df
257
+
258
+
259
  def load_kaggle_marqo_dataset(
260
  max_samples: int = 5000,
261
  include_hierarchy: bool = False,
262
+ raw_df: Optional[pd.DataFrame] = None,
263
  ) -> KaggleDataset:
264
+ """Download and prepare the KAGL Marqo HuggingFace dataset.
 
265
 
266
+ Args:
267
+ max_samples: Maximum number of samples to return.
268
+ include_hierarchy: If True, dataset tuples include a hierarchy element.
269
+ raw_df: Pre-downloaded DataFrame (from :func:`download_kaggle_raw_df`).
270
+ If *None*, the dataset is downloaded from HuggingFace.
271
+ """
272
+ if raw_df is not None:
273
+ df = raw_df.copy()
274
+ print(f"Using cached KAGL DataFrame: {len(df)} samples")
275
+ else:
276
+ df = download_kaggle_raw_df()
277
 
278
  df = df.dropna(subset=["text", "image"])
279
 
 
288
  })
289
 
290
  kaggle_df = kaggle_df.dropna(subset=["color"])
291
+
292
+ print(f"Colors: {sorted(kaggle_df['color'].unique())}")
293
 
294
  return KaggleDataset(kaggle_df, include_hierarchy=include_hierarchy)
295
 
 
308
 
309
  self.transform = transforms.Compose([
310
  transforms.Resize((224, 224)),
 
311
  transforms.ToTensor(),
312
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
313
  ])
 
319
  row = self.dataframe.iloc[idx]
320
  try:
321
  image_path = row.get(column_local_image_path) if hasattr(row, "get") else None
322
+ if isinstance(image_path, str) and image_path:
323
+ if not os.path.isabs(image_path):
324
+ image_path = str(ROOT_DIR / image_path)
325
  image = Image.open(image_path).convert("RGB")
326
  else:
327
  # Fallback: download image from URL (and cache).
 
359
  def load_local_validation_dataset(
360
  max_samples: int = 5000,
361
  include_hierarchy: bool = False,
362
+ raw_df: Optional[pd.DataFrame] = None,
363
  ) -> LocalDataset:
364
+ """Load and prepare the internal local validation dataset.
365
+
366
+ Args:
367
+ max_samples: Maximum number of samples to return.
368
+ include_hierarchy: If True, dataset tuples include a hierarchy element.
369
+ raw_df: Pre-loaded DataFrame. If *None*, the CSV is read from disk.
370
+ """
371
+ if raw_df is not None:
372
+ df = raw_df.copy()
373
+ print(f"Using cached local DataFrame: {len(df)} samples")
374
+ else:
375
+ print("Loading local validation dataset...")
376
+ df = pd.read_csv(local_dataset_path)
377
  print(f"Dataset loaded: {len(df)} samples")
378
 
379
  if column_local_image_path in df.columns:
 
383
  print(f"Column '{column_local_image_path}' not found; falling back to 'image_url'.")
384
 
385
  if "color" in df.columns:
 
386
  print(f"After color filtering: {len(df)} samples, colors: {sorted(df['color'].unique())}")
387
 
388
  if len(df) > max_samples:
 
406
  if not batch:
407
  print("Empty batch after filtering None values")
408
  return torch.tensor([]), [], []
409
+ # Support both 3-value (image, text, color) and 4-value (image, text, color, hierarchy) tuples
410
+ if len(batch[0]) == 4:
411
+ images, texts, colors, hierarchies = zip(*batch)
412
+ return torch.stack(images), list(texts), list(colors), list(hierarchies)
413
  images, texts, colors = zip(*batch)
414
  return torch.stack(images), list(texts), list(colors)
415