| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional, Union, Tuple |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from transformers import PreTrainedModel, PretrainedConfig |
| from transformers.modeling_outputs import ImageClassifierOutput |
|
|
| from .CSATv2 import CSATv2 |
|
|
|
|
| class CSATv2Config(PretrainedConfig): |
| model_type = "csatv2" |
|
|
| def __init__( |
| self, |
| image_size: int = 512, |
| num_channels: int = 3, |
| num_labels: int = 1000, |
| drop_path_rate: float = 0.0, |
| head_init_scale: float = 1.0, |
| **kwargs, |
| ): |
| """ |
| HF가 사용할 설정 값들. |
| """ |
| super().__init__(num_labels=num_labels, **kwargs) |
| self.image_size = image_size |
| self.num_channels = num_channels |
| self.drop_path_rate = drop_path_rate |
| self.head_init_scale = head_init_scale |
|
|
| |
| if self.id2label is None or self.label2id is None: |
| self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} |
| self.label2id = {v: k for k, v in self.id2label.items()} |
|
|
|
|
| class CSATv2ForImageClassification(PreTrainedModel): |
| """ |
| Hugging Face용 ImageNet 분류 모델 래퍼 |
| - backbone: CSATv2 (네가 구현한 모델) |
| - forward(pixel_values, labels=None) |
| """ |
|
|
| config_class = CSATv2Config |
|
|
| def __init__(self, config: CSATv2Config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| |
| self.backbone = CSATv2( |
| img_size=config.image_size, |
| num_classes=config.num_labels, |
| drop_path_rate=config.drop_path_rate, |
| head_init_scale=config.head_init_scale, |
| ) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| pixel_values: torch.Tensor = None, |
| labels: Optional[torch.Tensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[ImageClassifierOutput, Tuple]: |
| """ |
| Args: |
| pixel_values: (batch, 3, H, W), ImageNet 정규화까지 된 이미지 |
| labels: (batch,) 0~999 class index |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if pixel_values is None: |
| raise ValueError("You must provide pixel_values") |
|
|
| |
| logits = self.backbone(pixel_values) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss() |
| loss = loss_fct( |
| logits.view(-1, self.num_labels), |
| labels.view(-1), |
| ) |
|
|
| if not return_dict: |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return ImageClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=None, |
| attentions=None, |
| ) |