algaGPT / llm-metrics-two-files.py
GreenGenomicsLab's picture
Upload 30 files
261b39f verified
#!/usr/bin/env python3
"""
LLM Classification Metrics Generator for Two-File Analysis
This script analyzes the LLM classification results from two separate files:
- One containing algal sequences (true algal samples)
- One containing contaminant sequences (true contaminant samples)
It extracts the predicted tags and calculates comprehensive metrics.
"""
import re
import sys
import argparse
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_fscore_support, accuracy_score, confusion_matrix
from sklearn.metrics import classification_report
def parse_files(algal_file, contaminant_file):
"""
Parse the algal and contaminant files to extract true and predicted labels
Arguments:
algal_file (str): Path to the file containing algal sequences
contaminant_file (str): Path to the file containing contaminant sequences
Returns:
tuple: Lists of true labels and predicted labels
"""
true_labels = []
predicted_labels = []
sequence_ids = []
# Process algal file (all true labels are 'algal')
with open(algal_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
# Skip header or non-data lines
if line.startswith('==>') or line.startswith('('): ## or not re.search(r'-|_', line):
continue
# Extract sequence ID
seq_id_match = re.match(r'^([^\s]+)', line)
if seq_id_match:
seq_id = seq_id_match.group(1)
else:
seq_id = "unknown_id"
# Add to tracking lists
true_labels.append('algal')
sequence_ids.append(seq_id)
# Determine predicted label based on tags
if '@' in line:
predicted_labels.append('algal')
elif '!' in line:
predicted_labels.append('contaminant')
else:
predicted_labels.append('unknown')
#if re.search(r'<@+>', line):
# predicted_labels.append('algal')
#elif re.search(r'<!+>', line):
# predicted_labels.append('contaminant')
#else:
# predicted_labels.append('unknown')
# Process contaminant file (all true labels are 'contaminant')
with open(contaminant_file, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
# Skip header or non-data lines
if line.startswith('==>') or line.startswith('('): ## or not re.search(r'\.|_', line):
continue
# Extract sequence ID
seq_id_match = re.match(r'^([^\s]+)', line)
if seq_id_match:
seq_id = seq_id_match.group(1)
else:
seq_id = "unknown_id"
# Add to tracking lists
true_labels.append('contaminant')
sequence_ids.append(seq_id)
# Determine predicted label based on tags
# if re.search(r'<@+>', line):
# predicted_labels.append('algal')
#elif re.search(r'<!+>', line):
# predicted_labels.append('contaminant')
#e#lse:
# predicted_labels.append('unknown')
# Determine predicted label based on symbols (@ for algal, ! for contaminant)
if '@' in line:
predicted_labels.append('algal')
elif '!' in line:
predicted_labels.append('contaminant')
else:
predicted_labels.append('unknown')
return true_labels, predicted_labels, sequence_ids
def calculate_metrics(true_labels, predicted_labels):
"""
Calculate comprehensive classification metrics
Arguments:
true_labels (list): List of true class labels
predicted_labels (list): List of predicted class labels
Returns:
dict: Dictionary containing all calculated metrics
"""
# Convert labels for sklearn functions
classes = ['algal', 'contaminant']
label_map = {label: i for i, label in enumerate(classes)}
# Convert to numeric form
true_numeric = np.array([label_map.get(label, 2) for label in true_labels])
pred_numeric = np.array([label_map.get(label, 2) for label in predicted_labels])
# Filter out unknowns for main metrics
known_indices = [i for i, pred in enumerate(predicted_labels) if pred != 'unknown']
true_known = [true_labels[i] for i in known_indices]
pred_known = [predicted_labels[i] for i in known_indices]
# Overall accuracy (including unknowns as wrong predictions)
accuracy = sum(t == p for t, p in zip(true_labels, predicted_labels)) / len(true_labels)
if true_known and pred_known:
# Convert to numeric
true_known_numeric = np.array([label_map[label] for label in true_known])
pred_known_numeric = np.array([label_map[label] for label in pred_known])
# Calculate precision, recall, and F1 (excluding unknowns)
precision, recall, f1, support = precision_recall_fscore_support(
true_known_numeric,
pred_known_numeric,
labels=[0, 1], # algal, contaminant
zero_division=0
)
# Create confusion matrix
cm = confusion_matrix(
true_known_numeric,
pred_known_numeric,
labels=[0, 1]
)
# Full classification report
report = classification_report(
true_known_numeric,
pred_known_numeric,
labels=[0, 1],
target_names=classes,
output_dict=True
)
else:
precision = recall = f1 = support = [0, 0]
cm = np.zeros((2, 2))
report = {}
# Count occurrences and calculate per-class metrics
class_metrics = {}
for class_name in classes:
class_indices = [i for i, label in enumerate(true_labels) if label == class_name]
total = len(class_indices)
if total == 0:
class_metrics[class_name] = {
"total": 0,
"correct": 0,
"incorrect": 0,
"unknown": 0,
"accuracy": 0,
"error_rate": 0
}
continue
correct = sum(1 for i in class_indices if predicted_labels[i] == class_name)
unknown = sum(1 for i in class_indices if predicted_labels[i] == "unknown")
incorrect = total - correct - unknown
class_metrics[class_name] = {
"total": total,
"correct": correct,
"incorrect": incorrect,
"unknown": unknown,
"accuracy": correct / total if total > 0 else 0,
"error_rate": (incorrect + unknown) / total if total > 0 else 0
}
# Compile all metrics
metrics = {
"accuracy": accuracy,
"class_metrics": class_metrics,
"confusion_matrix": cm,
"precision": {classes[i]: precision[i] for i in range(len(classes))},
"recall": {classes[i]: recall[i] for i in range(len(classes))},
"f1": {classes[i]: f1[i] for i in range(len(classes))},
"support": {classes[i]: support[i] for i in range(len(classes))},
"classification_report": report,
"macro_f1": np.mean(f1),
"weighted_f1": np.sum(f1 * support) / np.sum(support) if np.sum(support) > 0 else 0,
"total_samples": len(true_labels),
"total_correct": sum(t == p for t, p in zip(true_labels, predicted_labels)),
"total_unknown": predicted_labels.count("unknown")
}
return metrics
def display_results(metrics, output_file=None):
"""
Display comprehensive results and optionally save to file
Arguments:
metrics (dict): Dictionary containing all calculated metrics
output_file (str, optional): Path to save results to
"""
# Start capturing output if needed
if output_file:
import io
output_capture = io.StringIO()
original_stdout = sys.stdout
sys.stdout = output_capture
# Print header
print("\n" + "="*60)
print(" LLM CLASSIFICATION METRICS REPORT")
print("="*60)
# Overall metrics
print("\n=== OVERALL METRICS ===")
print(f"Total samples: {metrics['total_samples']}")
print(f"Correctly classified: {metrics['total_correct']} ({metrics['total_correct']/metrics['total_samples']*100:.2f}%)")
print(f"Unknown predictions: {metrics['total_unknown']} ({metrics['total_unknown']/metrics['total_samples']*100:.2f}%)")
print(f"Overall accuracy: {metrics['accuracy']:.4f}")
print(f"Macro F1: {metrics['macro_f1']:.4f}")
print(f"Weighted F1: {metrics['weighted_f1']:.4f}")
# Confusion matrix
cm = metrics["confusion_matrix"]
class_labels = ["Algal", "Bacterial"]
print("\n=== CONFUSION MATRIX ===")
print(f"{'':15} | {'Predicted Algal':15} | {'Predicted Bacterial':20}")
print("-" * 55)
for i, label in enumerate(class_labels):
print(f"{label:15} | {int(cm[i][0]):15} | {int(cm[i][1]):20}")
# Per-class metrics
print("\n=== PER-CLASS METRICS ===")
print(f"{'Class':10} | {'Precision':10} | {'Recall':10} | {'F1 Score':10} | {'Support':10}")
print("-" * 60)
for class_name in ['algal', 'contaminant']:
precision = metrics['precision'][class_name]
recall = metrics['recall'][class_name]
f1 = metrics['f1'][class_name]
support = metrics['support'][class_name]
print(f"{class_name.capitalize():10} | {precision:.4f} | {recall:.4f} | {f1:.4f} | {int(support):10}")
# Detailed class counts
print("\n=== DETAILED CLASS COUNTS ===")
for class_name, class_data in metrics["class_metrics"].items():
print(f"{class_name.capitalize()} class:")
print(f" Total samples: {class_data['total']}")
if class_data['total'] > 0:
print(f" Correctly classified: {class_data['correct']} ({class_data['correct']/class_data['total']*100:.2f}%)")
print(f" Incorrectly classified: {class_data['incorrect']} ({class_data['incorrect']/class_data['total']*100:.2f}%)")
print(f" Unknown: {class_data['unknown']} ({class_data['unknown']/class_data['total']*100:.2f}%)")
print()
# If saving to file
if output_file:
# Restore stdout
sys.stdout = original_stdout
# Write to file
with open(output_file, 'w') as f:
f.write(output_capture.getvalue())
print(f"Results saved to {output_file}")
def generate_visualizations(metrics, output_prefix=None):
"""
Generate visualizations of the metrics
Arguments:
metrics (dict): Dictionary containing all calculated metrics
output_prefix (str, optional): Prefix for output image files
"""
# Create confusion matrix heatmap
plt.figure(figsize=(8, 6))
cm = metrics["confusion_matrix"]
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
classes = ["Algal", "Bacterial"]
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
# Add text annotations
thresh = cm.max() / 2.0
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(int(cm[i, j]), 'd'),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
if output_prefix:
plt.savefig(f"{output_prefix}_confusion_matrix.png", dpi=300, bbox_inches='tight')
else:
plt.show()
# Create per-class metrics bar chart
plt.figure(figsize=(10, 6))
metrics_names = ['Precision', 'Recall', 'F1-Score']
x = np.arange(len(metrics_names))
width = 0.35
algal_values = [metrics['precision']['algal'], metrics['recall']['algal'], metrics['f1']['algal']]
contaminant_values = [metrics['precision']['contaminant'], metrics['recall']['contaminant'], metrics['f1']['contaminant']]
plt.bar(x - width/2, algal_values, width, label='Algal')
plt.bar(x + width/2, contaminant_values, width, label='Bacterial')
plt.ylabel('Score')
plt.title('Performance Metrics by Class')
plt.xticks(x, metrics_names)
plt.ylim(0, 1.1)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)
if output_prefix:
plt.savefig(f"{output_prefix}_metrics_by_class.png", dpi=300, bbox_inches='tight')
else:
plt.show()
# Create class distribution pie charts
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
# Algal class distribution
algal_data = metrics['class_metrics']['algal']
algal_labels = ['Correct', 'Incorrect', 'Unknown']
algal_values = [algal_data['correct'], algal_data['incorrect'], algal_data['unknown']]
ax1.pie(algal_values, labels=algal_labels, autopct='%1.1f%%', startangle=90)
ax1.set_title('Algal Class Predictions')
# Bacterial class distribution
contaminant_data = metrics['class_metrics']['contaminant']
contaminant_labels = ['Correct', 'Incorrect', 'Unknown']
contaminant_values = [contaminant_data['correct'], contaminant_data['incorrect'], contaminant_data['unknown']]
ax2.pie(contaminant_values, labels=contaminant_labels, autopct='%1.1f%%', startangle=90)
ax2.set_title('Bacterial Class Predictions')
plt.tight_layout()
if output_prefix:
plt.savefig(f"{output_prefix}_class_distribution.png", dpi=300, bbox_inches='tight')
else:
plt.show()
def create_misclassified_report(true_labels, predicted_labels, sequence_ids, output_file=None):
"""
Create a report of misclassified sequences
Arguments:
true_labels (list): List of true class labels
predicted_labels (list): List of predicted class labels
sequence_ids (list): List of sequence IDs
output_file (str, optional): Path to save the report to
"""
misclassified = []
for i, (true, pred, seq_id) in enumerate(zip(true_labels, predicted_labels, sequence_ids)):
if true != pred:
misclassified.append({
'id': seq_id,
'true': true,
'predicted': pred
})
# Start capturing output
if output_file:
import io
output_capture = io.StringIO()
original_stdout = sys.stdout
sys.stdout = output_capture
# Print header
print("\n" + "="*60)
print(" MISCLASSIFIED SEQUENCES REPORT")
print("="*60)
print(f"\nTotal misclassified: {len(misclassified)} out of {len(true_labels)} ({len(misclassified)/len(true_labels)*100:.2f}%)\n")
# Print algal sequences misclassified as contaminant
print("\n--- ALGAL SEQUENCES MISCLASSIFIED AS BACTERIAL ---")
algal_as_contaminant = [m for m in misclassified if m['true'] == 'algal' and m['predicted'] == 'contaminant']
for item in algal_as_contaminant:
print(f"ID: {item['id']}")
print(f"Total: {len(algal_as_contaminant)}")
# Print contaminant sequences misclassified as algal
print("\n--- BACTERIAL SEQUENCES MISCLASSIFIED AS ALGAL ---")
contaminant_as_algal = [m for m in misclassified if m['true'] == 'contaminant' and m['predicted'] == 'algal']
for item in contaminant_as_algal:
print(f"ID: {item['id']}")
print(f"Total: {len(contaminant_as_algal)}")
# Print unknown classifications
print("\n--- SEQUENCES WITH UNKNOWN CLASSIFICATION ---")
unknown = [m for m in misclassified if m['predicted'] == 'unknown']
for item in unknown:
print(f"ID: {item['id']} (True: {item['true']})")
print(f"Total: {len(unknown)}")
# If saving to file
if output_file:
# Restore stdout
sys.stdout = original_stdout
# Write to file
with open(output_file, 'w') as f:
f.write(output_capture.getvalue())
print(f"Misclassified report saved to {output_file}")
def main():
"""Main function to run the script"""
parser = argparse.ArgumentParser(description='LLM Classification Metrics Generator for Two-File Analysis')
parser.add_argument('algal_file', help='Path to the file containing algal sequences')
parser.add_argument('contaminant_file', help='Path to the file containing contaminant sequences')
parser.add_argument('-o', '--output', help='Path to save the metrics report')
parser.add_argument('-m', '--misclassified', help='Path to save the misclassified sequences report')
parser.add_argument('-v', '--visualize', action='store_true', help='Generate visualizations')
parser.add_argument('-p', '--prefix', default='llm_metrics', help='Prefix for output files')
args = parser.parse_args()
# Parse files and calculate metrics
true_labels, predicted_labels, sequence_ids = parse_files(args.algal_file, args.contaminant_file)
metrics = calculate_metrics(true_labels, predicted_labels)
# Display results
output_file = f"{args.prefix}_report.txt" if args.output else None
display_results(metrics, output_file)
# Generate visualizations if requested
if args.visualize:
generate_visualizations(metrics, args.prefix)
# Create misclassified report if requested
if args.misclassified:
misclassified_file = f"{args.prefix}_misclassified.txt" if args.misclassified is True else args.misclassified
create_misclassified_report(true_labels, predicted_labels, sequence_ids, misclassified_file)
# Return number of misclassifications (for automated testing)
misclassifications = sum(t != p for t, p in zip(true_labels, predicted_labels))
return misclassifications
if __name__ == "__main__":
sys.exit(main())