omar-ah commited on
Commit
2dd6eee
·
1 Parent(s): ea9b821

Filter invalid LLaVA image records for real runs

Browse files
Files changed (1) hide show
  1. code/train_production.py +42 -20
code/train_production.py CHANGED
@@ -360,21 +360,18 @@ def tokenize_prompt_and_target(
360
 
361
 
362
  def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
363
- try:
364
- if isinstance(img, str):
365
- img = Image.open(img).convert("RGB")
366
- elif isinstance(img, dict) and "bytes" in img:
367
- img = Image.open(BytesIO(img["bytes"])).convert("RGB")
368
- elif isinstance(img, dict) and "zip_path" in img and "member" in img:
369
- with zipfile.ZipFile(img["zip_path"], "r") as archive:
370
- with archive.open(img["member"], "r") as member_file:
371
- img = Image.open(member_file).convert("RGB")
372
- elif isinstance(img, Image.Image):
373
- img = img.convert("RGB")
374
- else:
375
- img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
376
- except Exception:
377
- img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
378
 
379
  pil_image = img
380
  resized = pil_image.resize((img_size, img_size), Image.BICUBIC)
@@ -429,6 +426,7 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
429
  print("Loading LLaVA-Pretrain dataset...")
430
  dataset_root = None
431
  images_zip_path = None
 
432
  try:
433
  data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
434
  except Exception as exc:
@@ -440,11 +438,16 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
440
  )
441
  json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
442
  images_zip_path = os.path.join(dataset_root, "images.zip")
 
 
 
443
  data = load_dataset("json", data_files={"train": json_path}, split="train")
444
  if max_samples:
445
  data = data.select(range(min(max_samples, len(data))))
446
 
447
- def normalize(sample: Dict[str, object], idx: int) -> Dict[str, object]:
 
 
448
  text = ""
449
  if "conversations" in sample:
450
  parts = []
@@ -459,6 +462,9 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
459
  text = "Describe this image."
460
 
461
  image_obj = sample.get("image")
 
 
 
462
  if isinstance(image_obj, str) and dataset_root and not os.path.isabs(image_obj):
463
  candidate_paths = [
464
  image_obj,
@@ -468,12 +474,24 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
468
  resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None)
469
  if resolved_path:
470
  image_obj = resolved_path
471
- elif images_zip_path and os.path.exists(images_zip_path):
 
 
 
 
 
 
 
 
472
  image_obj = {
473
  "zip_path": images_zip_path,
474
- "member": image_obj,
475
  }
 
 
 
476
 
 
477
  return {
478
  "image": image_obj,
479
  "prompt_text": "Describe this image.",
@@ -482,9 +500,13 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
482
  "source_config": "llava_pretrain",
483
  }
484
 
485
- records = [normalize(data[i], i) for i in range(len(data))]
486
  normalized = HFDataset.from_list(records)
487
- print(f"Loaded {len(normalized)} LLaVA samples")
 
 
 
 
488
  return normalized
489
 
490
 
 
360
 
361
 
362
  def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
363
+ if isinstance(img, str):
364
+ img = Image.open(img).convert("RGB")
365
+ elif isinstance(img, dict) and "bytes" in img:
366
+ img = Image.open(BytesIO(img["bytes"])).convert("RGB")
367
+ elif isinstance(img, dict) and "zip_path" in img and "member" in img:
368
+ with zipfile.ZipFile(img["zip_path"], "r") as archive:
369
+ with archive.open(img["member"], "r") as member_file:
370
+ img = Image.open(member_file).convert("RGB")
371
+ elif isinstance(img, Image.Image):
372
+ img = img.convert("RGB")
373
+ else:
374
+ raise ValueError(f"Unsupported image payload type: {type(img)!r}")
 
 
 
375
 
376
  pil_image = img
377
  resized = pil_image.resize((img_size, img_size), Image.BICUBIC)
 
426
  print("Loading LLaVA-Pretrain dataset...")
427
  dataset_root = None
428
  images_zip_path = None
429
+ zip_members = None
430
  try:
431
  data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
432
  except Exception as exc:
 
438
  )
439
  json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
440
  images_zip_path = os.path.join(dataset_root, "images.zip")
441
+ if os.path.exists(images_zip_path):
442
+ with zipfile.ZipFile(images_zip_path, "r") as archive:
443
+ zip_members = set(archive.namelist())
444
  data = load_dataset("json", data_files={"train": json_path}, split="train")
445
  if max_samples:
446
  data = data.select(range(min(max_samples, len(data))))
447
 
448
+ stats = defaultdict(int)
449
+
450
+ def normalize(sample: Dict[str, object], idx: int) -> Optional[Dict[str, object]]:
451
  text = ""
452
  if "conversations" in sample:
453
  parts = []
 
462
  text = "Describe this image."
463
 
464
  image_obj = sample.get("image")
465
+ if image_obj is None:
466
+ stats["missing_image_ref"] += 1
467
+ return None
468
  if isinstance(image_obj, str) and dataset_root and not os.path.isabs(image_obj):
469
  candidate_paths = [
470
  image_obj,
 
474
  resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None)
475
  if resolved_path:
476
  image_obj = resolved_path
477
+ elif images_zip_path and os.path.exists(images_zip_path) and zip_members:
478
+ member_name = None
479
+ if image_obj in zip_members:
480
+ member_name = image_obj
481
+ elif f"images/{image_obj}" in zip_members:
482
+ member_name = f"images/{image_obj}"
483
+ if member_name is None:
484
+ stats["missing_backing_image"] += 1
485
+ return None
486
  image_obj = {
487
  "zip_path": images_zip_path,
488
+ "member": member_name,
489
  }
490
+ else:
491
+ stats["missing_backing_image"] += 1
492
+ return None
493
 
494
+ stats["kept"] += 1
495
  return {
496
  "image": image_obj,
497
  "prompt_text": "Describe this image.",
 
500
  "source_config": "llava_pretrain",
501
  }
502
 
503
+ records = [record for i in range(len(data)) if (record := normalize(data[i], i)) is not None]
504
  normalized = HFDataset.from_list(records)
505
+ print(
506
+ f"Loaded {len(normalized)} LLaVA samples "
507
+ f"(kept={stats['kept']}, missing_image_ref={stats['missing_image_ref']}, "
508
+ f"missing_backing_image={stats['missing_backing_image']})"
509
+ )
510
  return normalized
511
 
512