Filter invalid LLaVA image records for real runs
Browse files- code/train_production.py +42 -20
code/train_production.py
CHANGED
|
@@ -360,21 +360,18 @@ def tokenize_prompt_and_target(
|
|
| 360 |
|
| 361 |
|
| 362 |
def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
with
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
|
| 376 |
-
except Exception:
|
| 377 |
-
img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
|
| 378 |
|
| 379 |
pil_image = img
|
| 380 |
resized = pil_image.resize((img_size, img_size), Image.BICUBIC)
|
|
@@ -429,6 +426,7 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
|
| 429 |
print("Loading LLaVA-Pretrain dataset...")
|
| 430 |
dataset_root = None
|
| 431 |
images_zip_path = None
|
|
|
|
| 432 |
try:
|
| 433 |
data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
|
| 434 |
except Exception as exc:
|
|
@@ -440,11 +438,16 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
|
| 440 |
)
|
| 441 |
json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
|
| 442 |
images_zip_path = os.path.join(dataset_root, "images.zip")
|
|
|
|
|
|
|
|
|
|
| 443 |
data = load_dataset("json", data_files={"train": json_path}, split="train")
|
| 444 |
if max_samples:
|
| 445 |
data = data.select(range(min(max_samples, len(data))))
|
| 446 |
|
| 447 |
-
|
|
|
|
|
|
|
| 448 |
text = ""
|
| 449 |
if "conversations" in sample:
|
| 450 |
parts = []
|
|
@@ -459,6 +462,9 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
|
| 459 |
text = "Describe this image."
|
| 460 |
|
| 461 |
image_obj = sample.get("image")
|
|
|
|
|
|
|
|
|
|
| 462 |
if isinstance(image_obj, str) and dataset_root and not os.path.isabs(image_obj):
|
| 463 |
candidate_paths = [
|
| 464 |
image_obj,
|
|
@@ -468,12 +474,24 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
|
| 468 |
resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None)
|
| 469 |
if resolved_path:
|
| 470 |
image_obj = resolved_path
|
| 471 |
-
elif images_zip_path and os.path.exists(images_zip_path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
image_obj = {
|
| 473 |
"zip_path": images_zip_path,
|
| 474 |
-
"member":
|
| 475 |
}
|
|
|
|
|
|
|
|
|
|
| 476 |
|
|
|
|
| 477 |
return {
|
| 478 |
"image": image_obj,
|
| 479 |
"prompt_text": "Describe this image.",
|
|
@@ -482,9 +500,13 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
|
| 482 |
"source_config": "llava_pretrain",
|
| 483 |
}
|
| 484 |
|
| 485 |
-
records = [
|
| 486 |
normalized = HFDataset.from_list(records)
|
| 487 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
return normalized
|
| 489 |
|
| 490 |
|
|
|
|
| 360 |
|
| 361 |
|
| 362 |
def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
|
| 363 |
+
if isinstance(img, str):
|
| 364 |
+
img = Image.open(img).convert("RGB")
|
| 365 |
+
elif isinstance(img, dict) and "bytes" in img:
|
| 366 |
+
img = Image.open(BytesIO(img["bytes"])).convert("RGB")
|
| 367 |
+
elif isinstance(img, dict) and "zip_path" in img and "member" in img:
|
| 368 |
+
with zipfile.ZipFile(img["zip_path"], "r") as archive:
|
| 369 |
+
with archive.open(img["member"], "r") as member_file:
|
| 370 |
+
img = Image.open(member_file).convert("RGB")
|
| 371 |
+
elif isinstance(img, Image.Image):
|
| 372 |
+
img = img.convert("RGB")
|
| 373 |
+
else:
|
| 374 |
+
raise ValueError(f"Unsupported image payload type: {type(img)!r}")
|
|
|
|
|
|
|
|
|
|
| 375 |
|
| 376 |
pil_image = img
|
| 377 |
resized = pil_image.resize((img_size, img_size), Image.BICUBIC)
|
|
|
|
| 426 |
print("Loading LLaVA-Pretrain dataset...")
|
| 427 |
dataset_root = None
|
| 428 |
images_zip_path = None
|
| 429 |
+
zip_members = None
|
| 430 |
try:
|
| 431 |
data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
|
| 432 |
except Exception as exc:
|
|
|
|
| 438 |
)
|
| 439 |
json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
|
| 440 |
images_zip_path = os.path.join(dataset_root, "images.zip")
|
| 441 |
+
if os.path.exists(images_zip_path):
|
| 442 |
+
with zipfile.ZipFile(images_zip_path, "r") as archive:
|
| 443 |
+
zip_members = set(archive.namelist())
|
| 444 |
data = load_dataset("json", data_files={"train": json_path}, split="train")
|
| 445 |
if max_samples:
|
| 446 |
data = data.select(range(min(max_samples, len(data))))
|
| 447 |
|
| 448 |
+
stats = defaultdict(int)
|
| 449 |
+
|
| 450 |
+
def normalize(sample: Dict[str, object], idx: int) -> Optional[Dict[str, object]]:
|
| 451 |
text = ""
|
| 452 |
if "conversations" in sample:
|
| 453 |
parts = []
|
|
|
|
| 462 |
text = "Describe this image."
|
| 463 |
|
| 464 |
image_obj = sample.get("image")
|
| 465 |
+
if image_obj is None:
|
| 466 |
+
stats["missing_image_ref"] += 1
|
| 467 |
+
return None
|
| 468 |
if isinstance(image_obj, str) and dataset_root and not os.path.isabs(image_obj):
|
| 469 |
candidate_paths = [
|
| 470 |
image_obj,
|
|
|
|
| 474 |
resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None)
|
| 475 |
if resolved_path:
|
| 476 |
image_obj = resolved_path
|
| 477 |
+
elif images_zip_path and os.path.exists(images_zip_path) and zip_members:
|
| 478 |
+
member_name = None
|
| 479 |
+
if image_obj in zip_members:
|
| 480 |
+
member_name = image_obj
|
| 481 |
+
elif f"images/{image_obj}" in zip_members:
|
| 482 |
+
member_name = f"images/{image_obj}"
|
| 483 |
+
if member_name is None:
|
| 484 |
+
stats["missing_backing_image"] += 1
|
| 485 |
+
return None
|
| 486 |
image_obj = {
|
| 487 |
"zip_path": images_zip_path,
|
| 488 |
+
"member": member_name,
|
| 489 |
}
|
| 490 |
+
else:
|
| 491 |
+
stats["missing_backing_image"] += 1
|
| 492 |
+
return None
|
| 493 |
|
| 494 |
+
stats["kept"] += 1
|
| 495 |
return {
|
| 496 |
"image": image_obj,
|
| 497 |
"prompt_text": "Describe this image.",
|
|
|
|
| 500 |
"source_config": "llava_pretrain",
|
| 501 |
}
|
| 502 |
|
| 503 |
+
records = [record for i in range(len(data)) if (record := normalize(data[i], i)) is not None]
|
| 504 |
normalized = HFDataset.from_list(records)
|
| 505 |
+
print(
|
| 506 |
+
f"Loaded {len(normalized)} LLaVA samples "
|
| 507 |
+
f"(kept={stats['kept']}, missing_image_ref={stats['missing_image_ref']}, "
|
| 508 |
+
f"missing_backing_image={stats['missing_backing_image']})"
|
| 509 |
+
)
|
| 510 |
return normalized
|
| 511 |
|
| 512 |
|