| """ |
| Project configuration for GAP-CLIP scripts. |
| |
| This module provides default paths, column names, and runtime constants used by |
| training/evaluation scripts. Values can be edited locally as needed. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
| import torch |
|
|
|
|
| def _detect_device() -> torch.device: |
| if torch.cuda.is_available(): |
| return torch.device("cuda") |
| if torch.backends.mps.is_available(): |
| return torch.device("mps") |
| return torch.device("cpu") |
|
|
|
|
| ROOT_DIR = Path(__file__).resolve().parent |
|
|
| |
| device = _detect_device() |
|
|
| |
| color_emb_dim = 16 |
| hierarchy_emb_dim = 64 |
| main_emb_dim = 512 |
|
|
| |
| DEFAULT_BATCH_SIZE = 64 |
| DEFAULT_LEARNING_RATE = 1.5e-5 |
| DEFAULT_TEMPERATURE = 0.09 |
|
|
| |
| text_column = "text" |
| color_column = "color" |
| hierarchy_column = "hierarchy" |
| column_local_image_path = "local_image_path" |
| column_url_image = "image_url" |
|
|
| |
| local_dataset_path = str(ROOT_DIR / "data" / "data.csv") |
| color_model_path = str(ROOT_DIR / "models" / "color_model.pt") |
| hierarchy_model_path = str(ROOT_DIR / "models" / "hierarchy_model.pth") |
| main_model_path = str(ROOT_DIR / "models" / "gap_clip.pth") |
| images_dir = str(ROOT_DIR / "data" / "images") |
| fashion_mnist_csv = str(ROOT_DIR / "data" / "fashion-mnist_test.csv") |
|
|
|
|
|
|
| def print_config() -> None: |
| """Pretty-print core configuration.""" |
| print("GAP-CLIP Configuration") |
| print(f" device: {device}") |
| print(f" dims: color={color_emb_dim}, hierarchy={hierarchy_emb_dim}, total={main_emb_dim}") |
| print(f" dataset: {local_dataset_path}") |
| print(f" color model: {color_model_path}") |
| print(f" hierarchy model: {hierarchy_model_path}") |
| print(f" main model: {main_model_path}") |
|
|
|
|
| def validate_paths() -> dict[str, bool]: |
| """Return path existence checks for key files.""" |
| checks = { |
| "local_dataset_path": Path(local_dataset_path).exists(), |
| "color_model_path": Path(color_model_path).exists(), |
| "hierarchy_model_path": Path(hierarchy_model_path).exists(), |
| "main_model_path": Path(main_model_path).exists() |
| } |
| return checks |
|
|
|
|