import os import sys import json import logging import time import pickle import copy import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import TensorDataset, DataLoader, Subset from sklearn.model_selection import StratifiedKFold from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score from matplotlib import pyplot as plt _PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) if str(_PROJECT_ROOT) not in sys.path: sys.path.insert(0, str(_PROJECT_ROOT)) # We need the Tokenizer from stage 2 to execute texts_to_sequences natively from src.stage2_preprocessing import KerasStyleTokenizer logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(name)s | %(levelname)s | %(message)s") logger = logging.getLogger("lstm_model") # ── Architecture ────────────────────────────────────── class SpatialDropout1D(nn.Module): def __init__(self, p=0.3): super().__init__() self.p = p def forward(self, x): if not self.training or self.p == 0: return x # x is (batch, seq_len, embed_dim) # convert to (batch, embed_dim, seq_len) x = x.permute(0, 2, 1) # 1D spatial dropout is equivalent to 2d dropout with height 1 # nn.Dropout2d drops entire channels (which are our embedding dimensions) x = x.unsqueeze(3) x = F.dropout2d(x, p=self.p, training=self.training) x = x.squeeze(3) return x.permute(0, 2, 1) class BiLSTMClassifier(nn.Module): def __init__(self, vocab_size, embedding_matrix=None): super().__init__() # Embedding(vocab_size, 100) self.embedding = nn.Embedding(vocab_size, 100, padding_idx=0) if embedding_matrix is not None: self.embedding.weight.data.copy_(torch.from_numpy(embedding_matrix)) self.embedding.weight.requires_grad = False self.spatial_drop = SpatialDropout1D(0.3) # Bi-LSTM(100->128, bidirectional=True) self.lstm1 = nn.LSTM(100, 128, bidirectional=True, batch_first=True) # Bi-LSTM(256->64, bidirectional=True) self.lstm2 = nn.LSTM(256, 64, bidirectional=True, batch_first=True) # Linear(128, 64) + ReLU self.fc1 = nn.Linear(128, 64) self.dropout = nn.Dropout(0.4) # Linear(64, 1) + Sigmoid (handled via BCEWithLogitsLoss below conceptually, or explicitly applied) self.fc2 = nn.Linear(64, 1) def forward(self, x): h = self.embedding(x) h = self.spatial_drop(h) h, _ = self.lstm1(h) # Taking last states? Typically Keras `return_sequences=False` on the 2nd LSTM # means it takes the final hidden state of the sequence _, (h_n, _) = self.lstm2(h) # h_n shape for Bi-LSTM: (2, batch, hidden_size) # Concatenate forward and backward final states h_concat = torch.cat((h_n[-2,:,:], h_n[-1,:,:]), dim=1) # shape: (batch, 128) out = F.relu(self.fc1(h_concat)) out = self.dropout(out) logits = self.fc2(out) return logits.squeeze(1) # ── Utilities ────────────────────────────────────── def pad_sequences(sequences, maxlen=512, padding='post'): padded = np.zeros((len(sequences), maxlen), dtype=np.int64) for i, seq in enumerate(sequences): seq = seq[:maxlen] if padding == 'post': padded[i, :len(seq)] = seq else: padded[i, -len(seq):] = seq return padded def load_glove_embeddings(glove_path, word_index, embed_dim=100): logger.info(f"Loading GloVe embeddings from {glove_path}...") embeddings_index = {} with open(glove_path, "r", encoding="utf-8") as f: for line in f: values = line.split() word = values[0] coefs = np.asarray(values[1:], dtype='float32') embeddings_index[word] = coefs vocab_size = len(word_index) + 1 # 1 for padding embedding_matrix = np.zeros((vocab_size, embed_dim), dtype=np.float32) hits, misses = 0, 0 for word, i in word_index.items(): embedding_vector = embeddings_index.get(word) if embedding_vector is not None: embedding_matrix[i] = embedding_vector hits += 1 else: misses += 1 logger.info(f"GloVe mapped: {hits} hits, {misses} misses.") return embedding_matrix, vocab_size def plot_and_save_cm(y_true, y_pred, path): cm = confusion_matrix(y_true, (np.array(y_pred) > 0.5).astype(int)) fig, ax = plt.subplots(figsize=(5, 5)) ax.matshow(cm, cmap=plt.cm.Blues, alpha=0.3) for i in range(cm.shape[0]): for j in range(cm.shape[1]): ax.text(x=j, y=i, s=cm[i, j], va='center', ha='center', size='xx-large') plt.xlabel('Predicted Label') plt.ylabel('True Label') plt.title('Bi-LSTM Confusion Matrix') plt.tight_layout() plt.savefig(path) plt.close() # ── Training Loop ────────────────────────────────────── def train_epoch(model, loader, optimizer, criterion, device): model.train() total_loss = 0 for x_batch, y_batch in loader: x_batch, y_batch = x_batch.to(device), y_batch.to(device) optimizer.zero_grad() logits = model(x_batch) loss = criterion(logits, y_batch) loss.backward() optimizer.step() total_loss += loss.item() * x_batch.size(0) return total_loss / len(loader.dataset) @torch.no_grad() def eval_model(model, loader, criterion, device): model.eval() total_loss = 0 all_preds = [] for x_batch, y_batch in loader: x_batch, y_batch = x_batch.to(device), y_batch.to(device) logits = model(x_batch) loss = criterion(logits, y_batch) total_loss += loss.item() * x_batch.size(0) probas = torch.sigmoid(logits).cpu().numpy() all_preds.extend(probas) return total_loss / len(loader.dataset), np.array(all_preds) def train_lstm_logic(cfg, splits_dir, save_dir, glove_path): os.makedirs(save_dir, exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Load tokenized resources train_df = pd.read_csv(os.path.join(splits_dir, "df_train.csv")) val_df = pd.read_csv(os.path.join(splits_dir, "df_val.csv")) y_train = np.float32(train_df["binary_label"].values) y_val = np.float32(val_df["binary_label"].values) with open(os.path.join(_PROJECT_ROOT, cfg["paths"]["models_dir"], "tokenizer.pkl"), "rb") as f: tokenizer = pickle.load(f) maxlen = cfg.get("preprocessing", {}).get("lstm_max_len", 512) batch_size = cfg.get("training", {}).get("lstm_batch_size", 64) epochs = cfg.get("training", {}).get("lstm_epochs", 10) logger.info("Transforming texts to padded sequences...") X_train_seq = tokenizer.texts_to_sequences(train_df["clean_text"].fillna("")) X_val_seq = tokenizer.texts_to_sequences(val_df["clean_text"].fillna("")) X_train_pad = pad_sequences(X_train_seq, maxlen=maxlen, padding='post') X_val_pad = pad_sequences(X_val_seq, maxlen=maxlen, padding='post') # Embedding matrix emb_matrix, vocab_size = load_glove_embeddings(glove_path, tokenizer.word_index) # Class weights balancing formula: n_samples / (n_classes * np.bincount(y)) class_counts = np.bincount(y_train.astype(int)) pos_weight = torch.tensor([class_counts[0] / class_counts[1]], dtype=torch.float32).to(device) # Datasets train_tensor = TensorDataset(torch.from_numpy(X_train_pad).long(), torch.from_numpy(y_train)) val_tensor = TensorDataset(torch.from_numpy(X_val_pad).long(), torch.from_numpy(y_val)) val_loader = DataLoader(val_tensor, batch_size=batch_size, shuffle=False) # --- 5-Fold OOF Predictions --- logger.info("Starting 5-Fold OOF generation...") skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) oof_preds = np.zeros_like(y_train, dtype=np.float32) criterion_kfold = nn.BCEWithLogitsLoss(pos_weight=pos_weight) for fold, (t_idx, v_idx) in enumerate(skf.split(X_train_pad, y_train)): logger.info(f"OOF Fold {fold+1}/5") fold_train_ds = Subset(train_tensor, t_idx) fold_val_ds = Subset(train_tensor, v_idx) fold_train_loader = DataLoader(fold_train_ds, batch_size=batch_size, shuffle=True) fold_val_loader = DataLoader(fold_val_ds, batch_size=batch_size, shuffle=False) model = BiLSTMClassifier(vocab_size, emb_matrix).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5) best_val_loss = float('inf') patience_counter = 0 best_weights = copy.deepcopy(model.state_dict()) for ep in range(epochs): # Or hardcode early stop tightly for OOF e.g., 3-4 epochs max to save time t_loss = train_epoch(model, fold_train_loader, optimizer, criterion_kfold, device) v_loss, v_preds = eval_model(model, fold_val_loader, criterion_kfold, device) scheduler.step(v_loss) if v_loss < best_val_loss: best_val_loss = v_loss best_weights = copy.deepcopy(model.state_dict()) patience_counter = 0 else: patience_counter += 1 if patience_counter >= 3: break # Apply the best model model.load_state_dict(best_weights) _, fold_best_preds = eval_model(model, fold_val_loader, criterion_kfold, device) oof_preds[v_idx] = fold_best_preds np.save(os.path.join(save_dir, "lstm_oof.npy"), oof_preds) logger.info("Saved OOF predictions (lstm_oof.npy).") # --- Final Training on ALL Data --- logger.info("Starting final model training on full Train split...") train_loader = DataLoader(train_tensor, batch_size=batch_size, shuffle=True) model = BiLSTMClassifier(vocab_size, emb_matrix).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=1, factor=0.5) best_val_loss = float('inf') best_weights = copy.deepcopy(model.state_dict()) patience_counter = 0 for ep in range(epochs): t_loss = train_epoch(model, train_loader, optimizer, criterion_kfold, device) v_loss, v_preds = eval_model(model, val_loader, criterion_kfold, device) scheduler.step(v_loss) logger.info(f" Epoch {ep+1}/{epochs} | Train Loss: {t_loss:.4f} | Val Loss: {v_loss:.4f}") if v_loss < best_val_loss: best_val_loss = v_loss best_weights = copy.deepcopy(model.state_dict()) patience_counter = 0 else: patience_counter += 1 if patience_counter >= 3: logger.info(" EarlyStopping triggered.") break model.load_state_dict(best_weights) torch.save(model.state_dict(), os.path.join(save_dir, "model.pt")) logger.info("Saved final LSTM weights.") # Evaluate Validation Split _, val_preds_probas = eval_model(model, val_loader, criterion_kfold, device) val_preds_binary = (val_preds_probas >= 0.5).astype(int) logger.info("Validation Classification Report:\n" + classification_report(y_val, val_preds_binary)) roc_auc = roc_auc_score(y_val, val_preds_probas) logger.info(f"ROC-AUC: {roc_auc:.4f}") plot_and_save_cm(y_val, val_preds_probas, os.path.join(save_dir, "cm.png")) bucket_acc = {} for b in ["short", "medium", "long"]: b_mask = (val_df["text_length_bucket"] == b).values if b_mask.sum() > 0: acc = (val_preds_binary[b_mask] == y_val[b_mask]).mean() bucket_acc[b] = acc metrics = { "roc_auc": float(roc_auc), "bucket_accuracy": {k: float(v) for k, v in bucket_acc.items()} } with open(os.path.join(save_dir, "metrics.json"), "w") as f: json.dump(metrics, f, indent=2) if __name__ == "__main__": import yaml cfg_path = os.path.join(_PROJECT_ROOT, "config", "config.yaml") with open(cfg_path, "r", encoding="utf-8") as file: config = yaml.safe_load(file) s_dir = os.path.join(_PROJECT_ROOT, config["paths"]["splits_dir"]) m_dir = os.path.join(_PROJECT_ROOT, config["paths"]["models_dir"], "lstm_model") g_path = os.path.join(_PROJECT_ROOT, config["paths"]["glove_path"]) t0 = time.time() train_lstm_logic(config, s_dir, m_dir, g_path) print(f"Total time: {time.time() - t0:.2f}s")