| |
| """ |
| Train TBX5 classifier using both forward and reverse complement embeddings. |
| This script combines embeddings from both strands to improve classification accuracy. |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, TensorDataset |
| from sklearn.model_selection import train_test_split |
| from sklearn.preprocessing import StandardScaler |
| from sklearn.metrics import ( |
| roc_auc_score, |
| accuracy_score, |
| precision_recall_fscore_support, |
| confusion_matrix, |
| ) |
| import json |
| import pickle |
| from tqdm import tqdm |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from datetime import datetime |
|
|
| |
| sys.path.append(os.path.join(os.path.dirname(__file__), '..', 'finetuning')) |
|
|
| class TBX5ClassifierWithRC(nn.Module): |
| """ |
| 3-layer feedforward neural network for TBX5 binding site classification |
| using both forward and reverse complement embeddings. |
| Architecture: |
| - Input (8192 dimensions: 4096 forward + 4096 reverse complement) -> 2048 -> 512 -> 128 -> 1 (sigmoid) |
| - ReLU activation, BatchNorm, Dropout(0.5) after each hidden layer |
| """ |
|
|
| def __init__(self, input_dim=8192, dropout_rate=0.5): |
| super(TBX5ClassifierWithRC, self).__init__() |
|
|
| self.fc1 = nn.Linear(input_dim, 2048) |
| self.bn1 = nn.BatchNorm1d(2048) |
| self.dropout1 = nn.Dropout(dropout_rate) |
|
|
| self.fc2 = nn.Linear(2048, 512) |
| self.bn2 = nn.BatchNorm1d(512) |
| self.dropout2 = nn.Dropout(dropout_rate) |
|
|
| self.fc3 = nn.Linear(512, 128) |
| self.bn3 = nn.BatchNorm1d(128) |
| self.dropout3 = nn.Dropout(dropout_rate) |
|
|
| self.fc4 = nn.Linear(128, 1) |
|
|
| self.relu = nn.ReLU() |
| self.sigmoid = nn.Sigmoid() |
|
|
| def forward(self, x): |
| |
| x = self.fc1(x) |
| x = self.relu(x) |
| x = self.bn1(x) |
| x = self.dropout1(x) |
|
|
| |
| x = self.fc2(x) |
| x = self.relu(x) |
| x = self.bn2(x) |
| x = self.dropout2(x) |
|
|
| |
| x = self.fc3(x) |
| x = self.relu(x) |
| x = self.bn3(x) |
| x = self.dropout3(x) |
|
|
| |
| x = self.fc4(x) |
| x = self.sigmoid(x) |
|
|
| return x |
|
|
| def load_tbx5_embeddings_with_rc_from_csv(embeddings_dir, rc_embeddings_dir, processed_data_dir): |
| """ |
| Load TBX5 embeddings using train/val/test splits from processed_data_new CSV files. |
| |
| Args: |
| embeddings_dir: Directory containing forward embeddings |
| rc_embeddings_dir: Directory containing reverse complement embeddings |
| processed_data_dir: Directory containing train/val/test CSV files |
| |
| Returns: |
| train/val/test data splits with combined embeddings |
| """ |
| print(f"Loading data using CSV splits from: {processed_data_dir}") |
| print(f"Loading forward embeddings from: {embeddings_dir}") |
| print(f"Loading reverse complement embeddings from: {rc_embeddings_dir}") |
| |
| |
| train_df = pd.read_csv(os.path.join(processed_data_dir, 'train_tbx5_data_new.csv')) |
| val_df = pd.read_csv(os.path.join(processed_data_dir, 'val_tbx5_data_new.csv')) |
| test_df = pd.read_csv(os.path.join(processed_data_dir, 'test_tbx5_data_new.csv')) |
| |
| print(f"Train samples: {len(train_df)}") |
| print(f"Val samples: {len(val_df)}") |
| print(f"Test samples: {len(test_df)}") |
| |
| def load_embeddings_for_split(df, embeddings_dir, rc_embeddings_dir): |
| """Load embeddings for a specific split.""" |
| all_embeddings = [] |
| all_labels = [] |
| all_starts = [] |
| all_ends = [] |
| all_tbx5_scores = [] |
| all_chromosomes = [] |
| |
| total_samples = len(df) |
| found_samples = 0 |
| missing_files = 0 |
| missing_samples = 0 |
| |
| |
| loaded_chrom_data = {} |
| |
| |
| for idx, row in df.iterrows(): |
| chrom_num = row['chromosome'] |
| chrom = f"chr{chrom_num}" |
| start = row['start'] |
| end = row['end'] |
| label = row['label'] |
| tbx5_score = row['tbx5_score'] |
| |
| |
| if chrom not in loaded_chrom_data: |
| forward_file = os.path.join(embeddings_dir, f"{chrom}_tbx5_embeddings_arrays.npz") |
| rc_file = os.path.join(rc_embeddings_dir, f"{chrom}_tbx5_embeddings_rc_arrays.npz") |
| |
| if os.path.exists(forward_file) and os.path.exists(rc_file): |
| print(f" Loading {chrom}...") |
| forward_data = np.load(forward_file) |
| rc_data = np.load(rc_file) |
| |
| loaded_chrom_data[chrom] = { |
| 'forward_embeddings': forward_data['embeddings'], |
| 'forward_starts': forward_data['starts'], |
| 'forward_ends': forward_data['ends'], |
| 'forward_tbx5_scores': forward_data['tbx5_scores'], |
| 'rc_embeddings': rc_data['embeddings'], |
| 'rc_starts': rc_data['starts'], |
| 'rc_ends': rc_data['ends'], |
| 'rc_tbx5_scores': rc_data['tbx5_scores'] |
| } |
| else: |
| print(f" Warning: Missing embedding files for {chrom}") |
| loaded_chrom_data[chrom] = None |
| missing_files += 1 |
| continue |
| |
| |
| if loaded_chrom_data[chrom] is None: |
| missing_samples += 1 |
| continue |
| |
| chrom_data = loaded_chrom_data[chrom] |
| forward_starts = chrom_data['forward_starts'] |
| forward_embeddings = chrom_data['forward_embeddings'] |
| rc_embeddings = chrom_data['rc_embeddings'] |
| |
| |
| mask = (forward_starts == start) |
| if np.any(mask): |
| |
| emb_idx = np.where(mask)[0][0] |
| |
| |
| forward_emb = forward_embeddings[emb_idx] |
| rc_emb = rc_embeddings[emb_idx] |
| |
| |
| combined_emb = np.concatenate([forward_emb, rc_emb]) |
| |
| all_embeddings.append(combined_emb) |
| all_labels.append(label) |
| all_starts.append(start) |
| all_ends.append(end) |
| all_tbx5_scores.append(tbx5_score) |
| all_chromosomes.append(chrom) |
| |
| found_samples += 1 |
| else: |
| missing_samples += 1 |
| |
| continue |
| |
| print(f" Summary: {found_samples}/{total_samples} samples loaded") |
| print(f" Missing files: {missing_files} samples") |
| print(f" Missing embeddings: {missing_samples} samples") |
| |
| return ( |
| np.array(all_embeddings), |
| np.array(all_labels), |
| np.array(all_starts), |
| np.array(all_ends), |
| np.array(all_tbx5_scores), |
| all_chromosomes |
| ) |
| |
| |
| print("Loading train data...") |
| X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train = load_embeddings_for_split( |
| train_df, embeddings_dir, rc_embeddings_dir |
| ) |
| |
| print("Loading validation data...") |
| X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val = load_embeddings_for_split( |
| val_df, embeddings_dir, rc_embeddings_dir |
| ) |
| |
| print("Loading test data...") |
| X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test = load_embeddings_for_split( |
| test_df, embeddings_dir, rc_embeddings_dir |
| ) |
| |
| print(f"\nLoaded data:") |
| print(f"Train: {len(X_train)} samples") |
| print(f"Val: {len(X_val)} samples") |
| print(f"Test: {len(X_test)} samples") |
| print(f"Embedding dimension: {X_train.shape[1]}") |
| print(f"Train positive samples: {np.sum(y_train)}") |
| print(f"Val positive samples: {np.sum(y_val)}") |
| print(f"Test positive samples: {np.sum(y_test)}") |
| |
| |
| if len(X_train) == 0: |
| raise ValueError("No training data loaded! Check embedding files and CSV data.") |
| if len(X_val) == 0: |
| raise ValueError("No validation data loaded! Check embedding files and CSV data.") |
| if len(X_test) == 0: |
| raise ValueError("No test data loaded! Check embedding files and CSV data.") |
| |
| print(f"\nData quality check:") |
| print(f"Train positive ratio: {np.mean(y_train):.3f}") |
| print(f"Val positive ratio: {np.mean(y_val):.3f}") |
| print(f"Test positive ratio: {np.mean(y_test):.3f}") |
| |
| metadata = { |
| "total_samples": len(X_train) + len(X_val) + len(X_test), |
| "embedding_dim": X_train.shape[1], |
| "train_samples": len(X_train), |
| "val_samples": len(X_val), |
| "test_samples": len(X_test), |
| "train_positive": int(np.sum(y_train)), |
| "val_positive": int(np.sum(y_val)), |
| "test_positive": int(np.sum(y_test)), |
| "sequence_type": "forward_and_reverse_complement" |
| } |
| |
| return ( |
| X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, |
| X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, |
| X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, |
| metadata |
| ) |
|
|
| def prepare_data_with_scaling(X_train, X_val, X_test, y_train, y_val, y_test): |
| """ |
| Scale the features for train/val/test splits. |
| """ |
| print("Scaling features...") |
| |
| |
| scaler = StandardScaler() |
| X_train_scaled = scaler.fit_transform(X_train) |
| X_val_scaled = scaler.transform(X_val) |
| X_test_scaled = scaler.transform(X_test) |
| |
| return X_train_scaled, X_val_scaled, X_test_scaled, scaler |
|
|
| def train_model( |
| model, |
| train_loader, |
| val_loader, |
| test_loader, |
| device, |
| output_dir, |
| num_epochs=500, |
| learning_rate=1e-4, |
| patience=100, |
| lr_patience=20, |
| min_lr=1e-6, |
| gradient_clip=1.0, |
| save_every=5, |
| ): |
| """ |
| Train the model with specified optimization settings. |
| """ |
| print(f"Training model with learning rate {learning_rate}") |
| print(f"Early stopping patience: {patience}") |
| print(f"Learning rate reduction patience: {lr_patience}") |
| |
| |
| criterion = nn.BCELoss() |
| optimizer = optim.Adam(model.parameters(), lr=learning_rate) |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer, mode='min', factor=0.5, patience=lr_patience, min_lr=min_lr |
| ) |
| |
| |
| train_losses = [] |
| val_losses = [] |
| val_aucs = [] |
| test_results_by_epoch = {} |
| best_val_auc = 0.0 |
| best_epoch = 0 |
| epochs_without_improvement = 0 |
| |
| print(f"Starting training for {num_epochs} epochs...") |
| |
| for epoch in range(num_epochs): |
| |
| model.train() |
| train_loss = 0.0 |
| train_correct = 0 |
| train_total = 0 |
| |
| for batch_embeddings, batch_labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"): |
| batch_embeddings = batch_embeddings.to(device) |
| batch_labels = batch_labels.to(device).float() |
| |
| optimizer.zero_grad() |
| outputs = model(batch_embeddings).squeeze() |
| loss = criterion(outputs, batch_labels) |
| loss.backward() |
| |
| |
| torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clip) |
| |
| optimizer.step() |
| |
| train_loss += loss.item() |
| predicted = (outputs > 0.5).float() |
| train_correct += (predicted == batch_labels).sum().item() |
| train_total += batch_labels.size(0) |
| |
| train_loss /= len(train_loader) |
| train_acc = train_correct / train_total |
| |
| |
| model.eval() |
| val_loss = 0.0 |
| val_correct = 0 |
| val_total = 0 |
| val_predictions = [] |
| val_labels = [] |
| |
| with torch.no_grad(): |
| for batch_embeddings, batch_labels in val_loader: |
| batch_embeddings = batch_embeddings.to(device) |
| batch_labels = batch_labels.to(device).float() |
| |
| outputs = model(batch_embeddings).squeeze() |
| loss = criterion(outputs, batch_labels) |
| |
| val_loss += loss.item() |
| predicted = (outputs > 0.5).float() |
| val_correct += (predicted == batch_labels).sum().item() |
| val_total += batch_labels.size(0) |
| |
| val_predictions.extend(outputs.cpu().numpy()) |
| val_labels.extend(batch_labels.cpu().numpy()) |
| |
| val_loss /= len(val_loader) |
| val_acc = val_correct / val_total |
| val_auc = roc_auc_score(val_labels, val_predictions) |
| |
| |
| scheduler.step(val_loss) |
| |
| |
| train_losses.append(train_loss) |
| val_losses.append(val_loss) |
| val_aucs.append(val_auc) |
| |
| |
| if val_auc > best_val_auc: |
| best_val_auc = val_auc |
| best_epoch = epoch |
| epochs_without_improvement = 0 |
| |
| |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'epoch': epoch, |
| 'val_auc': val_auc, |
| 'val_loss': val_loss, |
| 'input_dim': model.fc1.in_features, |
| }, os.path.join(output_dir, 'best_model.pth')) |
| |
| print(f"New best model saved! Val AUC: {val_auc:.4f}") |
| else: |
| epochs_without_improvement += 1 |
| |
| |
| if (epoch + 1) % save_every == 0 or epoch == 0: |
| |
| epoch_model_path = os.path.join(output_dir, f"model_epoch_{epoch+1}.pth") |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'epoch': epoch + 1, |
| 'val_auc': val_auc, |
| 'val_loss': val_loss, |
| 'input_dim': model.fc1.in_features, |
| }, epoch_model_path) |
| |
| |
| test_results = evaluate_model_simple(model, test_loader, device) |
| test_results_by_epoch[epoch + 1] = test_results |
| |
| print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " |
| f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, " |
| f"Test AUC: {test_results['auc']:.4f}") |
| |
| |
| elif (epoch + 1) % 10 == 0: |
| current_lr = optimizer.param_groups[0]['lr'] |
| print(f"Epoch {epoch+1:3d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, " |
| f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val AUC: {val_auc:.4f}, " |
| f"LR: {current_lr:.2e}") |
| |
| |
| if epochs_without_improvement >= patience: |
| print(f"Early stopping at epoch {epoch+1} (no improvement for {patience} epochs)") |
| break |
| |
| print(f"Training completed! Best validation AUC: {best_val_auc:.4f} at epoch {best_epoch+1}") |
| |
| |
| checkpoint = torch.load(os.path.join(output_dir, 'best_model.pth'), map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| |
| |
| model.eval() |
| test_predictions = [] |
| test_labels = [] |
| test_loss = 0.0 |
| test_correct = 0 |
| test_total = 0 |
| |
| with torch.no_grad(): |
| for batch_embeddings, batch_labels in test_loader: |
| batch_embeddings = batch_embeddings.to(device) |
| batch_labels = batch_labels.to(device).float() |
| |
| outputs = model(batch_embeddings).squeeze() |
| loss = criterion(outputs, batch_labels) |
| |
| test_loss += loss.item() |
| predicted = (outputs > 0.5).float() |
| test_correct += (predicted == batch_labels).sum().item() |
| test_total += batch_labels.size(0) |
| |
| test_predictions.extend(outputs.cpu().numpy()) |
| test_labels.extend(batch_labels.cpu().numpy()) |
| |
| test_loss /= len(test_loader) |
| test_acc = test_correct / test_total |
| test_auc = roc_auc_score(test_labels, test_predictions) |
| |
| |
| precision, recall, f1, _ = precision_recall_fscore_support(test_labels, [1 if p > 0.5 else 0 for p in test_predictions], average='binary') |
| cm = confusion_matrix(test_labels, [1 if p > 0.5 else 0 for p in test_predictions]) |
| |
| |
| results = { |
| 'test_auc': float(test_auc), |
| 'test_accuracy': float(test_acc), |
| 'test_loss': float(test_loss), |
| 'test_precision': float(precision), |
| 'test_recall': float(recall), |
| 'test_f1': float(f1), |
| 'confusion_matrix': cm.tolist(), |
| 'best_val_auc': float(best_val_auc), |
| 'best_epoch': int(best_epoch + 1), |
| 'total_epochs': int(epoch + 1), |
| 'sequence_type': 'forward_and_reverse_complement', |
| 'predictions': [float(p) for p in test_predictions], |
| 'labels': [float(l) for l in test_labels] |
| } |
| |
| with open(os.path.join(output_dir, 'test_results.json'), 'w') as f: |
| json.dump(results, f, indent=2) |
| |
| |
| history = { |
| 'train_losses': train_losses, |
| 'val_losses': val_losses, |
| 'val_aucs': val_aucs, |
| 'best_epoch': best_epoch + 1, |
| 'best_val_auc': best_val_auc |
| } |
| |
| with open(os.path.join(output_dir, 'training_history.json'), 'w') as f: |
| json.dump(history, f, indent=2) |
| |
| |
| plt.figure(figsize=(15, 5)) |
| |
| plt.subplot(1, 3, 1) |
| plt.plot(train_losses, label='Train Loss') |
| plt.plot(val_losses, label='Val Loss') |
| plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') |
| plt.xlabel('Epoch') |
| plt.ylabel('Loss') |
| plt.title('Training and Validation Loss') |
| plt.legend() |
| plt.grid(True, alpha=0.3) |
| |
| plt.subplot(1, 3, 2) |
| plt.plot(val_aucs, label='Val AUC', color='green') |
| plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') |
| plt.xlabel('Epoch') |
| plt.ylabel('AUC') |
| plt.title('Validation AUC') |
| plt.legend() |
| plt.grid(True, alpha=0.3) |
| |
| plt.subplot(1, 3, 3) |
| plt.plot(range(len(train_losses)), train_losses, label='Train Loss') |
| plt.plot(range(len(val_losses)), val_losses, label='Val Loss') |
| plt.axvline(x=best_epoch, color='r', linestyle='--', alpha=0.7, label=f'Best Epoch ({best_epoch+1})') |
| plt.xlabel('Epoch') |
| plt.ylabel('Loss') |
| plt.title('Loss Comparison') |
| plt.legend() |
| plt.grid(True, alpha=0.3) |
| |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, 'training_history.png'), dpi=300, bbox_inches='tight') |
| plt.close() |
| |
| print(f"\n=== Test Results ===") |
| print(f"Test AUC: {test_auc:.4f}") |
| print(f"Test Accuracy: {test_acc:.4f}") |
| print(f"Test Precision: {precision:.4f}") |
| print(f"Test Recall: {recall:.4f}") |
| print(f"Test F1: {f1:.4f}") |
| print(f"Confusion Matrix:\n{cm}") |
| |
| return results, test_results_by_epoch |
|
|
| def evaluate_model_simple(model, test_loader, device): |
| """Simple evaluation that returns just basic metrics.""" |
| model.eval() |
| test_preds = [] |
| test_labels = [] |
|
|
| with torch.no_grad(): |
| for batch_X, batch_y in test_loader: |
| batch_X = batch_X.to(device) |
| outputs = model(batch_X).squeeze() |
| test_preds.extend(outputs.cpu().numpy()) |
| test_labels.extend(batch_y.numpy()) |
|
|
| test_preds = np.array(test_preds) |
| test_labels = np.array(test_labels) |
|
|
| |
| test_auc = roc_auc_score(test_labels, test_preds) |
| test_preds_binary = (test_preds > 0.5).astype(int) |
| test_acc = accuracy_score(test_labels, test_preds_binary) |
| precision, recall, f1, _ = precision_recall_fscore_support( |
| test_labels, test_preds_binary, average="binary" |
| ) |
|
|
| return { |
| "auc": test_auc, |
| "accuracy": test_acc, |
| "precision": precision, |
| "recall": recall, |
| "f1": f1, |
| } |
|
|
| def save_epoch_analysis(test_results_by_epoch, output_dir): |
| """Save analysis of results across epochs.""" |
| epochs = sorted(test_results_by_epoch.keys()) |
|
|
| |
| summary_data = [] |
| for epoch in epochs: |
| results = test_results_by_epoch[epoch] |
| summary_data.append( |
| { |
| "epoch": epoch, |
| "test_auc": results["auc"], |
| "test_accuracy": results["accuracy"], |
| "test_precision": results["precision"], |
| "test_recall": results["recall"], |
| "test_f1": results["f1"], |
| } |
| ) |
|
|
| df = pd.DataFrame(summary_data) |
|
|
| |
| csv_path = os.path.join(output_dir, "epoch_analysis.csv") |
| df.to_csv(csv_path, index=False) |
|
|
| |
| json_path = os.path.join(output_dir, "epoch_analysis.json") |
| with open(json_path, "w") as f: |
| json.dump(test_results_by_epoch, f, indent=2) |
|
|
| |
| print("\n" + "=" * 50) |
| print("EPOCH-WISE TEST PERFORMANCE ANALYSIS") |
| print("=" * 50) |
|
|
| best_auc_epoch = df.loc[df["test_auc"].idxmax()] |
| best_f1_epoch = df.loc[df["test_f1"].idxmax()] |
|
|
| print( |
| f"Best Test AUC: {best_auc_epoch['test_auc']:.4f} at Epoch {best_auc_epoch['epoch']}" |
| ) |
| print( |
| f"Best Test F1: {best_f1_epoch['test_f1']:.4f} at Epoch {best_f1_epoch['epoch']}" |
| ) |
| print() |
| print("Epoch-wise Performance:") |
| print(df.to_string(index=False, float_format="%.4f")) |
|
|
| |
| if len(epochs) >= 2: |
| auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0] |
| if auc_trend < -0.01: |
| print( |
| f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}" |
| ) |
| elif auc_trend > 0.01: |
| print( |
| f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" |
| ) |
| else: |
| print( |
| f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" |
| ) |
|
|
| return df |
|
|
| def plot_training_history(train_losses, val_losses, val_aucs, output_dir): |
| """Plot training history.""" |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) |
|
|
| |
| axes[0].plot(train_losses, label="Train Loss") |
| axes[0].plot(val_losses, label="Val Loss") |
| axes[0].set_xlabel("Epoch") |
| axes[0].set_ylabel("Loss") |
| axes[0].set_title("Training and Validation Loss") |
| axes[0].legend() |
| axes[0].grid(True, alpha=0.3) |
|
|
| |
| axes[1].plot(val_aucs, label="Val AUC", color="green") |
| axes[1].set_xlabel("Epoch") |
| axes[1].set_ylabel("AUC") |
| axes[1].set_title("Validation AUC") |
| axes[1].legend() |
| axes[1].grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100) |
| plt.close() |
|
|
| def plot_confusion_matrix(cm, output_dir): |
| """Plot confusion matrix.""" |
| plt.figure(figsize=(6, 5)) |
| sns.heatmap( |
| cm, |
| annot=True, |
| fmt="d", |
| cmap="Blues", |
| xticklabels=["Non-binding", "TBX5-binding"], |
| yticklabels=["Non-binding", "TBX5-binding"], |
| ) |
| plt.title("Confusion Matrix") |
| plt.ylabel("True Label") |
| plt.xlabel("Predicted Label") |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100) |
| plt.close() |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings") |
| parser.add_argument( |
| "--embeddings-dir", |
| type=str, |
| default="tbx5_embeddings", |
| help="Directory containing forward embeddings (default: tbx5_embeddings)", |
| ) |
| parser.add_argument( |
| "--rc-embeddings-dir", |
| type=str, |
| default="tbx5_embeddings_reverse_complement", |
| help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="result_with_rc", |
| help="Output directory for results (default: result_with_rc)", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=32, |
| help="Batch size for training (default: 32)", |
| ) |
| parser.add_argument( |
| "--num-epochs", |
| type=int, |
| default=500, |
| help="Number of training epochs (default: 500)", |
| ) |
| parser.add_argument( |
| "--learning-rate", |
| type=float, |
| default=1e-4, |
| help="Learning rate (default: 1e-4)", |
| ) |
| parser.add_argument( |
| "--patience", |
| type=int, |
| default=100, |
| help="Early stopping patience (default: 100)", |
| ) |
| parser.add_argument( |
| "--dropout-rate", |
| type=float, |
| default=0.5, |
| help="Dropout rate (default: 0.5)", |
| ) |
| parser.add_argument( |
| "--processed-data-dir", |
| type=str, |
| default="processed_data_new", |
| help="Directory containing train/val/test CSV files (default: processed_data_new)", |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| |
| print("Loading combined embeddings using CSV splits...") |
| (X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, |
| X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, |
| X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, |
| metadata) = load_tbx5_embeddings_with_rc_from_csv( |
| args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir |
| ) |
| |
| |
| with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| |
| X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling( |
| X_train, X_val, X_test, y_train, y_val, y_test |
| ) |
| |
| |
| with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f: |
| pickle.dump(scaler, f) |
| |
| |
| train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train)) |
| val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val)) |
| test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test)) |
| |
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) |
| test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) |
| |
| |
| input_dim = X_train_scaled.shape[1] |
| print(f"Input dimension: {input_dim}") |
| |
| model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device) |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total parameters: {total_params:,}") |
| print(f"Trainable parameters: {trainable_params:,}") |
| |
| |
| results, test_results_by_epoch = train_model( |
| model, train_loader, val_loader, test_loader, device, args.output_dir, |
| num_epochs=args.num_epochs, |
| learning_rate=args.learning_rate, |
| patience=args.patience, |
| ) |
| |
| |
| save_epoch_analysis(test_results_by_epoch, args.output_dir) |
| |
| |
| plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir) |
| plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir) |
| |
| print(f"\nTraining completed! Results saved to {args.output_dir}") |
| print(f"Best test AUC: {results['test_auc']:.4f}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|
|
|
|
|
|
|
| |
| if len(epochs) >= 2: |
| auc_trend = df["test_auc"].iloc[-1] - df["test_auc"].iloc[0] |
| if auc_trend < -0.01: |
| print( |
| f"\n⚠️ OVERFITTING DETECTED: Test AUC decreased by {abs(auc_trend):.4f} from epoch {epochs[0]} to {epochs[-1]}" |
| ) |
| elif auc_trend > 0.01: |
| print( |
| f"\n✅ GOOD TRAINING: Test AUC improved by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" |
| ) |
| else: |
| print( |
| f"\n📊 STABLE TRAINING: Test AUC changed by {auc_trend:.4f} from epoch {epochs[0]} to {epochs[-1]}" |
| ) |
|
|
| return df |
|
|
| def plot_training_history(train_losses, val_losses, val_aucs, output_dir): |
| """Plot training history.""" |
| fig, axes = plt.subplots(1, 2, figsize=(12, 4)) |
|
|
| |
| axes[0].plot(train_losses, label="Train Loss") |
| axes[0].plot(val_losses, label="Val Loss") |
| axes[0].set_xlabel("Epoch") |
| axes[0].set_ylabel("Loss") |
| axes[0].set_title("Training and Validation Loss") |
| axes[0].legend() |
| axes[0].grid(True, alpha=0.3) |
|
|
| |
| axes[1].plot(val_aucs, label="Val AUC", color="green") |
| axes[1].set_xlabel("Epoch") |
| axes[1].set_ylabel("AUC") |
| axes[1].set_title("Validation AUC") |
| axes[1].legend() |
| axes[1].grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, "training_history.png"), dpi=100) |
| plt.close() |
|
|
| def plot_confusion_matrix(cm, output_dir): |
| """Plot confusion matrix.""" |
| plt.figure(figsize=(6, 5)) |
| sns.heatmap( |
| cm, |
| annot=True, |
| fmt="d", |
| cmap="Blues", |
| xticklabels=["Non-binding", "TBX5-binding"], |
| yticklabels=["Non-binding", "TBX5-binding"], |
| ) |
| plt.title("Confusion Matrix") |
| plt.ylabel("True Label") |
| plt.xlabel("Predicted Label") |
| plt.tight_layout() |
| plt.savefig(os.path.join(output_dir, "confusion_matrix.png"), dpi=100) |
| plt.close() |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Train TBX5 classifier with forward and reverse complement embeddings") |
| parser.add_argument( |
| "--embeddings-dir", |
| type=str, |
| default="tbx5_embeddings", |
| help="Directory containing forward embeddings (default: tbx5_embeddings)", |
| ) |
| parser.add_argument( |
| "--rc-embeddings-dir", |
| type=str, |
| default="tbx5_embeddings_reverse_complement", |
| help="Directory containing reverse complement embeddings (default: tbx5_embeddings_reverse_complement)", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="result_with_rc", |
| help="Output directory for results (default: result_with_rc)", |
| ) |
| parser.add_argument( |
| "--batch-size", |
| type=int, |
| default=32, |
| help="Batch size for training (default: 32)", |
| ) |
| parser.add_argument( |
| "--num-epochs", |
| type=int, |
| default=500, |
| help="Number of training epochs (default: 500)", |
| ) |
| parser.add_argument( |
| "--learning-rate", |
| type=float, |
| default=1e-4, |
| help="Learning rate (default: 1e-4)", |
| ) |
| parser.add_argument( |
| "--patience", |
| type=int, |
| default=100, |
| help="Early stopping patience (default: 100)", |
| ) |
| parser.add_argument( |
| "--dropout-rate", |
| type=float, |
| default=0.5, |
| help="Dropout rate (default: 0.5)", |
| ) |
| parser.add_argument( |
| "--processed-data-dir", |
| type=str, |
| default="processed_data_new", |
| help="Directory containing train/val/test CSV files (default: processed_data_new)", |
| ) |
| |
| args = parser.parse_args() |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| |
| |
| print("Loading combined embeddings using CSV splits...") |
| (X_train, y_train, starts_train, ends_train, tbx5_scores_train, chromosomes_train, |
| X_val, y_val, starts_val, ends_val, tbx5_scores_val, chromosomes_val, |
| X_test, y_test, starts_test, ends_test, tbx5_scores_test, chromosomes_test, |
| metadata) = load_tbx5_embeddings_with_rc_from_csv( |
| args.embeddings_dir, args.rc_embeddings_dir, args.processed_data_dir |
| ) |
| |
| |
| with open(os.path.join(args.output_dir, 'metadata.json'), 'w') as f: |
| json.dump(metadata, f, indent=2) |
| |
| |
| X_train_scaled, X_val_scaled, X_test_scaled, scaler = prepare_data_with_scaling( |
| X_train, X_val, X_test, y_train, y_val, y_test |
| ) |
| |
| |
| with open(os.path.join(args.output_dir, 'scaler.pkl'), 'wb') as f: |
| pickle.dump(scaler, f) |
| |
| |
| train_dataset = TensorDataset(torch.FloatTensor(X_train_scaled), torch.LongTensor(y_train)) |
| val_dataset = TensorDataset(torch.FloatTensor(X_val_scaled), torch.LongTensor(y_val)) |
| test_dataset = TensorDataset(torch.FloatTensor(X_test_scaled), torch.LongTensor(y_test)) |
| |
| train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True) |
| val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False) |
| test_loader = DataLoader(test_dataset, batch_size=args.batch_size, shuffle=False) |
| |
| |
| input_dim = X_train_scaled.shape[1] |
| print(f"Input dimension: {input_dim}") |
| |
| model = TBX5ClassifierWithRC(input_dim=input_dim, dropout_rate=args.dropout_rate).to(device) |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"Total parameters: {total_params:,}") |
| print(f"Trainable parameters: {trainable_params:,}") |
| |
| |
| results, test_results_by_epoch = train_model( |
| model, train_loader, val_loader, test_loader, device, args.output_dir, |
| num_epochs=args.num_epochs, |
| learning_rate=args.learning_rate, |
| patience=args.patience, |
| ) |
| |
| |
| save_epoch_analysis(test_results_by_epoch, args.output_dir) |
| |
| |
| plot_training_history(results.get('train_losses', []), results.get('val_losses', []), results.get('val_aucs', []), args.output_dir) |
| plot_confusion_matrix(np.array(results['confusion_matrix']), args.output_dir) |
| |
| print(f"\nTraining completed! Results saved to {args.output_dir}") |
| print(f"Best test AUC: {results['test_auc']:.4f}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|
|
|
|
|
|
|