| import os | |
| from typing import Optional | |
| from transformers.modeling_utils import PretrainedConfig | |
| class CaSEDConfig(PretrainedConfig): | |
| """Configuration class for CaSED. | |
| Args: | |
| index_name (str): Name of the index. Defaults to "cc12m". | |
| alpha (float): Weight of the vision loss. Defaults to 0.5. | |
| retrieval_num_results (int): Number of results to return. Defaults to 10. | |
| cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". | |
| """ | |
| model_type = "cased" | |
| is_composition = True | |
| def __init__( | |
| self, | |
| index_name: str = "cc12m", | |
| alpha: float = 0.5, | |
| retrieval_num_results: int = 10, | |
| cache_dir: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| super().__init__(**kwargs) | |
| self.index_name = index_name | |
| self.alpha = alpha | |
| self.retrieval_num_results = retrieval_num_results | |
| self.cache_dir = cache_dir or os.path.expanduser("~/.cache/cased") | |