Fix LLaVA fallback image loading in Stage 1
Browse files- code/train_production.py +25 -8
code/train_production.py
CHANGED
|
@@ -14,6 +14,7 @@ import json
|
|
| 14 |
import math
|
| 15 |
import os
|
| 16 |
import time
|
|
|
|
| 17 |
from collections import defaultdict
|
| 18 |
from dataclasses import dataclass
|
| 19 |
from io import BytesIO
|
|
@@ -359,13 +360,20 @@ def tokenize_prompt_and_target(
|
|
| 359 |
|
| 360 |
|
| 361 |
def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
img
|
| 368 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
|
| 370 |
|
| 371 |
pil_image = img
|
|
@@ -420,6 +428,7 @@ class NormalizedVisionLanguageDataset(Dataset):
|
|
| 420 |
def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
| 421 |
print("Loading LLaVA-Pretrain dataset...")
|
| 422 |
dataset_root = None
|
|
|
|
| 423 |
try:
|
| 424 |
data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
|
| 425 |
except Exception as exc:
|
|
@@ -430,6 +439,7 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
|
| 430 |
allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"],
|
| 431 |
)
|
| 432 |
json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
|
|
|
|
| 433 |
data = load_dataset("json", data_files={"train": json_path}, split="train")
|
| 434 |
if max_samples:
|
| 435 |
data = data.select(range(min(max_samples, len(data))))
|
|
@@ -455,7 +465,14 @@ def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
|
| 455 |
os.path.join(dataset_root, image_obj),
|
| 456 |
os.path.join(dataset_root, "images", image_obj),
|
| 457 |
]
|
| 458 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 459 |
|
| 460 |
return {
|
| 461 |
"image": image_obj,
|
|
|
|
| 14 |
import math
|
| 15 |
import os
|
| 16 |
import time
|
| 17 |
+
import zipfile
|
| 18 |
from collections import defaultdict
|
| 19 |
from dataclasses import dataclass
|
| 20 |
from io import BytesIO
|
|
|
|
| 360 |
|
| 361 |
|
| 362 |
def preprocess_image_for_student(img: object, img_size: int) -> Tuple[torch.Tensor, Image.Image]:
|
| 363 |
+
try:
|
| 364 |
+
if isinstance(img, str):
|
| 365 |
+
img = Image.open(img).convert("RGB")
|
| 366 |
+
elif isinstance(img, dict) and "bytes" in img:
|
| 367 |
+
img = Image.open(BytesIO(img["bytes"])).convert("RGB")
|
| 368 |
+
elif isinstance(img, dict) and "zip_path" in img and "member" in img:
|
| 369 |
+
with zipfile.ZipFile(img["zip_path"], "r") as archive:
|
| 370 |
+
with archive.open(img["member"], "r") as member_file:
|
| 371 |
+
img = Image.open(member_file).convert("RGB")
|
| 372 |
+
elif isinstance(img, Image.Image):
|
| 373 |
+
img = img.convert("RGB")
|
| 374 |
+
else:
|
| 375 |
+
img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
|
| 376 |
+
except Exception:
|
| 377 |
img = Image.new("RGB", (img_size, img_size), (128, 128, 128))
|
| 378 |
|
| 379 |
pil_image = img
|
|
|
|
| 428 |
def build_llava_records(max_samples: Optional[int]) -> HFDataset:
|
| 429 |
print("Loading LLaVA-Pretrain dataset...")
|
| 430 |
dataset_root = None
|
| 431 |
+
images_zip_path = None
|
| 432 |
try:
|
| 433 |
data = load_dataset("liuhaotian/LLaVA-Pretrain", split="train")
|
| 434 |
except Exception as exc:
|
|
|
|
| 439 |
allow_patterns=["blip_laion_cc_sbu_558k.json", "images.zip"],
|
| 440 |
)
|
| 441 |
json_path = os.path.join(dataset_root, "blip_laion_cc_sbu_558k.json")
|
| 442 |
+
images_zip_path = os.path.join(dataset_root, "images.zip")
|
| 443 |
data = load_dataset("json", data_files={"train": json_path}, split="train")
|
| 444 |
if max_samples:
|
| 445 |
data = data.select(range(min(max_samples, len(data))))
|
|
|
|
| 465 |
os.path.join(dataset_root, image_obj),
|
| 466 |
os.path.join(dataset_root, "images", image_obj),
|
| 467 |
]
|
| 468 |
+
resolved_path = next((path for path in candidate_paths if os.path.exists(path)), None)
|
| 469 |
+
if resolved_path:
|
| 470 |
+
image_obj = resolved_path
|
| 471 |
+
elif images_zip_path and os.path.exists(images_zip_path):
|
| 472 |
+
image_obj = {
|
| 473 |
+
"zip_path": images_zip_path,
|
| 474 |
+
"member": image_obj,
|
| 475 |
+
}
|
| 476 |
|
| 477 |
return {
|
| 478 |
"image": image_obj,
|