personal_math / download_data.py
psidharth567's picture
Sync full project: code, checkpoints, datasets, logs
dcd2bd2 verified
import os
import requests
import zipfile
import tarfile
from tqdm import tqdm
# Target directory for all datasets
BASE_DIR = "./datasets"
# URLs for the industry-standard high-res and benchmark datasets
DATASETS = {
# High-Resolution Training & Validation (for H100s)
#"Flickr2K": "https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar",
# Standard Benchmark Test Sets (Hosted reliably on popular CV repos)
"Test_Datasets": "https://github.com/cszn/FFDNet/archive/refs/heads/master.zip"
}
def download_file(url, dest_path):
"""Downloads a file with a progress bar and robust error handling."""
if os.path.exists(dest_path):
print(f"[*] {os.path.basename(dest_path)} already exists. Skipping download.")
return
print(f"[*] Downloading {url}...")
# Disguise the script as a standard web browser
headers = {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
}
try:
response = requests.get(url, stream=True, headers=headers, timeout=30)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 * 1024 # 1 MB chunks
with open(dest_path, 'wb') as file, tqdm(
total=total_size, unit='B', unit_scale=True, desc=os.path.basename(dest_path)
) as bar:
for data in response.iter_content(block_size):
file.write(data)
bar.update(len(data))
except requests.exceptions.RequestException as e:
print(f"\n[!] The server rejected the connection: {e}")
print(f"[!] Skipping {os.path.basename(dest_path)}. You can proceed without it.")
# Remove the partial file if it failed midway
if os.path.exists(dest_path):
os.remove(dest_path)
def extract_file(file_path, extract_to):
"""Extracts zip or tar files."""
print(f"[*] Extracting {os.path.basename(file_path)}...")
if file_path.endswith(".zip"):
with zipfile.ZipFile(file_path, 'r') as zip_ref:
zip_ref.extractall(extract_to)
elif file_path.endswith(".tar") or file_path.endswith(".tar.gz"):
with tarfile.open(file_path, 'r:*') as tar_ref:
tar_ref.extractall(extract_to)
else:
print(f"[!] Unknown file format for {file_path}")
def main():
os.makedirs(BASE_DIR, exist_ok=True)
for name, url in DATASETS.items():
print(f"\n--- Processing {name} ---")
# Determine file extension and destination paths
ext = ".tar" if ".tar" in url else ".zip"
file_name = f"{name}{ext}"
download_path = os.path.join(BASE_DIR, file_name)
# Download
download_file(url, download_path)
# Extract
extract_dir = os.path.join(BASE_DIR, name)
os.makedirs(extract_dir, exist_ok=True)
extract_file(download_path, extract_dir)
# Clean up the archive to save disk space
print(f"[*] Cleaning up archive {file_name}...")
os.remove(download_path)
print("\n[+] All datasets downloaded and extracted successfully!")
print(f"[+] Look inside the '{BASE_DIR}' folder.")
if __name__ == "__main__":
main()