| import os |
| import argparse |
| import threading |
| import random |
| from concurrent.futures import ThreadPoolExecutor, as_completed |
| from pydrive2.auth import GoogleAuth |
| from pydrive2.drive import GoogleDrive |
| from tqdm import tqdm |
| from pathlib import Path |
|
|
| thread_local = threading.local() |
|
|
|
|
| def _get_thread_drive(service_account_json: str) -> GoogleDrive: |
| d = getattr(thread_local, "drive", None) |
| if d is None: |
| d = authenticate(service_account_json) |
| thread_local.drive = d |
| return d |
|
|
|
|
| def authenticate(service_account_json): |
| """Authenticate PyDrive2 with a service account.""" |
| gauth = GoogleAuth() |
| |
| gauth.settings["client_config_backend"] = "service" |
| gauth.settings["service_config"] = { |
| "client_json_file_path": service_account_json, |
| |
| "client_user_email": "drive-bot@web-design-396514.iam.gserviceaccount.com", |
| } |
| gauth.ServiceAuth() |
| drive = GoogleDrive(gauth) |
| return drive |
|
|
|
|
| def list_files_with_paths(drive, folder_id, prefix=""): |
| """Recursively collect all files with their relative paths from a folder.""" |
| items = [] |
| query = f"'{folder_id}' in parents and trashed=false" |
| params = { |
| "q": query, |
| "maxResults": 1000, |
| |
| "fields": "items(id,title,mimeType,fileSize,md5Checksum),nextPageToken", |
| } |
| for file in drive.ListFile(params).GetList(): |
| if file["mimeType"] == "application/vnd.google-apps.folder": |
| sub_prefix = (f"{prefix}/{file['title']}" if prefix else file["title"]) |
| items += list_files_with_paths(drive, file["id"], sub_prefix) |
| else: |
| rel_path = f"{prefix}/{file['title']}" if prefix else file["title"] |
| size = int(file.get("fileSize", 0)) if "fileSize" in file else 0 |
| items.append( |
| { |
| "id": file["id"], |
| "rel_path": rel_path, |
| "size": size, |
| "md5": file.get("md5Checksum", ""), |
| "mimeType": file["mimeType"], |
| } |
| ) |
| return items |
|
|
|
|
| def download_folder(folder_id, dest, service_account_json, workers: int): |
| drive = authenticate(service_account_json) |
| Path(dest).mkdir(parents=True, exist_ok=True) |
|
|
| print(f"Listing files in folder {folder_id}...") |
| files_with_paths = list_files_with_paths(drive, folder_id) |
| total = len(files_with_paths) |
| print(f"Found {total} files. Planning downloads...") |
|
|
| |
| tasks = [] |
| skipped = 0 |
| for meta in files_with_paths: |
| out_path = Path(dest) / meta["rel_path"] |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| if ( |
| meta["size"] > 0 |
| and out_path.exists() |
| and out_path.stat().st_size == meta["size"] |
| ): |
| skipped += 1 |
| continue |
| tasks.append((meta["id"], str(out_path))) |
|
|
| print(f"Skipping {skipped} existing files; {len(tasks)} to download.") |
|
|
| def _download_one(file_id: str, out_path: str): |
| d = _get_thread_drive(service_account_json) |
| f = d.CreateFile({"id": file_id}) |
| f.GetContentFile(out_path) |
|
|
| if len(tasks) == 0: |
| print("All files are up to date.") |
| return |
|
|
| with ThreadPoolExecutor(max_workers=workers) as ex: |
| futures = [ex.submit(_download_one, fid, path) for fid, path in tasks] |
| for _ in tqdm( |
| as_completed(futures), total=len(futures), desc="Downloading", unit="file" |
| ): |
| pass |
|
|
|
|
| def pull(args=None): |
| parser = argparse.ArgumentParser( |
| description="Download a full Google Drive folder using a service account" |
| ) |
| parser.add_argument( |
| "--folder-id", |
| dest="folder_id", |
| default="1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p", |
| help="Google Drive folder ID", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| dest="output_dir", |
| default="dataset/", |
| help="Directory to save files", |
| ) |
| parser.add_argument( |
| "--service-account", |
| default="secrets/drive-json.json", |
| help="Path to your Google service account JSON key file", |
| ) |
| parser.add_argument( |
| "--workers", |
| type=int, |
| default=8, |
| help="Number of parallel download workers", |
| ) |
| parsed = parser.parse_args(args=args) |
|
|
| download_folder( |
| parsed.folder_id, parsed.output_dir, parsed.service_account, parsed.workers |
| ) |
|
|
|
|
| def _index_numeric_pairs(images_dir: Path, masks_dir: Path): |
| assert images_dir.exists() and images_dir.is_dir(), ( |
| f"Missing images_dir: {images_dir}" |
| ) |
| assert masks_dir.exists() and masks_dir.is_dir(), f"Missing masks_dir: {masks_dir}" |
| img_files = sorted([p for p in images_dir.glob("*.jpg") if p.is_file()]) |
| img_files += sorted([p for p in images_dir.glob("*.jpeg") if p.is_file()]) |
| assert len(img_files) > 0, f"No .jpg/.jpeg images in {images_dir}" |
| ids = [] |
| for p in img_files: |
| stem = p.stem |
| assert stem.isdigit(), f"Non-numeric filename encountered: {p.name}" |
| ids.append(int(stem)) |
| ids = sorted(ids) |
| pairs = [] |
| for i in ids: |
| ip_jpg = images_dir / f"{i}.jpg" |
| ip_jpeg = images_dir / f"{i}.jpeg" |
| ip = ip_jpg if ip_jpg.exists() else ip_jpeg |
| assert ip.exists(), f"Missing image for {i}: {ip_jpg} or {ip_jpeg}" |
| mp = masks_dir / f"{i}.png" |
| assert mp.exists(), f"Missing mask for {i}: {mp}" |
| pairs.append((ip, mp)) |
| assert len(pairs) > 0, "No numeric pairs found" |
| return pairs |
|
|
|
|
| def split_test_train_val(args=None): |
| parser = argparse.ArgumentParser( |
| description="Split dataset into train/val/test = 85/5/10 with numeric pairs" |
| ) |
| parser.add_argument("--images-dir", required=True, help="Path to images directory") |
| parser.add_argument("--masks-dir", required=True, help="Path to masks directory") |
| parser.add_argument( |
| "--out-dir", |
| required=True, |
| help="Output root dir where train/ val/ test/ will be created", |
| ) |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| parser.add_argument( |
| "--link-method", |
| choices=["symlink", "copy"], |
| default="symlink", |
| help="How to place files into splits", |
| ) |
| parsed = parser.parse_args(args=args) |
|
|
| images_dir = Path(parsed.images_dir) |
| masks_dir = Path(parsed.masks_dir) |
| out_root = Path(parsed.out_dir) |
| pairs = _index_numeric_pairs(images_dir, masks_dir) |
|
|
| n = len(pairs) |
| n_train = int(0.85 * n) |
| n_val = int(0.05 * n) |
| rng = random.Random(parsed.seed) |
| idxs = list(range(n)) |
| rng.shuffle(idxs) |
| train_idx = idxs[:n_train] |
| val_idx = idxs[n_train : n_train + n_val] |
| test_idx = idxs[n_train + n_val :] |
|
|
| def _ensure_dirs(root: Path): |
| (root / "images").mkdir(parents=True, exist_ok=True) |
| (root / "gts").mkdir(parents=True, exist_ok=True) |
|
|
| def _place(src: Path, dst: Path): |
| if parsed.link_method == "symlink": |
| try: |
| if dst.exists() or dst.is_symlink(): |
| dst.unlink() |
| os.symlink(str(src), str(dst)) |
| except FileExistsError: |
| pass |
| else: |
| if dst.exists(): |
| dst.unlink() |
| |
| try: |
| os.link(str(src), str(dst)) |
| except OSError: |
| import shutil |
|
|
| shutil.copy2(str(src), str(dst)) |
|
|
| for split_name, split_ids in ( |
| ("train", train_idx), |
| ("val", val_idx), |
| ("test", test_idx), |
| ): |
| root = out_root / split_name |
| _ensure_dirs(root) |
| for k in split_ids: |
| img_p, mask_p = pairs[k] |
| (root / "images" / img_p.name).parent.mkdir(parents=True, exist_ok=True) |
| (root / "gts" / mask_p.name).parent.mkdir(parents=True, exist_ok=True) |
| _place(img_p, root / "images" / img_p.name) |
| _place(mask_p, root / "gts" / mask_p.name) |
| print( |
| f"Split written to {out_root} | train={len(train_idx)} val={len(val_idx)} test={len(test_idx)}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| |
| path = Path("./dataset") |
| path.mkdir(exist_ok=True) |
|
|
| |
| top = argparse.ArgumentParser(description="WireSegHR data utilities") |
| subs = top.add_subparsers(dest="cmd", required=True) |
|
|
| sp_pull = subs.add_parser("pull", help="Download dataset from Google Drive") |
| sp_pull.add_argument( |
| "--folder-id", dest="folder_id", default="1fgy3wn_yuHEeMNbfiHNVl1-jEdYOfu6p" |
| ) |
| sp_pull.add_argument("--output-dir", dest="output_dir", default="dataset/") |
| sp_pull.add_argument("--service-account", default="secrets/drive-json.json") |
| sp_pull.add_argument("--workers", type=int, default=8) |
|
|
| sp_split = subs.add_parser( |
| "split_test_train_val", help="Create 85/5/10 train/val/test split" |
| ) |
| sp_split.add_argument("--images-dir", required=True) |
| sp_split.add_argument("--masks-dir", required=True) |
| sp_split.add_argument("--out-dir", required=True) |
| sp_split.add_argument("--seed", type=int, default=42) |
| sp_split.add_argument( |
| "--link-method", choices=["symlink", "copy"], default="symlink" |
| ) |
|
|
| ns = top.parse_args() |
| if ns.cmd == "pull": |
| pull( |
| [ |
| "--folder-id", |
| ns.folder_id, |
| "--output-dir", |
| ns.output_dir, |
| "--service-account", |
| ns.service_account, |
| "--workers", |
| str(ns.workers), |
| ] |
| ) |
| elif ns.cmd == "split_test_train_val": |
| split_test_train_val( |
| [ |
| "--images-dir", |
| ns.images_dir, |
| "--masks-dir", |
| ns.masks_dir, |
| "--out-dir", |
| ns.out_dir, |
| "--seed", |
| str(ns.seed), |
| "--link-method", |
| ns.link_method, |
| ] |
| ) |
|
|