| import torch |
| import os |
| import math |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from peft import PeftModel, LoraConfig, get_peft_model |
| from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup |
| from torch.nn import MultiheadAttention, GELU |
|
|
| MODEL_NAME = "bert-base-uncased" |
| BATCH_SIZE = 16 |
| MAX_LENGTH = 128 |
| LEARNING_RATE = 2e-5 |
| EPOCHS = 5 |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| PREPROCESSED_DIR = "preprocessed_snli" |
| MIXED_PRECISION = "fp16" |
|
|
|
|
| class SimpleGNN(nn.Module): |
| def __init__(self, input_dim, hidden_dim): |
| super().__init__() |
| self.fc = nn.Linear(input_dim, hidden_dim) |
|
|
| def forward(self, node_embeddings, edges): |
| if node_embeddings.size(0) == 0: |
| return torch.zeros(1, self.fc.out_features, device=node_embeddings.device) |
| num_nodes = node_embeddings.size(0) |
| adj = torch.zeros((num_nodes, num_nodes), device=node_embeddings.device) |
| for (src, dst) in edges: |
| if src < num_nodes and dst < num_nodes: |
| adj[src, dst] = 1.0 |
| deg = adj.sum(dim=1, keepdim=True) + 1e-10 |
| adj_norm = adj / deg |
| agg_embeddings = adj_norm @ node_embeddings |
| return F.relu(self.fc(agg_embeddings)) |
|
|
|
|
| class GraphAugmentedNLIModel(nn.Module): |
| def __init__(self, base_model_name, num_labels=3, hidden_dim=768, gnn_dim=128): |
| super().__init__() |
| config = AutoConfig.from_pretrained(base_model_name) |
| config.num_labels = num_labels |
| self.bert = AutoModel.from_pretrained(base_model_name, config=config) |
| self.dropout = nn.Dropout(0.1) |
|
|
| self.gnn_premise = SimpleGNN(hidden_dim, gnn_dim) |
| self.gnn_hypothesis = SimpleGNN(hidden_dim, gnn_dim) |
|
|
| self.classifier = nn.Linear(hidden_dim + gnn_dim*2, num_labels) |
|
|
| def forward(self, input_ids, attention_mask, premise_graph_tokens, premise_graph_edges, premise_node_indices, |
| hypothesis_graph_tokens, hypothesis_graph_edges, hypothesis_node_indices, labels=None): |
| outputs = self.bert(input_ids, attention_mask=attention_mask) |
| cls_embedding = outputs.last_hidden_state[:,0,:] |
|
|
| batch_size = input_ids.size(0) |
| gnn_p_outputs = [] |
| gnn_h_outputs = [] |
|
|
| |
| |
| for i in range(batch_size): |
| instance_hidden = outputs.last_hidden_state[i] |
|
|
| p_edges = premise_graph_edges[i] |
| p_indices = premise_node_indices[i] |
| h_edges = hypothesis_graph_edges[i] |
| h_indices = hypothesis_node_indices[i] |
|
|
| |
| p_nodes = instance_hidden[p_indices] if len(p_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device) |
| h_nodes = instance_hidden[h_indices] if len(h_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device) |
|
|
| p_gnn_out = self.gnn_premise(p_nodes, p_edges) if p_nodes.size(0) > 0 else torch.zeros(1,128, device=DEVICE) |
| h_gnn_out = self.gnn_hypothesis(h_nodes, h_edges) if h_nodes.size(0) > 0 else torch.zeros(1,128, device=DEVICE) |
|
|
| p_mean = p_gnn_out.mean(dim=0, keepdim=True) |
| h_mean = h_gnn_out.mean(dim=0, keepdim=True) |
|
|
| gnn_p_outputs.append(p_mean) |
| gnn_h_outputs.append(h_mean) |
|
|
| gnn_p_outputs = torch.cat(gnn_p_outputs, dim=0) |
| gnn_h_outputs = torch.cat(gnn_h_outputs, dim=0) |
|
|
| fused = torch.cat([cls_embedding, gnn_p_outputs, gnn_h_outputs], dim=-1) |
| fused = self.dropout(fused) |
| logits = self.classifier(fused) |
|
|
| loss = None |
| if labels is not None: |
| loss_fn = nn.CrossEntropyLoss() |
| loss = loss_fn(logits, labels) |
| return {"loss": loss, "logits": logits} |
|
|
|
|
|
|
| class SimpleFinGNN(nn.Module): |
| def __init__(self, input_dim, hidden_dim): |
| super().__init__() |
| self.fc = nn.Linear(input_dim, hidden_dim) |
|
|
| def forward(self, node_embeddings, edges): |
| if node_embeddings.size(0) == 0: |
| return torch.zeros(1, self.fc.out_features, device=node_embeddings.device) |
| num_nodes = node_embeddings.size(0) |
| adj = torch.zeros((num_nodes, num_nodes), device=node_embeddings.device) |
| for (src, dst) in edges: |
| if src < num_nodes and dst < num_nodes: |
| adj[src, dst] = 1.0 |
| deg = adj.sum(dim=1, keepdim=True) + 1e-10 |
| adj_norm = adj / deg |
| agg_embeddings = adj_norm @ node_embeddings |
| return F.relu(self.fc(agg_embeddings)) |
|
|
|
|
| class GraphAugmentedFinNLIModel(nn.Module): |
| def __init__(self, base_model_name, num_labels=3, hidden_dim=768, gnn_dim=128): |
| super().__init__() |
| config = AutoConfig.from_pretrained(base_model_name) |
| config.num_labels = num_labels |
| self.bert = AutoModel.from_pretrained(base_model_name, config=config) |
| self.dropout = nn.Dropout(0.1) |
|
|
| self.gnn_premise = SimpleGNN(hidden_dim, gnn_dim) |
| self.gnn_hypothesis = SimpleGNN(hidden_dim, gnn_dim) |
|
|
| self.classifier = nn.Linear(hidden_dim + gnn_dim*2, num_labels) |
| self.config = self.bert.config |
| self.config.num_labels = num_labels |
|
|
| def forward(self, |
| input_ids=None, |
| attention_mask=None, |
| premise_graph_tokens=None, |
| hypothesis_graph_tokens=None, |
| premise_graph_edges=None, |
| hypothesis_graph_edges=None, |
| premise_node_indices=None, |
| hypothesis_node_indices=None, |
| labels=None, |
| inputs_embeds=None, |
| **kwargs): |
| |
| outputs = self.bert(input_ids=input_ids, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| **{k:v for k,v in kwargs.items() if k in self.bert.forward.__code__.co_varnames}) |
| |
| cls_embedding = outputs.last_hidden_state[:,0,:] |
| |
| batch_size = input_ids.size(0) if input_ids is not None else outputs.last_hidden_state.size(0) |
| gnn_p_outputs = [] |
| gnn_h_outputs = [] |
| |
| for i in range(batch_size): |
| instance_hidden = outputs.last_hidden_state[i] |
| |
| p_edges = premise_graph_edges[i] |
| p_indices = premise_node_indices[i] |
| h_edges = hypothesis_graph_edges[i] |
| h_indices = hypothesis_node_indices[i] |
| |
| p_nodes = instance_hidden[p_indices] if len(p_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device) |
| h_nodes = instance_hidden[h_indices] if len(h_indices) > 0 else torch.empty(0, instance_hidden.size(-1), device=instance_hidden.device) |
| |
| p_gnn_out = self.gnn_premise(p_nodes, p_edges) if p_nodes.size(0) > 0 else torch.zeros(1,128, device=instance_hidden.device) |
| h_gnn_out = self.gnn_hypothesis(h_nodes, h_edges) if h_nodes.size(0) > 0 else torch.zeros(1,128, device=instance_hidden.device) |
| |
| p_mean = p_gnn_out.mean(dim=0, keepdim=True) |
| h_mean = h_gnn_out.mean(dim=0, keepdim=True) |
| |
| gnn_p_outputs.append(p_mean) |
| gnn_h_outputs.append(h_mean) |
| |
| gnn_p_outputs = torch.cat(gnn_p_outputs, dim=0) |
| gnn_h_outputs = torch.cat(gnn_h_outputs, dim=0) |
| |
| fused = torch.cat([cls_embedding, gnn_p_outputs, gnn_h_outputs], dim=-1) |
| logits = self.classifier(fused) |
| |
| loss = None |
| if labels is not None: |
| loss_fn = nn.CrossEntropyLoss() |
| loss = loss_fn(logits, labels) |
| return {"loss": loss, "logits": logits} |
|
|
|
|
| class SentenceExtractionModel(nn.Module): |
| def __init__(self, |
| base_model_name: str, |
| dropout_prob: float = 0.1, |
| adapter_dir: str = "./lora_finance_adapter", |
| backbone: str = 'default', |
| init_pos_frac: float = None |
| ): |
| """ |
| backbone: |
| - 'default' → plain AutoModel.from_pretrained(base_model_name) |
| - 'finexbert' → use the .bert submodule of your GraphAugmentedFinNLIModel |
| """ |
| super().__init__() |
|
|
| |
| config = AutoConfig.from_pretrained(base_model_name) |
|
|
| if backbone == 'default': |
| |
| self.bert = AutoModel.from_pretrained(base_model_name, config=config) |
|
|
| elif backbone == 'finexbert': |
| |
| base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) |
| lora_cfg = LoraConfig( |
| r=8, |
| lora_alpha=32, |
| lora_dropout=0.1, |
| bias="none", |
| task_type="SEQ_CLS" |
| ) |
| full = get_peft_model(base_model, lora_cfg).to(DEVICE) |
| chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt") |
| if not os.path.isfile(chkpt_path): |
| raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}") |
| ckpt = torch.load(chkpt_path, map_location=DEVICE) |
| |
| full.load_state_dict(ckpt["model_state_dict"], strict=False) |
| |
| |
| self.bert = full.base_model |
|
|
| else: |
| raise ValueError(f"Unknown backbone {backbone}") |
|
|
| hidden_size = self.bert.config.hidden_size |
|
|
| self.dropout = nn.Dropout(dropout_prob) |
| self.classifier = nn.Linear(hidden_size, 1) |
|
|
| |
| if init_pos_frac is not None: |
| b0 = float(math.log(init_pos_frac / (1.0 - init_pos_frac))) |
| self.classifier.bias.data.fill_(b0) |
|
|
| def forward(self, input_ids, attention_mask): |
| outputs = self.bert(input_ids=input_ids, |
| attention_mask=attention_mask) |
| x = self.dropout(outputs.pooler_output) |
| logits = self.classifier(x).squeeze(-1) |
| return logits |