|
|
| import torch |
| import wandb |
| from transformers import Trainer |
|
|
|
|
| class ORPOTrainer(Trainer): |
| def __init__(self, alpha, pad, *args, **kwargs): |
| super().__init__(*args, **kwargs) |
| self.pad = pad |
| self.alpha = alpha |
| self.loss_fct = torch.nn.CrossEntropyLoss(reduction='none') |
| print("Pad Token ID: ", self.pad) |
| |
| def compute_custom_loss(self, logits, labels): |
| |
| logits = logits.contiguous() |
| |
| if labels is not None: |
| |
| labels = labels.to(logits.device) |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
| |
| |
| loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(dim=-1) |
| |
| return loss |
| |
| def compute_logps(self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits): |
| mask = chosen_attention_mask[:, :-1] - prompt_attention_mask[:, 1:] |
| per_token_logps = torch.gather(logits[:, :-1, :].log_softmax(-1), dim=2, |
| index=(mask * chosen_inputs[:, 1:]).unsqueeze(2)).squeeze(2) |
| return torch.mul(per_token_logps, mask.to(dtype=torch.bfloat16)).sum(dim=1).to(dtype=torch.float64) / mask.sum(dim=1).to(dtype=torch.float64) |
| |
| def compute_loss(self, model, inputs, return_outputs=False): |
| if self.label_smoother is not None and "labels" in inputs: |
| labels = inputs.pop("labels") |
| else: |
| labels = None |
| |
| |
| neg_labels = inputs['negative_input_ids'].clone() |
| pos_labels = inputs['positive_input_ids'].clone() |
|
|
| neg_labels[neg_labels == self.pad] = -100 |
| pos_labels[pos_labels == self.pad] = -100 |
|
|
| outputs_neg = model(**{'input_ids': inputs['negative_input_ids'], |
| 'attention_mask': inputs['negative_attention_mask'], |
| 'labels': neg_labels,}, output_hidden_states=True) |
| outputs_pos = model(**{'input_ids': inputs['positive_input_ids'], |
| 'attention_mask': inputs['positive_attention_mask'], |
| 'labels': pos_labels,}, output_hidden_states=True) |
| |
| |
| pos_loss = self.compute_custom_loss(logits=outputs_pos.logits, labels=inputs['positive_input_ids']) |
|
|
| |
| pos_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'], |
| chosen_inputs=inputs['positive_input_ids'], |
| chosen_attention_mask=inputs['positive_attention_mask'], |
| logits=outputs_pos.logits) |
| neg_prob = self.compute_logps(prompt_attention_mask=inputs['attention_mask'], |
| chosen_inputs=inputs['negative_input_ids'], |
| chosen_attention_mask=inputs['negative_attention_mask'], |
| logits=outputs_neg.logits) |
|
|
| |
| log_odds = (pos_prob - neg_prob) - (torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))) |
| sig_ratio = torch.nn.functional.sigmoid(log_odds) |
| ratio = torch.log(sig_ratio) |
| |
| |
| loss = torch.mean(pos_loss - self.alpha * ratio).to(dtype=torch.bfloat16) |
| |
| wandb.log({'Positive Geometric Mean': torch.mean(pos_prob).item(), |
| 'Negative Geometric Mean': torch.mean(neg_prob).item(), |
| 'Log Odds Ratio': torch.mean(ratio).item(), |
| 'Log Odds': torch.mean(log_odds).item()}) |
| |
| return (loss, outputs_pos) if return_outputs else loss |