| from .configuration_clip_camembert import CLIPTextCamembertConfig |
| from transformers import ( |
| CamembertModel, |
| CLIPTextModelWithProjection, |
| ) |
| from transformers.models.clip.modeling_clip import CLIPTextModelOutput |
| import torch |
| from torch import nn |
| from typing import Any, Optional, Tuple, Union |
|
|
|
|
| class CLIPTextCamembertModelWithProjection(CLIPTextModelWithProjection): |
| config_class = CLIPTextCamembertConfig |
|
|
| def __init__(self, config: CLIPTextCamembertConfig): |
| super().__init__(config) |
|
|
| self.text_model = CamembertModel(config) |
|
|
| self.text_projection = nn.Linear( |
| config.hidden_size, config.projection_dim, bias=False |
| ) |
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple, CLIPTextModelOutput]: |
| return_dict = ( |
| return_dict if return_dict is not None else self.config.use_return_dict |
| ) |
|
|
| text_outputs = self.text_model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| pooled_output = text_outputs[1] |
|
|
| text_embeds = self.text_projection(pooled_output) |
|
|
| if not return_dict: |
| outputs = (text_embeds, text_outputs[0]) + text_outputs[2:] |
| return tuple(output for output in outputs if output is not None) |
|
|
| return CLIPTextModelOutput( |
| text_embeds=text_embeds, |
| last_hidden_state=text_outputs.last_hidden_state, |
| hidden_states=text_outputs.hidden_states, |
| attentions=text_outputs.attentions, |
| ) |
|
|
| def converter_weight( |
| self, path_model="airesearch/wangchanberta-base-att-spm-uncased" |
| ): |
| r""" |
| converter weight from airesearch/wangchanberta-base-att-spm-uncased |
| """ |
| pretrained_state_dict = CamembertModel.from_pretrained(path_model).state_dict() |
| |
| self.text_model.load_state_dict(pretrained_state_dict) |