| """ |
| Utility module for loading and running the finetuned ModernBERT reward model. |
| |
| The model mirrors the architecture defined in `mosaic_bert_training.py`: |
| - base encoder: answerdotai/ModernBERT-base (8k context support) |
| - pooling: attention-mask-weighted mean pooling |
| - head: single linear layer + sigmoid to output a score in [0, 1] |
| """ |
|
|
| import os |
| from typing import Optional |
|
|
| import torch |
| from transformers import AutoModel |
|
|
|
|
| class BERTRewardModel(torch.nn.Module): |
| """ModernBERT encoder with a sigmoid regression head.""" |
|
|
| def __init__(self, model_name: str = "answerdotai/ModernBERT-base"): |
| super().__init__() |
| self.bert = AutoModel.from_pretrained( |
| model_name, |
| reference_compile=False, |
| attn_implementation="eager", |
| ) |
| self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 1) |
| self.sigmoid = torch.nn.Sigmoid() |
|
|
| def forward(self, input_ids, attention_mask, labels: Optional[torch.Tensor] = None): |
| outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) |
|
|
| last_hidden_state = outputs.last_hidden_state |
| attention_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float() |
| sum_hidden = torch.sum(last_hidden_state * attention_mask_expanded, dim=1) |
| sum_mask = torch.clamp(attention_mask_expanded.sum(dim=1), min=1e-9) |
| pooled_output = sum_hidden / sum_mask |
|
|
| logits = self.classifier(pooled_output) |
| scores = self.sigmoid(logits).squeeze(-1) |
|
|
| loss = None |
| if labels is not None: |
| loss_fct = torch.nn.MSELoss() |
| loss = loss_fct(scores, labels) * 100 |
|
|
| return {"loss": loss, "scores": scores, "logits": logits} |
|
|
|
|
| def load_finetuned_model(model_dir: str, device: Optional[str] = None) -> BERTRewardModel: |
| """ |
| Load the finetuned ModernBERT reward model from `model_dir`. |
| |
| Args: |
| model_dir: Path containing model.safetensors (preferred) or pytorch_model.bin. |
| device: Optional torch device string. Defaults to CUDA if available else CPU. |
| """ |
| if device is None: |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| torch_device = torch.device(device) |
|
|
| model = BERTRewardModel() |
| model.to(torch_device) |
|
|
| state_dict = None |
| safetensors_path = os.path.join(model_dir, "model.safetensors") |
| bin_path = os.path.join(model_dir, "pytorch_model.bin") |
|
|
| if os.path.exists(safetensors_path): |
| try: |
| from safetensors.torch import load_file |
|
|
| state_dict = load_file(safetensors_path) |
| except ImportError as exc: |
| print(f"⚠️ safetensors not available ({exc}); falling back to pytorch_model.bin if present.") |
|
|
| if state_dict is None and os.path.exists(bin_path): |
| state_dict = torch.load(bin_path, map_location=torch_device) |
|
|
| if state_dict is None: |
| raise FileNotFoundError( |
| f"Could not find model weights in {model_dir}. " |
| "Expected either model.safetensors or pytorch_model.bin." |
| ) |
|
|
| model.load_state_dict(state_dict) |
| model.eval() |
| return model |
|
|
|
|