| from typing import Any, TypedDict |
|
|
| from transformers import PretrainedConfig |
| from transformers.utils.backbone_utils import verify_backbone_config_arguments |
|
|
|
|
| class STAConfig(TypedDict): |
| kernel: int |
| q_tile: int |
| kv_tile: int |
|
|
|
|
| class LSPDetrConfig(PretrainedConfig): |
| model_type = "lsp_detr" |
|
|
| def __init__( |
| self, |
| use_timm_backbone: bool = False, |
| use_pretrained_backbone: bool = True, |
| backbone: str = "microsoft/swinv2-tiny-patch4-window16-256", |
| backbone_kwargs: dict[str, Any] | None = None, |
| backbone_config: Any | None = None, |
| dim: int = 384, |
| num_heads: int = 12, |
| num_classes: int = 1, |
| query_block_size: float = 8, |
| feature_levels: tuple[int, ...] = (2, 1, 0, 2, 1, 0), |
| num_radial_distances: int = 64, |
| self_sta_config: STAConfig | None = None, |
| cross_sta_config: tuple[STAConfig, ...] = ( |
| {"kernel": 5, "q_tile": 4, "kv_tile": 8}, |
| {"kernel": 5, "q_tile": 4, "kv_tile": 4}, |
| {"kernel": 5, "q_tile": 4, "kv_tile": 2}, |
| ), |
| **kwargs, |
| ) -> None: |
| if self_sta_config is None: |
| self_sta_config = {"kernel": 5, "q_tile": 4, "kv_tile": 4} |
|
|
| if backbone_kwargs is None: |
| backbone_kwargs = {"out_features": ["stage1", "stage2", "stage3", "stage4"]} |
|
|
| verify_backbone_config_arguments( |
| use_timm_backbone=use_timm_backbone, |
| use_pretrained_backbone=use_pretrained_backbone, |
| backbone=backbone, |
| backbone_config=backbone_config, |
| backbone_kwargs=backbone_kwargs, |
| ) |
|
|
| self.use_timm_backbone = use_timm_backbone |
| self.use_pretrained_backbone = use_pretrained_backbone |
| self.backbone = backbone |
| self.backbone_config = backbone_config |
| self.backbone_kwargs = backbone_kwargs |
| self.dim = dim |
| self.num_heads = num_heads |
| self.num_classes = num_classes |
| self.query_block_size = query_block_size |
| self.feature_levels = feature_levels |
| self.num_radial_distances = num_radial_distances |
| self.self_sta_config = self_sta_config |
| self.cross_sta_config = cross_sta_config |
|
|
| super().__init__(**kwargs) |
|
|