gap-clip / config.py
Leacb4's picture
Upload config.py with huggingface_hub
bc323e6 verified
"""
Project configuration for GAP-CLIP scripts.
This module provides default paths, column names, and runtime constants used by
training/evaluation scripts. Values can be edited locally as needed.
"""
from __future__ import annotations
from pathlib import Path
import torch
def _detect_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
if torch.backends.mps.is_available():
return torch.device("mps")
return torch.device("cpu")
ROOT_DIR = Path(__file__).resolve().parent
# Runtime/device
device = _detect_device()
# Embedding dimensions
color_emb_dim = 16
hierarchy_emb_dim = 64
main_emb_dim = 512
# Default training hyperparameters
DEFAULT_BATCH_SIZE = 64
DEFAULT_LEARNING_RATE = 1.5e-5
DEFAULT_TEMPERATURE = 0.09
# Data columns
text_column = "text"
color_column = "color"
hierarchy_column = "hierarchy"
column_local_image_path = "local_image_path"
column_url_image = "image_url"
# Paths
local_dataset_path = str(ROOT_DIR / "data" / "data.csv")
color_model_path = str(ROOT_DIR / "models" / "color_model.pt")
hierarchy_model_path = str(ROOT_DIR / "models" / "hierarchy_model.pth")
main_model_path = str(ROOT_DIR / "models" / "gap_clip.pth")
images_dir = str(ROOT_DIR / "data" / "images")
fashion_mnist_csv = str(ROOT_DIR / "data" / "fashion-mnist_test.csv")
def print_config() -> None:
"""Pretty-print core configuration."""
print("GAP-CLIP Configuration")
print(f" device: {device}")
print(f" dims: color={color_emb_dim}, hierarchy={hierarchy_emb_dim}, total={main_emb_dim}")
print(f" dataset: {local_dataset_path}")
print(f" color model: {color_model_path}")
print(f" hierarchy model: {hierarchy_model_path}")
print(f" main model: {main_model_path}")
def validate_paths() -> dict[str, bool]:
"""Return path existence checks for key files."""
checks = {
"local_dataset_path": Path(local_dataset_path).exists(),
"color_model_path": Path(color_model_path).exists(),
"hierarchy_model_path": Path(hierarchy_model_path).exists(),
"main_model_path": Path(main_model_path).exists()
}
return checks