| """ |
| Download script for CSI-4CAST datasets. |
| |
| This script downloads all available datasets from the CSI-4CAST Hugging Face organization |
| by checking for all possible combinations of channel models, delay spreads, and speeds. |
| |
| Usage: |
| python3 download.py [--output-dir OUTPUT_DIR] |
| |
| If no arguments provided, it will download datasets to a 'datasets' folder. |
| """ |
|
|
| import argparse |
| from pathlib import Path |
|
|
| from huggingface_hub import HfApi, snapshot_download |
| from tqdm import tqdm |
|
|
| |
| ORG = "CSI-4CAST" |
|
|
| |
| LIST_CHANNEL_MODEL = ["A", "C", "D"] |
| LIST_DELAY_SPREAD = [30e-9, 100e-9, 300e-9] |
| LIST_MIN_SPEED = [1, 10, 30] |
|
|
| |
| LIST_CHANNEL_MODEL_GEN = ["A", "B", "C", "D", "E"] |
| LIST_DELAY_SPREAD_GEN = [30e-9, 50e-9, 100e-9, 200e-9, 300e-9, 400e-9] |
| LIST_MIN_SPEED_GEN = sorted([*range(3, 46, 3), 1, 10]) |
|
|
| def make_folder_name(cm: str, ds: float, ms: int, **kwargs) -> str: |
| """Generate a standardized folder name based on channel model, delay spread, and minimum speed. |
| |
| Args: |
| cm (str): Channel model identifier (e.g., 'A', 'B', 'C', 'D', 'E') |
| ds (float): Delay spread in seconds (e.g., 30e-9, 100e-9, 300e-9) |
| ms (int): Minimum speed in km/h (e.g., 1, 10, 30) |
| **kwargs: Additional keyword arguments (unused) |
| |
| Returns: |
| str: Formatted folder name in the format 'cm_{cm}_ds_{ds}_ms_{ms}' |
| where ds is converted to nanoseconds and zero-padded to 3 digits, |
| and ms is zero-padded to 3 digits |
| |
| Example: |
| >>> make_folder_name('A', 30e-9, 10) |
| 'cm_A_ds_030_ms_010' |
| """ |
| |
| ds = round(ds * 1e9) |
| ds_str = str(ds).zfill(3) |
|
|
| |
| ms_str = str(ms) |
| ms_str = ms_str.zfill(3) |
|
|
| |
| return f"cm_{cm}_ds_{ds_str}_ms_{ms_str}" |
|
|
| def check_repo_exists(api: HfApi, repo_id: str) -> bool: |
| """Check if a repository exists in the organization.""" |
| try: |
| api.repo_info(repo_id, repo_type="dataset") |
| return True |
| except Exception: |
| return False |
|
|
| def generate_dataset_combinations(): |
| """Generate all possible dataset combinations.""" |
| combinations = [] |
| |
| |
| combinations.append("stats") |
| |
| |
| for cm in LIST_CHANNEL_MODEL: |
| for ds in LIST_DELAY_SPREAD: |
| for ms in LIST_MIN_SPEED: |
| folder_name = make_folder_name(cm, ds, ms) |
| repo_name = f"train_regular_{folder_name}" |
| combinations.append(repo_name) |
| |
| |
| for cm in LIST_CHANNEL_MODEL: |
| for ds in LIST_DELAY_SPREAD: |
| for ms in LIST_MIN_SPEED: |
| folder_name = make_folder_name(cm, ds, ms) |
| repo_name = f"test_regular_{folder_name}" |
| combinations.append(repo_name) |
| |
| |
| for cm in LIST_CHANNEL_MODEL_GEN: |
| for ds in LIST_DELAY_SPREAD_GEN: |
| for ms in LIST_MIN_SPEED_GEN: |
| folder_name = make_folder_name(cm, ds, ms) |
| repo_name = f"test_generalization_{folder_name}" |
| combinations.append(repo_name) |
| |
| return combinations |
|
|
| def download_dataset(api: HfApi, org: str, repo_name: str, output_dir: Path, dry_run: bool = False) -> bool: |
| """Download a single dataset if it exists.""" |
| repo_id = f"{org}/{repo_name}" |
| |
| if not check_repo_exists(api, repo_id): |
| return False |
| |
| try: |
| |
| target_dir = output_dir / repo_name |
| target_dir.mkdir(parents=True, exist_ok=True) |
| |
| if dry_run: |
| |
| placeholder_file = target_dir / "placeholder.txt" |
| placeholder_file.write_text("") |
| print(f"β
Dry run - Created placeholder: {repo_name}") |
| else: |
| |
| snapshot_download( |
| repo_id=repo_id, |
| repo_type="dataset", |
| local_dir=target_dir, |
| local_dir_use_symlinks=False |
| ) |
| print(f"β
Downloaded: {repo_name}") |
| |
| return True |
| |
| except Exception as e: |
| print(f"β Error downloading {repo_name}: {e}") |
| return False |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Download all CSI-4CAST datasets from Hugging Face") |
| parser.add_argument("--output-dir", "-o", default="datasets", |
| help="Output directory for downloaded datasets (default: 'datasets')") |
| parser.add_argument("--dry-run", action="store_true", |
| help="Dry run mode: create empty placeholder files instead of downloading") |
| |
| args = parser.parse_args() |
| |
| output_dir = Path(args.output_dir).resolve() |
| org = ORG |
| |
| mode = "Dry run" if args.dry_run else "Downloading" |
| print(f"{mode} datasets from organization: {org}") |
| print(f"Output directory: {output_dir}") |
| print() |
| |
| |
| output_dir.mkdir(parents=True, exist_ok=True) |
| |
| |
| api = HfApi() |
| |
| |
| print("Generating dataset combinations...") |
| combinations = generate_dataset_combinations() |
| print(f"Total possible combinations: {len(combinations)}") |
| print() |
| |
| |
| action = "Checking and creating placeholders for" if args.dry_run else "Checking and downloading" |
| print(f"{action} existing datasets...") |
| downloaded_count = 0 |
| skipped_count = 0 |
| |
| for repo_name in tqdm(combinations, desc="Processing datasets"): |
| if download_dataset(api, org, repo_name, output_dir, args.dry_run): |
| downloaded_count += 1 |
| else: |
| skipped_count += 1 |
| |
| print() |
| if args.dry_run: |
| print("π Dry run complete!") |
| print(f"β
Created placeholders: {downloaded_count} datasets") |
| print(f"βοΈ Skipped: {skipped_count} datasets (not found)") |
| print(f"π Placeholders saved to: {output_dir}") |
| else: |
| print("π Download complete!") |
| print(f"β
Downloaded: {downloaded_count} datasets") |
| print(f"βοΈ Skipped: {skipped_count} datasets (not found)") |
| print(f"π Datasets saved to: {output_dir}") |
| print() |
| print("To reconstruct the original folder structure, run:") |
| print(f"python3 reconstruction.py --input-dir {output_dir}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|