| from configuration_vgs import VGSConfig
|
| from transformers import Qwen2PreTrainedModel, Qwen2Model
|
| from transformers.modeling_outputs import SequenceClassifierOutputWithPast
|
| from transformers.cache_utils import Cache
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from typing import List, Optional, Tuple, Union
|
| from dataclasses import dataclass
|
|
|
|
|
| @dataclass
|
| class CustomSequenceClassifierOutputWithPast(SequenceClassifierOutputWithPast):
|
|
|
| success_probs: Optional[torch.FloatTensor] = None
|
|
|
|
|
| class VGSModel(Qwen2PreTrainedModel):
|
| config_class = VGSConfig
|
| def __init__(self, config):
|
| super().__init__(config)
|
| num_labels = config.num_labels
|
| self.model = Qwen2Model(config)
|
| self.score = nn.Sequential(
|
| nn.Linear(config.hidden_size, config.hidden_size, bias=config.use_bias),
|
| nn.ReLU(),
|
| nn.Linear(config.hidden_size, num_labels, bias=config.use_bias),
|
| )
|
| self.p_dropout = config.attention_dropout
|
| self.score_dropout = nn.Dropout(self.p_dropout)
|
| self.inference_impl = "naive"
|
| self.train_bt_model = False
|
| self.num_labels = num_labels
|
|
|
|
|
| self.post_init()
|
|
|
| def forward(
|
| self,
|
| input_ids: Optional[torch.LongTensor] = None,
|
| attention_mask: Optional[torch.Tensor] = None,
|
| position_ids: Optional[torch.LongTensor] = None,
|
| past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
| inputs_embeds: Optional[torch.FloatTensor] = None,
|
| labels: Optional[torch.LongTensor] = None,
|
| use_cache: Optional[bool] = None,
|
| output_attentions: Optional[bool] = None,
|
| output_hidden_states: Optional[bool] = None,
|
| return_dict: Optional[bool] = None,
|
| loss_mask: Optional[torch.Tensor] = None,
|
| continuation_ids: Optional[torch.LongTensor] = None,
|
| continuation_attention_mask: Optional[torch.Tensor] = None,
|
| ) -> Union[Tuple, CustomSequenceClassifierOutputWithPast]:
|
| """
|
| During training:
|
| - labels should not be None and have shape: [bs, 1]
|
| - input_ids: [bs, seqlen]
|
| - loss_mask [bs, seqlen]
|
|
|
| During inference:
|
| labels, loss_mask should be None
|
| continuation_ids is [bs, N, c_len].
|
| If input_ids is [bs, seqlen], this is prefill stage.
|
| Otherwise, input_ids is also [bs, c_len] which contains the chosen continuation from last step. And we update the kv_cache.
|
| Here, attention_mask should be [bs, q_len] where q_len is seqlen + len of continuations so far.
|
| """
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| assert return_dict, "Only return_dict=True is supported."
|
| is_training = labels is not None
|
| is_single_eval = continuation_ids is None
|
| if not is_training: assert not self.training, "Model should not be in training mode during inference."
|
|
|
| if is_training:
|
| transformer_outputs = self.model(
|
| input_ids,
|
| attention_mask=attention_mask,
|
| position_ids=position_ids,
|
| past_key_values=past_key_values,
|
| inputs_embeds=inputs_embeds,
|
| use_cache=use_cache,
|
| output_attentions=output_attentions,
|
| output_hidden_states=output_hidden_states,
|
| return_dict=return_dict,
|
| )
|
| hidden_states = transformer_outputs[0]
|
| logits = self.score(self.score_dropout(hidden_states)).float()
|
| bs, seqlen, _ = logits.shape
|
| if self.train_bt_model:
|
| assert self.num_labels == 1, f"BT model should have 1 label. Got {self.num_labels}."
|
| assert bs % 2 == 0, f"Batch size should be even for BT model. Got {bs}."
|
| logits = logits[:, -1, 0]
|
|
|
| assert torch.all(labels[::2] == 1), f"Labels should be 1 for chosen logits. Got {labels[::2]}."
|
| assert torch.all(labels[1::2] == 0), f"Labels should be 0 for rejected logits. Got {labels[1::2]}."
|
| chosen_logits = logits[::2]
|
| reject_logits = logits[1::2]
|
| elemwise_loss = -F.logsigmoid(chosen_logits - reject_logits)
|
| loss = elemwise_loss.mean()
|
| else:
|
| if self.num_labels == 1:
|
|
|
| labels_expanded = labels.unsqueeze(-1).expand_as(logits)
|
| elemwise_loss = F.binary_cross_entropy_with_logits(logits, labels_expanded, reduction="none")
|
| else:
|
|
|
| labels_expanded = labels.long().unsqueeze(-1).expand((bs, seqlen))
|
| elemwise_loss = F.cross_entropy(
|
| logits.transpose(1, 2),
|
| labels_expanded,
|
| reduction="none",
|
| )
|
|
|
| mask_sum = loss_mask.sum(1).float()
|
| safe_denom = torch.where(mask_sum > 0, mask_sum, torch.ones_like(mask_sum))
|
| loss = torch.where(mask_sum > 0, (elemwise_loss * loss_mask).sum(1) / safe_denom, mask_sum)
|
| loss = loss.mean()
|
|
|
| return CustomSequenceClassifierOutputWithPast(loss=loss, logits=logits)
|
|
|
| elif is_single_eval:
|
|
|
| assert continuation_ids is None
|
| transformer_outputs = self.model(
|
| input_ids,
|
| attention_mask=attention_mask,
|
| past_key_values=past_key_values,
|
| use_cache=use_cache,
|
| output_attentions=output_attentions,
|
| output_hidden_states=output_hidden_states,
|
| return_dict=return_dict,
|
| )
|
| hidden_states = transformer_outputs[0]
|
| logits = self.score(hidden_states).float()
|
| if logits.shape[-1] > 1:
|
|
|
| success_probs = F.softmax(logits, dim=-1)[:, :, 1]
|
| else:
|
| assert logits.shape[-1] == 1, f"Expected logits to have 1 output, got {logits.shape}."
|
| success_probs = logits.squeeze(-1).sigmoid()
|
|
|
| return CustomSequenceClassifierOutputWithPast(
|
| logits=logits, success_probs=success_probs, past_key_values=transformer_outputs.past_key_values)
|
|
|