| import os |
| import logging |
| import random |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from nltk import sent_tokenize |
| from sklearn.metrics import accuracy_score, precision_score, f1_score |
| from sklearn.model_selection import train_test_split |
| from torch.utils.data import DataLoader, random_split, WeightedRandomSampler |
| from transformers import AutoTokenizer, AutoModel, AutoConfig, get_linear_schedule_with_warmup |
| from peft import PeftModel, LoraConfig, get_peft_model |
| from datasets import load_dataset, DatasetDict, load_from_disk |
| import spacy |
| import re |
| from tqdm.auto import tqdm |
| from accelerate import Accelerator |
| import matplotlib.pyplot as plt |
| from torch.optim import AdamW |
| import pandas as pd |
| from typing import Optional, Tuple, List, Dict |
|
|
| from models import GraphAugmentedNLIModel, GraphAugmentedFinNLIModel |
| from preprocess_data import SpanExtractionChunkedDataset, process_data, chunk_transcript, span_collate_fn |
|
|
| |
| |
| |
| from config import MODEL_NAME, MAX_LENGTH, OVERLAP, PREPROCESSED_DIR, tokenizer, nlp |
|
|
| |
| BATCH_SIZE = 16 |
| |
| |
| LEARNING_RATE = 2e-5 |
| EPOCHS = 5 |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| MIXED_PRECISION = "fp16" |
|
|
| |
| label_map = {0: "entailment", 1: "neutral", 2: "contradiction"} |
|
|
| |
| |
| |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| def set_seed(seed: int = 42): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
| |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| nlp = spacy.load("en_core_web_sm") |
|
|
| |
| |
| |
| def build_dependency_graph(sentence: str): |
| doc = nlp(sentence) |
| tokens = [token.text for token in doc] |
| edges = [] |
| for token in doc: |
| if token.head.i != token.i: |
| edges.append((token.i, token.head.i)) |
| edges.append((token.head.i, token.i)) |
| return tokens, edges |
|
|
| |
| |
| |
| def align_tokens(spacy_tokens, wp_tokens): |
| node_indices = [] |
| wp_idx = 1 |
| for _ in spacy_tokens: |
| if wp_idx >= len(wp_tokens) - 1: |
| break |
| node_indices.append(wp_idx) |
| wp_idx += 1 |
| while wp_idx < len(wp_tokens) - 1 and wp_tokens[wp_idx].startswith("##"): |
| wp_idx += 1 |
| return node_indices |
|
|
| |
| |
| |
| def my_collate_fn(batch): |
| input_ids = [torch.tensor(ex["input_ids"], dtype=torch.long) for ex in batch] |
| attention_mask = [torch.tensor(ex["attention_mask"], dtype=torch.long) for ex in batch] |
| labels = [ex.get("labels", None) for ex in batch] |
|
|
| premise_graph_tokens = [ex.get("premise_graph_tokens") for ex in batch] |
| premise_graph_edges = [ex.get("premise_graph_edges") for ex in batch] |
| premise_node_indices = [ex.get("premise_node_indices") for ex in batch] |
|
|
| hypothesis_graph_tokens = [ex.get("hypothesis_graph_tokens") for ex in batch] |
| hypothesis_graph_edges = [ex.get("hypothesis_graph_edges") for ex in batch] |
| hypothesis_node_indices = [ex.get("hypothesis_node_indices") for ex in batch] |
|
|
| input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
| attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0) |
| labels = torch.tensor(labels, dtype=torch.long) if labels and labels[0] is not None else None |
|
|
| return { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": labels, |
| "premise_graph_tokens": premise_graph_tokens, |
| "premise_graph_edges": premise_graph_edges, |
| "premise_node_indices": premise_node_indices, |
| "hypothesis_graph_tokens": hypothesis_graph_tokens, |
| "hypothesis_graph_edges": hypothesis_graph_edges, |
| "hypothesis_node_indices": hypothesis_node_indices, |
| } |
|
|
| |
| |
| |
| def train_model(epochs: int = EPOCHS, |
| batch_size: int = BATCH_SIZE, |
| lr: float = LEARNING_RATE, |
| save_model: bool = False, |
| save_path: str = 'gnn_model_weights_3.pt'): |
| set_seed() |
| process_data() |
| logging.info("Loading preprocessed dataset...") |
| snli = load_from_disk(PREPROCESSED_DIR) |
| snli.set_format("python", output_all_columns=True) |
|
|
| train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) |
| val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn) |
|
|
| model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) |
|
|
| if hasattr(model.bert, 'gradient_checkpointing_enable'): |
| model.bert.gradient_checkpointing_enable() |
| logging.info("Enabled gradient checkpointing on BERT.") |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
| num_training_steps = epochs * len(train_loader) |
| lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=num_training_steps) |
|
|
| accelerator = Accelerator(mixed_precision=MIXED_PRECISION) |
| model, optimizer, train_loader, val_loader, lr_scheduler = accelerator.prepare( |
| model, optimizer, train_loader, val_loader, lr_scheduler |
| ) |
|
|
| model.train() |
| all_losses = [] |
| epoch_losses = [] |
| best_val_loss = float('inf') |
| best_epoch = 0 |
|
|
| for epoch in range(1, epochs + 1): |
| epoch_loss = [] |
| progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", leave=False) |
| for batch in progress: |
| labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None |
| outputs = model( |
| input_ids=batch["input_ids"].to(DEVICE), |
| attention_mask=batch["attention_mask"].to(DEVICE), |
| premise_graph_tokens=batch["premise_graph_tokens"], |
| premise_graph_edges=batch["premise_graph_edges"], |
| premise_node_indices=batch["premise_node_indices"], |
| hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], |
| hypothesis_graph_edges=batch["hypothesis_graph_edges"], |
| hypothesis_node_indices=batch["hypothesis_node_indices"], |
| labels=labels |
| ) |
| loss = outputs.get("loss") if isinstance(outputs, dict) else outputs |
|
|
| optimizer.zero_grad() |
| accelerator.backward(loss) |
| optimizer.step() |
| lr_scheduler.step() |
|
|
| loss_val = loss.item() |
| epoch_loss.append(loss_val) |
| all_losses.append(loss_val) |
| progress.set_postfix({"loss": f"{loss_val:.4f}"}) |
|
|
| avg_epoch_loss = np.mean(epoch_loss) |
| epoch_losses.append(avg_epoch_loss) |
| logging.info(f"Epoch {epoch} completed. Avg Loss: {avg_epoch_loss:.4f}") |
|
|
| |
| model.eval() |
| val_losses = [] |
| with torch.no_grad(): |
| for batch in val_loader: |
| labels = batch["labels"].to(DEVICE) if batch.get("labels") is not None else None |
| outputs = model( |
| input_ids=batch["input_ids"].to(DEVICE), |
| attention_mask=batch["attention_mask"].to(DEVICE), |
| premise_graph_tokens=batch["premise_graph_tokens"], |
| premise_graph_edges=batch["premise_graph_edges"], |
| premise_node_indices=batch["premise_node_indices"], |
| hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], |
| hypothesis_graph_edges=batch["hypothesis_graph_edges"], |
| hypothesis_node_indices=batch["hypothesis_node_indices"], |
| labels=labels |
| ) |
| loss_item = outputs.get("loss").item() if isinstance(outputs, dict) else outputs.item() |
| val_losses.append(loss_item) |
| avg_val_loss = np.mean(val_losses) if val_losses else float('inf') |
| logging.info(f"Validation Loss after Epoch {epoch}: {avg_val_loss:.4f}") |
|
|
| if avg_val_loss < best_val_loss: |
| best_val_loss = avg_val_loss |
| best_epoch = epoch |
| if save_model: |
| logging.info(f"Saving best model at epoch {epoch} with val loss {avg_val_loss:.4f}") |
| torch.save(model.state_dict(), save_path) |
| model.train() |
|
|
| |
| plt.figure() |
| plt.plot(all_losses) |
| plt.xlabel('Training steps') |
| plt.ylabel('Loss') |
| plt.title('Step-wise Training Loss') |
| plt.show() |
|
|
| plt.figure() |
| plt.plot(range(1, epochs+1), epoch_losses, marker='o') |
| plt.xlabel('Epochs') |
| plt.ylabel('Loss') |
| plt.title('Epoch-wise Training Loss') |
| plt.show() |
|
|
| logging.info(f"Training complete. Best validation loss {best_val_loss:.4f} at epoch {best_epoch}.") |
| return model |
|
|
|
|
| def predict_nli(premise, hypothesis, tokenizer=tokenizer, model_path='gnn_model_checkpoint.pt'): |
| |
| model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) |
|
|
| |
| ckpt = torch.load(model_path, map_location=DEVICE) |
| model.load_state_dict(ckpt["model_state_dict"]) |
|
|
| model.eval() |
|
|
| |
| encoded = tokenizer( |
| premise, hypothesis, |
| truncation=True, |
| padding="max_length", |
| max_length=MAX_LENGTH, |
| return_tensors="pt" |
| ) |
|
|
| input_ids = encoded["input_ids"] |
| attention_mask = encoded["attention_mask"] |
|
|
| |
| p_tokens, p_edges = build_dependency_graph(premise) |
| h_tokens, h_edges = build_dependency_graph(hypothesis) |
|
|
| |
| wp_tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
|
| p_node_indices = align_tokens(p_tokens, wp_tokens) |
| h_node_indices = align_tokens(h_tokens, wp_tokens) |
|
|
| |
| device = next(model.parameters()).device |
| input_ids = input_ids.to(device) |
| attention_mask = attention_mask.to(device) |
|
|
| |
| |
| premise_graph_tokens = [p_tokens] |
| premise_graph_edges = [p_edges] |
| premise_node_indices = [p_node_indices] |
|
|
| hypothesis_graph_tokens = [h_tokens] |
| hypothesis_graph_edges = [h_edges] |
| hypothesis_node_indices = [h_node_indices] |
|
|
| with torch.no_grad(): |
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| premise_graph_tokens=premise_graph_tokens, |
| premise_graph_edges=premise_graph_edges, |
| premise_node_indices=premise_node_indices, |
| hypothesis_graph_tokens=hypothesis_graph_tokens, |
| hypothesis_graph_edges=hypothesis_graph_edges, |
| hypothesis_node_indices=hypothesis_node_indices |
| ) |
|
|
| logits = outputs["logits"] |
| probs = F.softmax(logits, dim=-1).cpu().numpy()[0] |
| |
| predicted_label_id = torch.argmax(logits, dim=-1).item() |
| predicted_label = label_map[predicted_label_id] |
| prob_map = dict() |
| for i, cls_label in label_map.items(): |
| prob_map[cls_label] = probs[i] |
| return predicted_label, prob_map |
|
|
|
|
| def predict_fin_nli( |
| premise: str, |
| hypothesis: str, |
| tokenizer=tokenizer, |
| model_path: str = 'gnn_model_checkpoint.pt', |
| adapter_dir: str = './lora_finance_adapter', |
| ) -> (str, list): |
| |
| base_model = GraphAugmentedFinNLIModel(MODEL_NAME).to(DEVICE) |
| ckpt = torch.load(model_path, map_location=DEVICE) |
| base_model.load_state_dict(ckpt['model_state_dict']) |
|
|
| |
| lora_cfg = LoraConfig( |
| r=8, |
| lora_alpha=32, |
| lora_dropout=0.1, |
| bias='none', |
| task_type='SEQ_CLS', |
| target_modules=['query', 'value'] |
| ) |
| model = get_peft_model(base_model, lora_cfg).to(DEVICE) |
|
|
| |
| adapter_ckpt = torch.load(os.path.join(adapter_dir, 'training_checkpoint.pt'), map_location=DEVICE) |
| |
| model.load_state_dict(adapter_ckpt['model_state_dict'], strict=False) |
| model.eval() |
|
|
| |
| enc = tokenizer( |
| premise, hypothesis, |
| truncation=True, |
| padding='max_length', |
| max_length=MAX_LENGTH, |
| return_tensors='pt' |
| ) |
| input_ids = enc['input_ids'].to(DEVICE) |
| attention_mask = enc['attention_mask'].to(DEVICE) |
|
|
| |
| p_toks, p_edges = build_dependency_graph(premise) |
| h_toks, h_edges = build_dependency_graph(hypothesis) |
| wp = tokenizer.convert_ids_to_tokens(input_ids[0]) |
| p_idx = align_tokens(p_toks, wp) |
| h_idx = align_tokens(h_toks, wp) |
|
|
| premise_graph_tokens = [p_toks] |
| premise_graph_edges = [p_edges] |
| premise_node_indices = [p_idx] |
| hypothesis_graph_tokens = [h_toks] |
| hypothesis_graph_edges = [h_edges] |
| hypothesis_node_indices = [h_idx] |
|
|
| |
| with torch.no_grad(): |
| out = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| premise_graph_tokens=premise_graph_tokens, |
| premise_graph_edges=premise_graph_edges, |
| premise_node_indices=premise_node_indices, |
| hypothesis_graph_tokens=hypothesis_graph_tokens, |
| hypothesis_graph_edges=hypothesis_graph_edges, |
| hypothesis_node_indices=hypothesis_node_indices |
| ) |
|
|
| logits = out['logits'][0] |
| probs = torch.softmax(logits, dim=-1).cpu().numpy() |
|
|
| |
| entail, neutral, contra = probs |
| s = entail + contra + 1e-12 |
| scores = [entail / s, contra / s] |
| label = 'entailment' if entail >= contra else 'contradiction' |
| return label, scores |
|
|
|
|
| def train_model_with_chkpt(epochs: int = 5, |
| batch_size: int = 16, |
| lr: float = 2e-5, |
| save_model: bool = False, |
| save_path: str = 'gnn_model_checkpoint.pt', |
| resume: bool = False): |
| """ |
| Train with mixed precision, gradient checkpointing, and resume support. |
| If resume=True and save_path exists, picks up from last epoch. |
| """ |
| set_seed() |
| process_data() |
| logging.info("Loading preprocessed dataset…") |
| snli = load_from_disk(PREPROCESSED_DIR) |
| snli.set_format("python", output_all_columns=True) |
|
|
| train_loader = DataLoader(snli["train"], batch_size=batch_size, shuffle=True, collate_fn=my_collate_fn) |
| val_loader = DataLoader(snli["validation"], batch_size=batch_size, collate_fn=my_collate_fn) |
|
|
| model = GraphAugmentedNLIModel(MODEL_NAME).to(DEVICE) |
| optimizer = torch.optim.AdamW(model.parameters(), lr=lr) |
| total_steps = epochs * len(train_loader) |
| scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=1000, num_training_steps=total_steps) |
|
|
| |
| start_epoch = 1 |
| if resume and os.path.isfile(save_path): |
| ckpt = torch.load(save_path, map_location=DEVICE) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| scheduler.load_state_dict(ckpt["scheduler_state_dict"]) |
| start_epoch = ckpt.get("epoch", 1) + 1 |
| logging.info(f"Resuming from epoch {start_epoch}") |
|
|
| |
| if hasattr(model.bert, "gradient_checkpointing_enable"): |
| model.bert.gradient_checkpointing_enable() |
| logging.info("Enabled gradient checkpointing on BERT.") |
| accelerator = Accelerator(mixed_precision=MIXED_PRECISION) |
| model, optimizer, train_loader, val_loader, scheduler = accelerator.prepare( |
| model, optimizer, train_loader, val_loader, scheduler |
| ) |
|
|
| best_val_loss = float("inf") |
| for epoch in range(start_epoch, epochs + 1): |
| model.train() |
| train_losses = [] |
| for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"): |
| optimizer.zero_grad() |
| outputs = model( |
| input_ids=batch["input_ids"].to(DEVICE), |
| attention_mask=batch["attention_mask"].to(DEVICE), |
| premise_graph_tokens=batch["premise_graph_tokens"], |
| premise_graph_edges=batch["premise_graph_edges"], |
| premise_node_indices=batch["premise_node_indices"], |
| hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], |
| hypothesis_graph_edges=batch["hypothesis_graph_edges"], |
| hypothesis_node_indices=batch["hypothesis_node_indices"], |
| labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None |
| ) |
| loss = outputs["loss"] if isinstance(outputs, dict) else outputs |
| accelerator.backward(loss) |
| optimizer.step() |
| scheduler.step() |
| train_losses.append(loss.item()) |
| avg_train = np.mean(train_losses) |
| logging.info(f"Epoch {epoch} train loss: {avg_train:.4f}") |
|
|
| |
| model.eval() |
| val_losses = [] |
| with torch.no_grad(): |
| for batch in val_loader: |
| outputs = model( |
| input_ids=batch["input_ids"].to(DEVICE), |
| attention_mask=batch["attention_mask"].to(DEVICE), |
| premise_graph_tokens=batch["premise_graph_tokens"], |
| premise_graph_edges=batch["premise_graph_edges"], |
| premise_node_indices=batch["premise_node_indices"], |
| hypothesis_graph_tokens=batch["hypothesis_graph_tokens"], |
| hypothesis_graph_edges=batch["hypothesis_graph_edges"], |
| hypothesis_node_indices=batch["hypothesis_node_indices"], |
| labels=batch.get("labels", None).to(DEVICE) if batch.get("labels") is not None else None |
| ) |
| v_loss = outputs["loss"].item() if isinstance(outputs, dict) else outputs.item() |
| val_losses.append(v_loss) |
| avg_val = np.mean(val_losses) if val_losses else float("inf") |
| logging.info(f"Epoch {epoch} val loss: {avg_val:.4f}") |
|
|
| |
| ckpt = { |
| "epoch": epoch, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "scheduler_state_dict": scheduler.state_dict(), |
| } |
| torch.save(ckpt, save_path) |
| logging.info(f"Saved checkpoint: {save_path}") |
|
|
| if avg_val < best_val_loss: |
| best_val_loss = avg_val |
|
|
| logging.info(f"Training complete. Best val loss: {best_val_loss:.4f}") |
| return model |
|
|
|
|
| def extract_sentences_by_intent( |
| text: str, |
| intent: str, |
| adapter_dir: str = "./lora_finance_adapter", |
| threshold: float = 0.7, |
| top_k: int = None, |
| min_words: int = 4, |
| convo_focus: str = None |
| ): |
| """ |
| Splits `text` into sentences, embeds them (and the `intent`) under your |
| LoRA‐adapted BERT, and returns those whose cosine similarity ≥ `threshold`. |
| Loads the adapter from the single `training_checkpoint.pt` in `adapter_dir`. |
| """ |
| |
| |
|
|
| if convo_focus is None: |
| sentences = [sent.text.strip() for sent in nlp(text).sents if sent.text.strip()] |
|
|
| elif convo_focus == "customer": |
| customer_lines = [ |
| line.strip() |
| for line in text.splitlines() |
| if line.strip().lower().startswith("customer:") |
| ] |
|
|
| |
| sentences = [] |
| for cust_line in customer_lines: |
| for sent in nlp(cust_line).sents: |
| s = sent.text.strip() |
| if s and len(s.split(' '))>6: |
| sentences.append(s) |
|
|
| else: |
| customer_lines = [ |
| line.strip() |
| for line in text.splitlines() |
| if line.strip().lower().startswith("agent:") |
| ] |
|
|
| |
| sentences = [] |
| for cust_line in customer_lines: |
| for sent in nlp(cust_line).sents: |
| s = sent.text.strip() |
| if s and len(s.split(' '))>6: |
| sentences.append(s) |
|
|
| |
| base_model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE) |
| lora_cfg = LoraConfig( |
| r=8, |
| lora_alpha=32, |
| lora_dropout=0.1, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(base_model, lora_cfg).to(DEVICE) |
|
|
| |
| chkpt_path = os.path.join(adapter_dir, "training_checkpoint.pt") |
| if not os.path.isfile(chkpt_path): |
| raise FileNotFoundError(f"No LoRA checkpoint at {chkpt_path}") |
| ckpt = torch.load(chkpt_path, map_location=DEVICE) |
| |
| model.load_state_dict(ckpt["model_state_dict"], strict=False) |
| model.eval() |
|
|
| |
| def embed(text_str): |
| toks = tokenizer( |
| text_str, |
| truncation=True, |
| padding="longest", |
| return_tensors="pt" |
| ).to(DEVICE) |
|
|
| em_args = { |
| "input_ids": toks["input_ids"], |
| "attention_mask": toks["attention_mask"], |
| } |
| if "token_type_ids" in toks: |
| em_args["token_type_ids"] = toks["token_type_ids"] |
|
|
| |
| hf_model = getattr(model, "base_model", model) |
| with torch.no_grad(): |
| last_hidden = hf_model( |
| input_ids=em_args["input_ids"], |
| attention_mask=em_args["attention_mask"], |
| **({"token_type_ids": em_args["token_type_ids"]} if "token_type_ids" in em_args else {}) |
| ).last_hidden_state |
| return last_hidden[:, 0, :] |
|
|
| |
| intent_emb = embed(intent) |
|
|
| results = [] |
| with torch.no_grad(): |
| for sent in sentences: |
| clean = re.sub(r'^(Agent|Customer):\s*', "", sent) |
| if len(clean.split()) < min_words: |
| continue |
|
|
| sent_emb = embed(clean) |
| sim = F.cosine_similarity(sent_emb, intent_emb, dim=1).item() |
| if sim >= threshold: |
| results.append((clean, sim)) |
|
|
| |
| results.sort(key=lambda x: x[1], reverse=True) |
| return results[:top_k] if top_k else results |
|
|
|
|
| def train_sentence_extractor( |
| model: nn.Module, |
| dataset: torch.utils.data.Dataset, |
| output_dir: str, |
| val_split: float = 0.2, |
| epochs: int = 3, |
| batch_size: int = 16, |
| lr: float = 2e-5, |
| device: str = "cpu", |
| unfreeze_after_epoch: int = 1, |
| threshold: float = 0.5 |
| ): |
| """ |
| Fine-tune `model` on `dataset`, hold out `val_split` for val, |
| compute loss + acc + precision + F1 each epoch, save best checkpoint, |
| and plot all four metrics at the end. |
| """ |
| |
| total = len(dataset) |
| val_n = int(total * val_split) |
| train_n = total - val_n |
| train_ds, val_ds = random_split(dataset, [train_n, val_n]) |
|
|
| |
| train_labels = [train_ds[i]['label'].item() for i in range(len(train_ds))] |
| counts = torch.bincount(torch.tensor(train_labels, dtype=torch.long)) |
| weights = (1.0 / counts.float()).tolist() |
| sample_weights = [weights[int(l)] for l in train_labels] |
| sampler = WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) |
|
|
| train_loader = DataLoader(train_ds, batch_size=batch_size, sampler=sampler, drop_last=True) |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) |
|
|
| model.to(device) |
| |
| for p in model.bert.parameters(): p.requires_grad = False |
|
|
| optimizer = AdamW(model.parameters(), lr=lr) |
| total_steps = epochs * len(train_loader) |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=int(0.1 * total_steps), |
| num_training_steps=total_steps |
| ) |
| criterion = nn.BCEWithLogitsLoss() |
|
|
| |
| train_losses, val_losses = [], [] |
| train_accs, val_accs = [], [] |
| train_precs, val_precs = [], [] |
| train_f1s, val_f1s = [], [] |
|
|
| best_val_loss = float('inf') |
|
|
| for epoch in range(1, epochs+1): |
| |
| model.train() |
| epoch_loss = 0.0 |
| preds, labels = [], [] |
| for batch in tqdm(train_loader, desc=f"Train {epoch}/{epochs}"): |
| inputs = batch['input_ids'].to(device) |
| masks = batch['attention_mask'].to(device) |
| labs = batch['label'].to(device) |
|
|
| optimizer.zero_grad() |
| logits = model(inputs, masks) |
| loss = criterion(logits, labs) |
| loss.backward() |
| optimizer.step() |
| scheduler.step() |
|
|
| epoch_loss += loss.item() |
|
|
| probs = torch.sigmoid(logits) |
| batch_preds = (probs >= threshold).long() |
| preds.extend(batch_preds.cpu().tolist()) |
| labels.extend(labs.cpu().long().tolist()) |
|
|
| avg_train = epoch_loss / len(train_loader) |
| train_losses.append(avg_train) |
| train_accs.append( accuracy_score(labels, preds) ) |
| train_precs.append( precision_score(labels, preds, zero_division=0) ) |
| train_f1s.append( f1_score(labels, preds, zero_division=0) ) |
| print(f"→ Epoch {epoch} Train — loss {avg_train:.4f}, acc {train_accs[-1]:.4f}, prec {train_precs[-1]:.4f}, f1 {train_f1s[-1]:.4f}") |
|
|
| |
| if epoch == unfreeze_after_epoch: |
| for p in model.bert.parameters(): p.requires_grad = True |
| optimizer = AdamW([ |
| {"params": model.classifier.parameters(), "lr": 1e-3}, |
| {"params": model.bert.parameters(), "lr": 1e-5}, |
| ], weight_decay=1e-2) |
| scheduler = get_linear_schedule_with_warmup( |
| optimizer, |
| num_warmup_steps=int(0.1 * total_steps), |
| num_training_steps=total_steps |
| ) |
|
|
| |
| model.eval() |
| epoch_loss = 0.0 |
| preds, labels = [], [] |
| with torch.no_grad(): |
| for batch in tqdm(val_loader, desc=f" Val {epoch}/{epochs}"): |
| inputs = batch['input_ids'].to(device) |
| masks = batch['attention_mask'].to(device) |
| labs = batch['label'].to(device) |
|
|
| logits = model(inputs, masks) |
| loss = criterion(logits, labs) |
| epoch_loss += loss.item() |
|
|
| probs = torch.sigmoid(logits) |
| batch_preds = (probs >= threshold).long() |
| preds.extend(batch_preds.cpu().tolist()) |
| labels.extend(labs.cpu().long().tolist()) |
|
|
| avg_val = epoch_loss / len(val_loader) |
| val_losses.append(avg_val) |
| val_accs.append( accuracy_score(labels, preds) ) |
| val_precs.append( precision_score(labels, preds, zero_division=0) ) |
| val_f1s.append( f1_score(labels, preds, zero_division=0) ) |
| print(f"→ Epoch {epoch} Val — loss {avg_val:.4f}, acc {val_accs[-1]:.4f}, prec {val_precs[-1]:.4f}, f1 {val_f1s[-1]:.4f}") |
|
|
| |
| os.makedirs(output_dir, exist_ok=True) |
| ckpt = os.path.join(output_dir, f"epo{epoch}_val{avg_val:.4f}.pth") |
| torch.save(model.state_dict(), ckpt) |
| if avg_val < best_val_loss: |
| best_val_loss = avg_val |
| torch.save(model.state_dict(), os.path.join(output_dir, "best_model.pth")) |
| print(f"🎉 New best model saved (val loss {best_val_loss:.4f})") |
|
|
| print(f"✔️ Training complete — best val loss: {best_val_loss:.4f}") |
|
|
| |
| epochs = list(range(1, epochs+1)) |
|
|
| save_metric_plot( |
| epochs, |
| train_losses, |
| val_losses, |
| metric_name="Loss", |
| output_path="results/Loss_Plot.png" |
| ) |
|
|
| save_metric_plot( |
| epochs, |
| train_accs, |
| val_accs, |
| metric_name="Accuracy", |
| output_path="results/Accuracy_Plot.png", |
| threshold=0.5 |
| ) |
|
|
| save_metric_plot( |
| epochs, |
| train_precs, |
| val_precs, |
| metric_name="Precision", |
| output_path="results/Precision_Plot.png", |
| threshold=0.5 |
| ) |
|
|
| save_metric_plot( |
| epochs, |
| train_f1s, |
| val_f1s, |
| metric_name="F1 Score", |
| output_path="results/F1Score_Plot.png", |
| threshold=0.5 |
| ) |
|
|
|
|
| def save_metric_plot( |
| epochs, |
| train_vals, |
| val_vals, |
| metric_name: str, |
| output_path: str, |
| threshold: float = None |
| ): |
| """ |
| epochs – list of epoch indices |
| train_vals – list of train metric values |
| val_vals – list of validation metric values |
| metric_name – e.g. "Loss", "Accuracy", "Precision", "F1 Score" |
| output_path – where to save the PNG |
| threshold – optional horizontal line to draw, e.g. 0.5 |
| """ |
| fig, ax = plt.subplots(figsize=(8, 5)) |
| ax.plot(epochs, train_vals, marker='o', linewidth=2, label=f'Train {metric_name}') |
| ax.plot(epochs, val_vals, marker='s', linewidth=2, label=f'Val {metric_name}') |
|
|
| if threshold is not None: |
| ax.axhline(threshold, color='gray', linestyle='--', linewidth=1, label=f'Threshold = {threshold}') |
|
|
| ax.set_title(f'{metric_name} over Epochs', fontsize=14, pad=10) |
| ax.set_xlabel('Epoch', fontsize=12) |
| ax.set_ylabel(metric_name, fontsize=12) |
| ax.grid(True, linestyle='--', alpha=0.4) |
| ax.legend(loc='best', frameon=True, fontsize=10) |
| fig.tight_layout() |
| fig.savefig(output_path, dpi=300) |
| plt.close(fig) |
|
|
|
|
| def demo_on_random_val( |
| model, |
| tokenizer, |
| excel_path: str, |
| ckpt_path: str, |
| max_length: int = 128, |
| device: str = "cpu", |
| temperature: float = 1.0 |
| ): |
| """ |
| Like demo_on_random_val, but instead of a fixed threshold: |
| 1) Compute sigmoid(logits / temperature) for each sentence |
| 2) Sort probabilities descending |
| 3) Find the largest gap between adjacent probs |
| 4) Set dynamic_threshold = midpoint of that gap |
| 5) Extract all sentences with prob >= dynamic_threshold |
| """ |
| |
| model.load_state_dict(torch.load(ckpt_path, map_location=device)) |
| model.to(device).eval() |
|
|
| |
| df = pd.read_excel(excel_path) |
| _, val_df = train_test_split(df, test_size=0.2, random_state=42) |
| row = val_df.sample(n=1, random_state=random.randint(0,999)).iloc[0] |
| transcript = str(row['Claude_Call']) |
| print(f"\n── Transcript (val sample idx={row['idx']}):\n{transcript}\n") |
|
|
| |
| sentences, probs = [], [] |
| for sent in sent_tokenize(transcript): |
| enc = tokenizer.encode_plus( |
| sent, |
| max_length=max_length, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
| logits = model(enc['input_ids'].to(device), |
| enc['attention_mask'].to(device)) |
| prob = torch.sigmoid(logits / temperature).item() |
| sentences.append(sent) |
| probs.append(prob) |
|
|
| |
| print("Sentence probabilities:") |
| for s,p in zip(sentences, probs): |
| print(f" → {p:.4f} → {s}") |
|
|
| |
| if len(probs) < 2 or max(probs) - min(probs) < 1e-3: |
| dynamic_thr = 0.5 |
| else: |
| |
| sorted_probs = sorted(probs, reverse=True) |
| diffs = [sorted_probs[i] - sorted_probs[i+1] for i in range(len(sorted_probs)-1)] |
| idx = max(range(len(diffs)), key=lambda i: diffs[i]) |
| |
| dynamic_thr = (sorted_probs[idx] + sorted_probs[idx+1]) / 2.0 |
|
|
| print(f"\nDynamic threshold = {dynamic_thr:.4f}\n") |
| print("Extracted sentences:") |
| for s,p in zip(sentences, probs): |
| if p >= dynamic_thr: |
| print(f" • {p:.4f} → {s}") |
| print() |
|
|
|
|
| def batch_predict_and_save( |
| model, |
| tokenizer, |
| excel_path: str, |
| ckpt_path: str, |
| output_path: str, |
| n_samples: int = 40, |
| max_length: int = 128, |
| device: str = "cpu", |
| temperature: float = 1.0, |
| random_state: int = None |
| ): |
| """ |
| 1) Loads best checkpoint |
| 2) Samples `n_samples` rows |
| 3) For each transcript: |
| - tokenize into sentences |
| - compute p = sigmoid(logits/temperature) |
| - compute elbow threshold on sorted p’s |
| - extract all sentences with p >= elbow |
| - if none, pick the highest-p sentence |
| 4) Save new Excel with columns: |
| - 'Claude_Call' |
| - 'Predicted Sel_K' (list of extracted sentences) |
| """ |
| |
| model.load_state_dict(torch.load(ckpt_path, map_location=device)) |
| model.to(device).eval() |
|
|
| |
| df = pd.read_excel(excel_path) |
| sampled = df.sample(n=n_samples, random_state=random_state) \ |
| if random_state is not None else df.sample(n=n_samples) |
|
|
| records = [] |
| for _, row in tqdm(sampled.iterrows(), |
| total=len(sampled), |
| desc="Running Predictions"): |
| transcript = str(row['Claude_Call']) |
| sentences = sent_tokenize(transcript) |
|
|
| |
| probs = [] |
| for sent in sentences: |
| enc = tokenizer.encode_plus( |
| sent, |
| max_length=max_length, |
| padding='max_length', |
| truncation=True, |
| return_tensors='pt' |
| ) |
| with torch.no_grad(): |
| logits = model(enc['input_ids'].to(device), |
| enc['attention_mask'].to(device)) |
| p = torch.sigmoid(logits / temperature).item() |
| probs.append(p) |
|
|
| |
| if len(probs) >= 2 and max(probs) - min(probs) > 1e-3: |
| sp = sorted(probs, reverse=True) |
| diffs = [sp[i] - sp[i+1] for i in range(len(sp)-1)] |
| idx = max(range(len(diffs)), key=lambda i: diffs[i]) |
| thr = (sp[idx] + sp[idx+1]) / 2.0 |
| else: |
| thr = 0.5 |
|
|
| |
| extracted = [s for s,p in zip(sentences, probs) if p >= thr] |
| if not extracted and sentences: |
| best_idx = int(max(range(len(probs)), key=lambda i: probs[i])) |
| extracted = [sentences[best_idx]] |
|
|
| records.append({ |
| 'Claude_Call': transcript, |
| 'Predicted Sel_K': extracted |
| }) |
|
|
| |
| out_df = pd.DataFrame(records) |
| os.makedirs(os.path.dirname(output_path) or '.', exist_ok=True) |
| out_df.to_excel(output_path, index=False) |
| print(f"➡️ Saved {len(out_df)} rows to {output_path}") |