omar-ah commited on
Commit
ea9b821
·
1 Parent(s): 0d77b0a

Fix LLaVA fallback image loading in Stage 1

Browse files
Files changed (1) hide show
  1. code/train_production.py +25 -8
code/train_production.py CHANGED
@@ -14,6 +14,7 @@ import json
14
  import math
15
  import os
16
  import time
 
17
  from collections import defaultdict
18
  from dataclasses import dataclass
19
  from io import BytesIO
@@ -359,13 +360,20 @@ def tokenize_prompt_and_target(
359
 
360
 
361
  def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
362
- if isinstance(img, str):
363
- img = Image.open(img).convert("RGB")
364
- elif isinstance(img, dict) and "bytes" in img:
365
- img = Image.open(BytesIO(img["bytes"])).convert("RGB")
366
- elif isinstance(img, Image.Image):
367
- img = img.convert("RGB")
368
- else:
 
 
 
 
 
 
 
369
  img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
370
 
371
  pil_image = img
@@ -420,6 +428,7 @@ class NormalizedVisionLanguageDataset(Dataset):
420
  def build_llava_records(max_samples: Optional[int]) -> HFDataset:
421
  print("Loading LLaVA-Pretrain dataset...")
422
  dataset_root = None
 
423
  try:
424
  data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
425
  except Exception as exc:
@@ -430,6 +439,7 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
430
  allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"],
431
  )
432
  json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
 
433
  data = load_dataset("json", data_files={"train": json_path}, split="train")
434
  if max_samples:
435
  data = data.select(range(min(max_samples, len(data))))
@@ -455,7 +465,14 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
455
  os.path.join(dataset_root, image_obj),
456
  os.path.join(dataset_root, "images", image_obj),
457
  ]
458
- image_obj = next((path for path in candidate_paths if os.path.exists(path)), image_obj)
 
 
 
 
 
 
 
459
 
460
  return {
461
  "image": image_obj,
 
14
  import math
15
  import os
16
  import time
17
+ import zipfile
18
  from collections import defaultdict
19
  from dataclasses import dataclass
20
  from io import BytesIO
 
360
 
361
 
362
  def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
363
+ try:
364
+ if isinstance(img, str):
365
+ img = Image.open(img).convert("RGB")
366
+ elif isinstance(img, dict) and "bytes" in img:
367
+ img = Image.open(BytesIO(img["bytes"])).convert("RGB")
368
+ elif isinstance(img, dict) and "zip_path" in img and "member" in img:
369
+ with zipfile.ZipFile(img["zip_path"], "r") as archive:
370
+ with archive.open(img["member"], "r") as member_file:
371
+ img = Image.open(member_file).convert("RGB")
372
+ elif isinstance(img, Image.Image):
373
+ img = img.convert("RGB")
374
+ else:
375
+ img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
376
+ except Exception:
377
  img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
378
 
379
  pil_image = img
 
428
  def build_llava_records(max_samples: Optional[int]) -> HFDataset:
429
  print("Loading LLaVA-Pretrain dataset...")
430
  dataset_root = None
431
+ images_zip_path = None
432
  try:
433
  data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
434
  except Exception as exc:
 
439
  allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"],
440
  )
441
  json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
442
+ images_zip_path = os.path.join(dataset_root, "images.zip")
443
  data = load_dataset("json", data_files={"train": json_path}, split="train")
444
  if max_samples:
445
  data = data.select(range(min(max_samples, len(data))))
 
465
  os.path.join(dataset_root, image_obj),
466
  os.path.join(dataset_root, "images", image_obj),
467
  ]
468
+ resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None)
469
+ if resolved_path:
470
+ image_obj = resolved_path
471
+ elif images_zip_path and os.path.exists(images_zip_path):
472
+ image_obj = {
473
+ "zip_path": images_zip_path,
474
+ "member": image_obj,
475
+ }
476
 
477
  return {
478
  "image": image_obj,