| from transformers import AutoModelForTokenClassification, AutoModel, AutoConfig |
| from transformers.modeling_outputs import TokenClassifierOutput |
| import torch |
| import torch.nn as nn |
| from torchcrf import CRF |
| from typing import Optional, Union, Tuple, List |
| import os |
| import json |
|
|
|
|
| class TransformerCRFForTokenClassification(AutoModelForTokenClassification): |
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| self.base_model = AutoModel.from_config(config=config, use_safetensors=True) |
| hidden_size = config.hidden_size if hasattr(config, 'hidden_size') else 768 |
|
|
| self.dropout = nn.Dropout(config.hidden_dropout_prob if hasattr(config, 'hidden_dropout_prob') else 0.1) |
| self.classifier = nn.Linear(hidden_size, config.num_labels) |
| |
|
|
| self.use_crf = config.use_crf if hasattr(config, 'use_crf') else False |
| if self.use_crf: |
| self.crf = CRF(num_tags=self.num_labels, batch_first=True) |
| else: |
| self.crf = None |
| self.loss_fn = nn.CrossEntropyLoss() |
|
|
| |
| self.post_init() |
| |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| token_type_ids: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| head_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.base_model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| sequence_output = outputs[0] |
| sequence_output = self.dropout(sequence_output) |
| logits = self.classifier(sequence_output) |
| |
| loss = None |
| if labels is not None: |
| if self.crf is not None: |
| mask = attention_mask.bool() |
| labels_mask = labels != -100 |
| mask = mask & labels_mask |
| loss = -self.crf(logits, labels, mask=mask, reduction='mean') |
| else: |
| loss = self.loss_fn(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
| if not return_dict: |
| output = (logits,) + outputs[2:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return TokenClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states if output_hidden_states else None, |
| attentions=outputs.attentions if output_attentions else None, |
| ) |
|
|
| def save_pretrained(self, save_directory: str, **kwargs): |
| """Save model with custom CRF layer""" |
| |
| self.config.use_crf = self.use_crf |
| self.config.save_pretrained(save_directory, safe_serialization=True) |
|
|
| |
| super().save_pretrained(save_directory, safe_serialization=True, **kwargs) |
|
|
| if self.crf is not None: |
| crf_path = os.path.join(save_directory, "crf.pt") |
| torch.save(self.crf.state_dict(), crf_path) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path: str, *model_args, **kwargs): |
| """Load model with custom CRF layer""" |
| if 'config' in kwargs: |
| config = kwargs.pop('config') |
| else: |
| config = AutoConfig.from_pretrained(pretrained_model_name_or_path) |
| |
| |
| if not hasattr(config, 'use_crf'): |
| config.use_crf = False |
|
|
| |
| model = super().from_pretrained(pretrained_model_name_or_path, config=config, use_safetensors=True, *model_args, **kwargs) |
|
|
| |
| if config.use_crf: |
| model.crf = CRF(num_tags=config.num_labels, batch_first=True) |
| crf_path = os.path.join(pretrained_model_name_or_path, "crf.pt") |
| if os.path.exists(crf_path): |
| model.crf.load_state_dict(torch.load(crf_path)) |
| else: |
| model.crf = None |
|
|
| return model |
|
|