| |
| """Extract images and labels from Parquet files and save them into |
| subfolders by label. |
| |
| Usage: |
| python extract_images.py [--train] [--test] [--output OUTPUT_DIR] |
| |
| Defaults: |
| train: process training data (train-00000-of-00001.parquet) |
| test: process test data (test-00000-of-00001.parquet) |
| output: TrainData (relative to script location) |
| """ |
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
| import pyarrow.parquet as pq |
|
|
|
|
| def extract_images_from_parquet(parquet_path, output_dir, split_name): |
| """Extract images from a Parquet file and save them into label folders.""" |
|
|
| print(f"Processing {parquet_path}...") |
| |
| |
| try: |
| table = pq.read_table(parquet_path) |
| df = table.to_pandas() |
| except Exception as e: |
| print(f"Failed to read parquet file: {e}") |
| return False |
| |
| print(f"Found {len(df)} images") |
| |
| |
| unique_labels = sorted(df['label'].unique()) |
| print(f"Label classes: {unique_labels}") |
| |
| |
| for label in unique_labels: |
| label_dir = output_dir / split_name / f"label_{label}" |
| label_dir.mkdir(parents=True, exist_ok=True) |
| print(f"Created folder: {label_dir}") |
| |
| |
| success_count = 0 |
| error_count = 0 |
| |
| for idx, row in df.iterrows(): |
| try: |
| |
| image_struct = row['image'] |
| image_bytes = image_struct['bytes'] |
| original_path = image_struct['path'] |
| label = row['label'] |
| |
| |
| _, ext = os.path.splitext(original_path) |
| if not ext: |
| ext = '.jpg' |
| |
| |
| base_name = os.path.splitext(os.path.basename(original_path))[0] |
| filename = f"{base_name}{ext}" |
| |
| |
| label_dir = output_dir / split_name / f"label_{label}" |
| output_path = label_dir / filename |
| counter = 1 |
| while output_path.exists(): |
| filename = f"{base_name}_{counter}{ext}" |
| output_path = label_dir / filename |
| counter += 1 |
| |
| |
| with open(output_path, 'wb') as f: |
| f.write(image_bytes) |
| |
| success_count += 1 |
| if success_count % 100 == 0: |
| print(f"Processed {success_count} images...") |
| |
| except Exception as e: |
| print(f"Error processing image {idx}: {e}") |
| error_count += 1 |
| continue |
| |
| print(f"Done! Success: {success_count}, Failed: {error_count}") |
| |
| |
| print("\nImage count per label:") |
| for label in unique_labels: |
| label_dir = output_dir / split_name / f"label_{label}" |
| count = len(list(label_dir.glob("*"))) |
| print(f" label {label}: {count} images") |
| |
| return success_count > 0 |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Extract images from Parquet files and organize by label") |
| parser.add_argument("--train", action="store_true", help="process training data") |
| parser.add_argument("--test", action="store_true", help="process test data") |
| parser.add_argument("--output", "-o", default="TrainData", help="output directory") |
| |
| args = parser.parse_args() |
| |
| |
| if not args.train and not args.test: |
| args.train = True |
| args.test = True |
| |
| |
| script_dir = Path(__file__).parent |
| yoga_data_dir = script_dir / "YogaDataSet" / "data" |
| output_dir = Path(args.output) |
| |
| |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"Output directory: {output_dir.absolute()}") |
| |
| success = True |
| |
| |
| if args.train: |
| train_parquet = yoga_data_dir / "train-00000-of-00001.parquet" |
| if train_parquet.exists(): |
| if not extract_images_from_parquet(train_parquet, output_dir, "train"): |
| success = False |
| else: |
| print(f"Training parquet file not found: {train_parquet}") |
| success = False |
| |
| |
| if args.test: |
| test_parquet = yoga_data_dir / "test-00000-of-00001.parquet" |
| if test_parquet.exists(): |
| if not extract_images_from_parquet(test_parquet, output_dir, "test"): |
| success = False |
| else: |
| print(f"Test parquet file not found: {test_parquet}") |
| success = False |
| |
| if success: |
| print("\nβ
All images extracted!") |
| print(f"Images saved to: {output_dir.absolute()}") |
| print("Directory structure:") |
| print("TrainData/") |
| if args.train: |
| print("βββ train/") |
| print("β βββ label_0/") |
| print("β βββ label_1/") |
| print("β βββ ...") |
| if args.test: |
| print("βββ test/") |
| print(" βββ label_0/") |
| print(" βββ label_1/") |
| print(" βββ ...") |
| else: |
| print("\nβ Errors occurred during extraction") |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|