import os import logging import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from nltk import sent_tokenize from sklearn.metrics import accuracy_score, precision_score, f1_score from sklearn.model_selection import train_test_split from torch.utils.data import DataLoader, random_split, WeightedRandomSampler from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup from peft import PeftModel, LoraConfig, get_peft_model from datasets import load_dataset, DatasetDict, load_from_disk import spacy import re from tqdm.auto import tqdm from accelerate import Accelerator import matplotlib.pyplot as plt from torch.optim import AdamW import pandas as pd from typing import Optional, Tuple, List, Dict from models import GraphAugmentedNLIModel, GraphAugmentedFinNLIModel from preprocess_data import SpanExtractionChunkedDataset, process_data, chunk_transcript, span_collate_fn # ============================= # Configuration Constants # ============================= from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp #MODEL_NAME = "bert-base-uncased" BATCH_SIZE = 16 #MAX_LENGTH = 128 #OVERLAP = 32 LEARNING_RATE = 2e-5 EPOCHS = 5 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") #PREPROCESSED_DIR = "preprocessed_snli" MIXED_PRECISION = "fp16" # label mapping label_map = {0: "entailment", 1: "neutral", 2: "contradiction"} # ============================= # Logging & Reproducibility # ============================= logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") def set_seed(seed: int = 42): random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # ============================= # Tokenizer & NLP Model # ============================= tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) nlp = spacy.load("en_core_web_sm") # ============================= # Dependency Graph Helpers # ============================= def build_dependency_graph(sentence: str): doc = nlp(sentence) tokens = [token.text for token in doc] edges = [] for token in doc: if token.head.i != token.i: edges.append((token.i, token.head.i)) edges.append((token.head.i, token.i)) return tokens, edges # ============================= # Token Alignment # ============================= def align_tokens(spacy_tokens, wp_tokens): node_indices = [] wp_idx = 1 # after [CLS] for _ in spacy_tokens: if wp_idx >= len(wp_tokens) - 1: break node_indices.append(wp_idx) wp_idx += 1 while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"): wp_idx += 1 return node_indices # ============================= # Data Collation # ============================= def my_collate_fn(batch): input_ids = [torch.tensor(ex["input_ids"], dtype=torch.long) for ex in batch] attention_mask = [torch.tensor(ex["attention_mask"], dtype=torch.long) for ex in batch] labels = [ex.get("labels", None) for ex in batch] premise_graph_tokens = [ex.get("premise_graph_tokens") for ex in batch] premise_graph_edges = [ex.get("premise_graph_edges") for ex in batch] premise_node_indices = [ex.get("premise_node_indices") for ex in batch] hypothesis_graph_tokens = [ex.get("hypothesis_graph_tokens") for ex in batch] hypothesis_graph_edges = [ex.get("hypothesis_graph_edges") for ex in batch] hypothesis_node_indices = [ex.get("hypothesis_node_indices") for ex in batch] input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) labels = torch.tensor(labels, dtype=torch.long) if labels and labels[0] is not None else None return { "input_ids": input_ids, "attention_mask": attention_mask, "labels": labels, "premise_graph_tokens": premise_graph_tokens, "premise_graph_edges": premise_graph_edges, "premise_node_indices": premise_node_indices, "hypothesis_graph_tokens": hypothesis_graph_tokens, "hypothesis_graph_edges": hypothesis_graph_edges, "hypothesis_node_indices": hypothesis_node_indices, } # ============================= # Training Loop # ============================= def train_model(epochs: int = EPOCHS, batch_size: int = BATCH_SIZE, lr: float = LEARNING_RATE, save_model: bool = False, save_path: str = 'gnn_model_weights_3.pt'): set_seed() process_data() logging.info("Loading preprocessed dataset...") snli = load_from_disk(PREPROCESSED_DIR) snli.set_format("python", output_all_columns=True) train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn) model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) if hasattr(model.bert, 'gradient_checkpointing_enable'): model.bert.gradient_checkpointing_enable() logging.info("Enabled gradient checkpointing on BERT.") optimizer = torch.optim.AdamW(model.parameters(), lr=lr) num_training_steps = epochs * len(train_loader) lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=num_training_steps) accelerator = Accelerator(mixed_precision=MIXED_PRECISION) model, optimizer, train_loader, val_loader, lr_scheduler = accelerator.prepare( model, optimizer, train_loader, val_loader, lr_scheduler ) model.train() all_losses = [] epoch_losses = [] best_val_loss = float('inf') best_epoch = 0 for epoch in range(1, epochs + 1): epoch_loss = [] progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False) for batch in progress: labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None outputs = model( input_ids=batch["input_ids"].to(DEVICE), attention_mask=batch["attention_mask"].to(DEVICE), premise_graph_tokens=batch["premise_graph_tokens"], premise_graph_edges=batch["premise_graph_edges"], premise_node_indices=batch["premise_node_indices"], hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], hypothesis_graph_edges=batch["hypothesis_graph_edges"], hypothesis_node_indices=batch["hypothesis_node_indices"], labels=labels ) loss = outputs.get("loss") if isinstance(outputs, dict) else outputs optimizer.zero_grad() accelerator.backward(loss) optimizer.step() lr_scheduler.step() loss_val = loss.item() epoch_loss.append(loss_val) all_losses.append(loss_val) progress.set_postfix({"loss": f"{loss_val:.4f}"}) avg_epoch_loss = np.mean(epoch_loss) epoch_losses.append(avg_epoch_loss) logging.info(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}") # Validation model.eval() val_losses = [] with torch.no_grad(): for batch in val_loader: labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None outputs = model( input_ids=batch["input_ids"].to(DEVICE), attention_mask=batch["attention_mask"].to(DEVICE), premise_graph_tokens=batch["premise_graph_tokens"], premise_graph_edges=batch["premise_graph_edges"], premise_node_indices=batch["premise_node_indices"], hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], hypothesis_graph_edges=batch["hypothesis_graph_edges"], hypothesis_node_indices=batch["hypothesis_node_indices"], labels=labels ) loss_item = outputs.get("loss").item() if isinstance(outputs, dict) else outputs.item() val_losses.append(loss_item) avg_val_loss = np.mean(val_losses) if val_losses else float('inf') logging.info(f"Validation Loss after Epoch {epoch}: {avg_val_loss:.4f}") if avg_val_loss < best_val_loss: best_val_loss = avg_val_loss best_epoch = epoch if save_model: logging.info(f"Saving best model at epoch {epoch} with val loss {avg_val_loss:.4f}") torch.save(model.state_dict(), save_path) model.train() # Plot losses plt.figure() plt.plot(all_losses) plt.xlabel('Training steps') plt.ylabel('Loss') plt.title('Step-wise Training Loss') plt.show() plt.figure() plt.plot(range(1, epochs+1), epoch_losses, marker='o') plt.xlabel('Epochs') plt.ylabel('Loss') plt.title('Epoch-wise Training Loss') plt.show() logging.info(f"Training complete. Best validation loss {best_val_loss:.4f} at epoch {best_epoch}.") return model def predict_nli(premise, hypothesis, tokenizer=tokenizer, model_path='gnn_model_checkpoint.pt'): # 1) instantiate the model exactly as you did during training model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) # 2) load the checkpoint, then hand only the model weights to load_state_dict ckpt = torch.load(model_path, map_location=DEVICE) model.load_state_dict(ckpt["model_state_dict"]) model.eval() # 3) tokenize & build graphs (as before)… encoded = tokenizer( premise, hypothesis, truncation=True, padding="max_length", max_length=MAX_LENGTH, return_tensors="pt" ) input_ids = encoded["input_ids"] attention_mask = encoded["attention_mask"] # Build dependency graphs p_tokens, p_edges = build_dependency_graph(premise) h_tokens, h_edges = build_dependency_graph(hypothesis) # Convert ids back to tokens for alignment wp_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) p_node_indices = align_tokens(p_tokens, wp_tokens) h_node_indices = align_tokens(h_tokens, wp_tokens) # Move tensors to the same device as the model device = next(model.parameters()).device input_ids = input_ids.to(device) attention_mask = attention_mask.to(device) # Prepare inputs for the model: the model expects lists for graph fields # since we used a custom collate_fn logic. premise_graph_tokens = [p_tokens] premise_graph_edges = [p_edges] premise_node_indices = [p_node_indices] hypothesis_graph_tokens = [h_tokens] hypothesis_graph_edges = [h_edges] hypothesis_node_indices = [h_node_indices] with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, premise_graph_tokens=premise_graph_tokens, premise_graph_edges=premise_graph_edges, premise_node_indices=premise_node_indices, hypothesis_graph_tokens=hypothesis_graph_tokens, hypothesis_graph_edges=hypothesis_graph_edges, hypothesis_node_indices=hypothesis_node_indices ) logits = outputs["logits"] probs = F.softmax(logits, dim=-1).cpu().numpy()[0] # Get predicted label predicted_label_id = torch.argmax(logits, dim=-1).item() predicted_label = label_map[predicted_label_id] prob_map = dict() for i, cls_label in label_map.items(): prob_map[cls_label] = probs[i] return predicted_label, prob_map def predict_fin_nli( premise: str, hypothesis: str, tokenizer=tokenizer, model_path: str = 'gnn_model_checkpoint.pt', adapter_dir: str = './lora_finance_adapter', ) -> (str, list): # 1) Load base GraphAugmentedFinNLIModel and its checkpoint base_model = GraphAugmentedFinNLIModel(MODEL_NAME).to(DEVICE) ckpt = torch.load(model_path, map_location=DEVICE) base_model.load_state_dict(ckpt['model_state_dict']) # 2) Wrap with the same LoRA config you used in training lora_cfg = LoraConfig( r=8, lora_alpha=32, lora_dropout=0.1, bias='none', task_type='SEQ_CLS', target_modules=['query', 'value'] ) model = get_peft_model(base_model, lora_cfg).to(DEVICE) # 3) Load your adapter checkpoint (the .pt under lora_finance_adapter/) adapter_ckpt = torch.load(os.path.join(adapter_dir, 'training_checkpoint.pt'), map_location=DEVICE) # This checkpoint contains the same 'model_state_dict' keys—so load it leniently: model.load_state_dict(adapter_ckpt['model_state_dict'], strict=False) model.eval() # 4) Tokenize enc = tokenizer( premise, hypothesis, truncation=True, padding='max_length', max_length=MAX_LENGTH, return_tensors='pt' ) input_ids = enc['input_ids'].to(DEVICE) attention_mask = enc['attention_mask'].to(DEVICE) # 5) Build & align your dependency graphs p_toks, p_edges = build_dependency_graph(premise) h_toks, h_edges = build_dependency_graph(hypothesis) wp = tokenizer.convert_ids_to_tokens(input_ids[0]) p_idx = align_tokens(p_toks, wp) h_idx = align_tokens(h_toks, wp) premise_graph_tokens = [p_toks] premise_graph_edges = [p_edges] premise_node_indices = [p_idx] hypothesis_graph_tokens = [h_toks] hypothesis_graph_edges = [h_edges] hypothesis_node_indices = [h_idx] # 6) Forward with torch.no_grad(): out = model( input_ids=input_ids, attention_mask=attention_mask, premise_graph_tokens=premise_graph_tokens, premise_graph_edges=premise_graph_edges, premise_node_indices=premise_node_indices, hypothesis_graph_tokens=hypothesis_graph_tokens, hypothesis_graph_edges=hypothesis_graph_edges, hypothesis_node_indices=hypothesis_node_indices ) logits = out['logits'][0] # shape [3] probs = torch.softmax(logits, dim=-1).cpu().numpy() # 7) Collapse to entailment vs. contradiction (ignore neutral) entail, neutral, contra = probs s = entail + contra + 1e-12 scores = [entail / s, contra / s] label = 'entailment' if entail >= contra else 'contradiction' return label, scores def train_model_with_chkpt(epochs: int = 5, batch_size: int = 16, lr: float = 2e-5, save_model: bool = False, save_path: str = 'gnn_model_checkpoint.pt', resume: bool = False): """ Train with mixed precision, gradient checkpointing, and resume support. If resume=True and save_path exists, picks up from last epoch. """ set_seed() process_data() logging.info("Loading preprocessed dataset…") snli = load_from_disk(PREPROCESSED_DIR) snli.set_format("python", output_all_columns=True) train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn) model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) optimizer = torch.optim.AdamW(model.parameters(), lr=lr) total_steps = epochs * len(train_loader) scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=total_steps) # --- Resume checkpoint if requested --- start_epoch = 1 if resume and os.path.isfile(save_path): ckpt = torch.load(save_path, map_location=DEVICE) model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) scheduler.load_state_dict(ckpt["scheduler_state_dict"]) start_epoch = ckpt.get("epoch", 1) + 1 logging.info(f"Resuming from epoch {start_epoch}") # Mixed precision setup if hasattr(model.bert, "gradient_checkpointing_enable"): model.bert.gradient_checkpointing_enable() logging.info("Enabled gradient checkpointing on BERT.") accelerator = Accelerator(mixed_precision=MIXED_PRECISION) model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( model, optimizer, train_loader, val_loader, scheduler ) best_val_loss = float("inf") for epoch in range(start_epoch, epochs + 1): model.train() train_losses = [] for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"): optimizer.zero_grad() outputs = model( input_ids=batch["input_ids"].to(DEVICE), attention_mask=batch["attention_mask"].to(DEVICE), premise_graph_tokens=batch["premise_graph_tokens"], premise_graph_edges=batch["premise_graph_edges"], premise_node_indices=batch["premise_node_indices"], hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], hypothesis_graph_edges=batch["hypothesis_graph_edges"], hypothesis_node_indices=batch["hypothesis_node_indices"], labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None ) loss = outputs["loss"] if isinstance(outputs, dict) else outputs accelerator.backward(loss) optimizer.step() scheduler.step() train_losses.append(loss.item()) avg_train = np.mean(train_losses) logging.info(f"Epoch {epoch} train loss: {avg_train:.4f}") # Validation model.eval() val_losses = [] with torch.no_grad(): for batch in val_loader: outputs = model( input_ids=batch["input_ids"].to(DEVICE), attention_mask=batch["attention_mask"].to(DEVICE), premise_graph_tokens=batch["premise_graph_tokens"], premise_graph_edges=batch["premise_graph_edges"], premise_node_indices=batch["premise_node_indices"], hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], hypothesis_graph_edges=batch["hypothesis_graph_edges"], hypothesis_node_indices=batch["hypothesis_node_indices"], labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None ) v_loss = outputs["loss"].item() if isinstance(outputs, dict) else outputs.item() val_losses.append(v_loss) avg_val = np.mean(val_losses) if val_losses else float("inf") logging.info(f"Epoch {epoch} val loss: {avg_val:.4f}") # Save checkpoint ckpt = { "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), } torch.save(ckpt, save_path) logging.info(f"Saved checkpoint: {save_path}") if avg_val < best_val_loss: best_val_loss = avg_val logging.info(f"Training complete. Best val loss: {best_val_loss:.4f}") return model def extract_sentences_by_intent( text: str, intent: str, adapter_dir: str = "./lora_finance_adapter", threshold: float = 0.7, top_k: int = None, min_words: int = 4, convo_focus: str = None ): """ Splits `text` into sentences, embeds them (and the `intent`) under your LoRA‐adapted BERT, and returns those whose cosine similarity ≥ `threshold`. Loads the adapter from the single `training_checkpoint.pt` in `adapter_dir`. """ # 1) Sentence split & cleanup # 1) Only consider lines spoken by the customer if convo_focus is None: sentences = [sent.text.strip() for sent in nlp(text).sents if sent.text.strip()] elif convo_focus == "customer": customer_lines = [ line.strip() for line in text.splitlines() if line.strip().lower().startswith("customer:") ] # 2) Sentence-split each customer line sentences = [] for cust_line in customer_lines: for sent in nlp(cust_line).sents: s = sent.text.strip() if s and len(s.split(' '))>6: sentences.append(s) else: customer_lines = [ line.strip() for line in text.splitlines() if line.strip().lower().startswith("agent:") ] # 2) Sentence-split each customer line sentences = [] for cust_line in customer_lines: for sent in nlp(cust_line).sents: s = sent.text.strip() if s and len(s.split(' '))>6: sentences.append(s) # 2) Load base BERT + wrap in same LoRA config base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) lora_cfg = LoraConfig( r=8, lora_alpha=32, lora_dropout=0.1, bias="none", task_type="CAUSAL_LM", # must match your fine-tune setting ) model = get_peft_model(base_model, lora_cfg).to(DEVICE) # 3) Load your adapter checkpoint chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt") if not os.path.isfile(chkpt_path): raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}") ckpt = torch.load(chkpt_path, map_location=DEVICE) # ckpt["model_state_dict"] contains both base + LoRA weights; strict=False model.load_state_dict(ckpt["model_state_dict"], strict=False) model.eval() # helper: get [CLS] embedding under LoRA-BERT def embed(text_str): toks = tokenizer( text_str, truncation=True, padding="longest", return_tensors="pt" ).to(DEVICE) em_args = { "input_ids": toks["input_ids"], "attention_mask": toks["attention_mask"], } if "token_type_ids" in toks: em_args["token_type_ids"] = toks["token_type_ids"] # unwrap PEFT to call only the base BertModel hf_model = getattr(model, "base_model", model) with torch.no_grad(): last_hidden = hf_model( input_ids=em_args["input_ids"], attention_mask=em_args["attention_mask"], **({"token_type_ids": em_args["token_type_ids"]} if "token_type_ids" in em_args else {}) ).last_hidden_state return last_hidden[:, 0, :] # now embed(intent) and each sentence using this safe helper intent_emb = embed(intent) results = [] with torch.no_grad(): for sent in sentences: clean = re.sub(r'^(Agent|Customer):\s*', "", sent) if len(clean.split()) < min_words: continue sent_emb = embed(clean) sim = F.cosine_similarity(sent_emb, intent_emb, dim=1).item() if sim >= threshold: results.append((clean, sim)) # 5) sort & trim results.sort(key=lambda x: x[1], reverse=True) return results[:top_k] if top_k else results def train_sentence_extractor( model: nn.Module, dataset: torch.utils.data.Dataset, output_dir: str, val_split: float = 0.2, epochs: int = 3, batch_size: int = 16, lr: float = 2e-5, device: str = "cpu", unfreeze_after_epoch: int = 1, threshold: float = 0.5 ): """ Fine-tune `model` on `dataset`, hold out `val_split` for val, compute loss + acc + precision + F1 each epoch, save best checkpoint, and plot all four metrics at the end. """ # Split total = len(dataset) val_n = int(total * val_split) train_n = total - val_n train_ds, val_ds = random_split(dataset, [train_n, val_n]) # Oversample train train_labels = [train_ds[i]['label'].item() for i in range(len(train_ds))] counts = torch.bincount(torch.tensor(train_labels, dtype=torch.long)) weights = (1.0 / counts.float()).tolist() sample_weights = [weights[int(l)] for l in train_labels] sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, drop_last=True) val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) model.to(device) # initially freeze backbone for p in model.bert.parameters(): p.requires_grad = False optimizer = AdamW(model.parameters(), lr=lr) total_steps = epochs * len(train_loader) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps ) criterion = nn.BCEWithLogitsLoss() # storage for metrics train_losses, val_losses = [], [] train_accs, val_accs = [], [] train_precs, val_precs = [], [] train_f1s, val_f1s = [], [] best_val_loss = float('inf') for epoch in range(1, epochs+1): # —— TRAIN —— model.train() epoch_loss = 0.0 preds, labels = [], [] for batch in tqdm(train_loader, desc=f"Train {epoch}/{epochs}"): inputs = batch['input_ids'].to(device) masks = batch['attention_mask'].to(device) labs = batch['label'].to(device) optimizer.zero_grad() logits = model(inputs, masks) # raw logits loss = criterion(logits, labs) loss.backward() optimizer.step() scheduler.step() epoch_loss += loss.item() probs = torch.sigmoid(logits) batch_preds = (probs >= threshold).long() preds.extend(batch_preds.cpu().tolist()) labels.extend(labs.cpu().long().tolist()) avg_train = epoch_loss / len(train_loader) train_losses.append(avg_train) train_accs.append( accuracy_score(labels, preds) ) train_precs.append( precision_score(labels, preds, zero_division=0) ) train_f1s.append( f1_score(labels, preds, zero_division=0) ) print(f"→ Epoch {epoch} Train — loss {avg_train:.4f}, acc {train_accs[-1]:.4f}, prec {train_precs[-1]:.4f}, f1 {train_f1s[-1]:.4f}") # unfreeze if needed if epoch == unfreeze_after_epoch: for p in model.bert.parameters(): p.requires_grad = True optimizer = AdamW([ {"params": model.classifier.parameters(), "lr": 1e-3}, {"params": model.bert.parameters(), "lr": 1e-5}, ], weight_decay=1e-2) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=int(0.1 * total_steps), num_training_steps=total_steps ) # —— VALIDATION —— model.eval() epoch_loss = 0.0 preds, labels = [], [] with torch.no_grad(): for batch in tqdm(val_loader, desc=f" Val {epoch}/{epochs}"): inputs = batch['input_ids'].to(device) masks = batch['attention_mask'].to(device) labs = batch['label'].to(device) logits = model(inputs, masks) loss = criterion(logits, labs) epoch_loss += loss.item() probs = torch.sigmoid(logits) batch_preds = (probs >= threshold).long() preds.extend(batch_preds.cpu().tolist()) labels.extend(labs.cpu().long().tolist()) avg_val = epoch_loss / len(val_loader) val_losses.append(avg_val) val_accs.append( accuracy_score(labels, preds) ) val_precs.append( precision_score(labels, preds, zero_division=0) ) val_f1s.append( f1_score(labels, preds, zero_division=0) ) print(f"→ Epoch {epoch} Val — loss {avg_val:.4f}, acc {val_accs[-1]:.4f}, prec {val_precs[-1]:.4f}, f1 {val_f1s[-1]:.4f}") # checkpoints os.makedirs(output_dir, exist_ok=True) ckpt = os.path.join(output_dir, f"epo{epoch}_val{avg_val:.4f}.pth") torch.save(model.state_dict(), ckpt) if avg_val < best_val_loss: best_val_loss = avg_val torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth")) print(f"🎉 New best model saved (val loss {best_val_loss:.4f})") print(f"✔️ Training complete — best val loss: {best_val_loss:.4f}") # —— PLOT METRICS —— epochs = list(range(1, epochs+1)) save_metric_plot( epochs, train_losses, val_losses, metric_name="Loss", output_path="results/Loss_Plot.png" ) save_metric_plot( epochs, train_accs, val_accs, metric_name="Accuracy", output_path="results/Accuracy_Plot.png", threshold=0.5 ) save_metric_plot( epochs, train_precs, val_precs, metric_name="Precision", output_path="results/Precision_Plot.png", threshold=0.5 ) save_metric_plot( epochs, train_f1s, val_f1s, metric_name="F1 Score", output_path="results/F1Score_Plot.png", threshold=0.5 ) def save_metric_plot( epochs, train_vals, val_vals, metric_name: str, output_path: str, threshold: float = None ): """ epochs – list of epoch indices train_vals – list of train metric values val_vals – list of validation metric values metric_name – e.g. "Loss", "Accuracy", "Precision", "F1 Score" output_path – where to save the PNG threshold – optional horizontal line to draw, e.g. 0.5 """ fig, ax = plt.subplots(figsize=(8, 5)) ax.plot(epochs, train_vals, marker='o', linewidth=2, label=f'Train {metric_name}') ax.plot(epochs, val_vals, marker='s', linewidth=2, label=f'Val {metric_name}') if threshold is not None: ax.axhline(threshold, color='gray', linestyle='--', linewidth=1, label=f'Threshold = {threshold}') ax.set_title(f'{metric_name} over Epochs', fontsize=14, pad=10) ax.set_xlabel('Epoch', fontsize=12) ax.set_ylabel(metric_name, fontsize=12) ax.grid(True, linestyle='--', alpha=0.4) ax.legend(loc='best', frameon=True, fontsize=10) fig.tight_layout() fig.savefig(output_path, dpi=300) plt.close(fig) def demo_on_random_val( model, tokenizer, excel_path: str, ckpt_path: str, max_length: int = 128, device: str = "cpu", temperature: float = 1.0 ): """ Like demo_on_random_val, but instead of a fixed threshold: 1) Compute sigmoid(logits / temperature) for each sentence 2) Sort probabilities descending 3) Find the largest gap between adjacent probs 4) Set dynamic_threshold = midpoint of that gap 5) Extract all sentences with prob >= dynamic_threshold """ # load model model.load_state_dict(torch.load(ckpt_path, map_location=device)) model.to(device).eval() # sample one from validation split df = pd.read_excel(excel_path) _, val_df = train_test_split(df, test_size=0.2, random_state=42) row = val_df.sample(n=1, random_state=random.randint(0,999)).iloc[0] transcript = str(row['Claude_Call']) print(f"\n── Transcript (val sample idx={row['idx']}):\n{transcript}\n") # split into sentences & run inference sentences, probs = [], [] for sent in sent_tokenize(transcript): enc = tokenizer.encode_plus( sent, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt' ) logits = model(enc['input_ids'].to(device), enc['attention_mask'].to(device)) prob = torch.sigmoid(logits / temperature).item() sentences.append(sent) probs.append(prob) # print all print("Sentence probabilities:") for s,p in zip(sentences, probs): print(f" → {p:.4f} → {s}") # if no variation, fall back to 0.5 if len(probs) < 2 or max(probs) - min(probs) < 1e-3: dynamic_thr = 0.5 else: # find elbow in sorted probabilities sorted_probs = sorted(probs, reverse=True) diffs = [sorted_probs[i] - sorted_probs[i+1] for i in range(len(sorted_probs)-1)] idx = max(range(len(diffs)), key=lambda i: diffs[i]) # threshold is midpoint between the two dynamic_thr = (sorted_probs[idx] + sorted_probs[idx+1]) / 2.0 print(f"\nDynamic threshold = {dynamic_thr:.4f}\n") print("Extracted sentences:") for s,p in zip(sentences, probs): if p >= dynamic_thr: print(f" • {p:.4f} → {s}") print() def batch_predict_and_save( model, tokenizer, excel_path: str, ckpt_path: str, output_path: str, n_samples: int = 40, max_length: int = 128, device: str = "cpu", temperature: float = 1.0, random_state: int = None ): """ 1) Loads best checkpoint 2) Samples `n_samples` rows 3) For each transcript: - tokenize into sentences - compute p = sigmoid(logits/temperature) - compute elbow threshold on sorted p’s - extract all sentences with p >= elbow - if none, pick the highest-p sentence 4) Save new Excel with columns: - 'Claude_Call' - 'Predicted Sel_K' (list of extracted sentences) """ # load model model.load_state_dict(torch.load(ckpt_path, map_location=device)) model.to(device).eval() # sample rows df = pd.read_excel(excel_path) sampled = df.sample(n=n_samples, random_state=random_state) \ if random_state is not None else df.sample(n=n_samples) records = [] for _, row in tqdm(sampled.iterrows(), total=len(sampled), desc="Running Predictions"): transcript = str(row['Claude_Call']) sentences = sent_tokenize(transcript) # compute probabilities probs = [] for sent in sentences: enc = tokenizer.encode_plus( sent, max_length=max_length, padding='max_length', truncation=True, return_tensors='pt' ) with torch.no_grad(): logits = model(enc['input_ids'].to(device), enc['attention_mask'].to(device)) p = torch.sigmoid(logits / temperature).item() probs.append(p) # dynamic threshold via elbow detection if len(probs) >= 2 and max(probs) - min(probs) > 1e-3: sp = sorted(probs, reverse=True) diffs = [sp[i] - sp[i+1] for i in range(len(sp)-1)] idx = max(range(len(diffs)), key=lambda i: diffs[i]) thr = (sp[idx] + sp[idx+1]) / 2.0 else: thr = 0.5 # fallback # collect all above threshold, else top-1 extracted = [s for s,p in zip(sentences, probs) if p >= thr] if not extracted and sentences: best_idx = int(max(range(len(probs)), key=lambda i: probs[i])) extracted = [sentences[best_idx]] records.append({ 'Claude_Call': transcript, 'Predicted Sel_K': extracted }) # save out_df = pd.DataFrame(records) os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) out_df.to_excel(output_path, index=False) print(f"➡️ Saved {len(out_df)} rows to {output_path}")