| from transformers import PreTrainedModel
|
| from sproto.model.multi_proto import MultiProtoModule
|
| from .configuration_sproto import SprotoConfig
|
|
|
| class SprotoModel(PreTrainedModel):
|
| config_class = SprotoConfig
|
| base_model_prefix = "sproto"
|
|
|
| def __init__(self, config: SprotoConfig):
|
| super().__init__(config)
|
|
|
| self.module = MultiProtoModule(
|
| pretrained_model=config.pretrained_model,
|
| num_classes=config.num_classes,
|
| label_order_path=config.label_order_path,
|
| use_sigmoid=config.use_sigmoid,
|
| use_cuda=config.use_cuda,
|
| lr_prototypes=config.lr_prototypes,
|
| lr_features=config.lr_features,
|
| lr_others=config.lr_others,
|
| num_training_steps=config.num_training_steps,
|
| num_warmup_steps=config.num_warmup_steps,
|
| loss=config.loss,
|
| save_dir=config.save_dir,
|
| use_attention=config.use_attention,
|
| use_global_attention=config.use_global_attention,
|
| dot_product=config.dot_product,
|
| normalize=config.normalize,
|
| final_layer=config.final_layer,
|
| reduce_hidden_size=config.reduce_hidden_size,
|
| use_prototype_loss=config.use_prototype_loss,
|
| prototype_vector_path=config.prototype_vector_path,
|
| attention_vector_path=config.attention_vector_path,
|
| eval_buckets=config.eval_buckets,
|
| seed=config.seed,
|
| num_prototypes_per_class=config.num_prototypes_per_class,
|
| batch_size=config.batch_size,
|
| )
|
|
|
|
|
| self.post_init()
|
|
|
| def _init_weights(self, module):
|
| """Initialize the weights"""
|
| if isinstance(module, (MultiProtoModule)):
|
|
|
| return
|
|
|
| pass
|
|
|
| def forward(
|
| self,
|
| input_ids=None,
|
| attention_mask=None,
|
| token_type_ids=None,
|
| targets=None,
|
| tokens=None,
|
| sample_ids=None,
|
| **kwargs,
|
| ):
|
|
|
| batch = {
|
| "input_ids": input_ids,
|
| "attention_masks": attention_mask,
|
| "token_type_ids": token_type_ids,
|
| "targets": targets,
|
| "tokens": tokens,
|
| "sample_ids": sample_ids,
|
| }
|
|
|
| logits, max_indices, metadata = self.module(batch)
|
|
|
| return {
|
| "logits": logits,
|
| "max_indices": max_indices,
|
| "metadata": metadata,
|
| } |