| import torch |
| import numpy as np |
| import json |
| import os |
| from tqdm import tqdm |
| import warnings |
| from datetime import datetime |
| warnings.filterwarnings('ignore') |
|
|
| |
| from generate_amps import AMPGenerator |
| from compressor_with_embeddings import Compressor, Decompressor |
| from final_sequence_decoder import EmbeddingToSequenceConverter |
|
|
| |
| try: |
| from local_apex_wrapper import LocalAPEXWrapper |
| APEX_AVAILABLE = True |
| except ImportError as e: |
| print(f"Warning: Local APEX not available: {e}") |
| APEX_AVAILABLE = False |
|
|
| class PeptideTester: |
| """ |
| Generate peptides and test them using APEX for antimicrobial activity. |
| """ |
| |
| def __init__(self, model_path='amp_flow_model_final.pth', device='cuda'): |
| self.device = device |
| self.model_path = model_path |
| |
| |
| print("Initializing peptide generator...") |
| self.generator = AMPGenerator(model_path, device) |
| |
| |
| print("Initializing embedding to sequence converter...") |
| self.converter = EmbeddingToSequenceConverter(device) |
| |
| |
| if APEX_AVAILABLE: |
| print("Initializing local APEX predictor...") |
| self.apex = LocalAPEXWrapper() |
| print("✓ Local APEX loaded successfully!") |
| else: |
| self.apex = None |
| print("⚠ Local APEX not available - will only generate sequences") |
| |
| def generate_peptides(self, num_samples=100, num_steps=25, batch_size=32): |
| """ |
| Generate peptide sequences using the trained flow model. |
| """ |
| print(f"\n=== Generating {num_samples} Peptide Sequences ===") |
| |
| |
| generated_embeddings = self.generator.generate_amps( |
| num_samples=num_samples, |
| num_steps=num_steps, |
| batch_size=batch_size |
| ) |
| |
| print(f"Generated embeddings shape: {generated_embeddings.shape}") |
| |
| |
| sequences = self.converter.batch_embedding_to_sequences(generated_embeddings) |
| |
| |
| sequences = self.converter.filter_valid_sequences(sequences) |
| |
| return sequences |
| |
|
|
| |
| def test_with_apex(self, sequences): |
| """ |
| Test generated sequences using APEX for antimicrobial activity. |
| """ |
| if not APEX_AVAILABLE: |
| print("⚠ APEX not available - skipping activity prediction") |
| return None |
| |
| print(f"\n=== Testing {len(sequences)} Sequences with APEX ===") |
| |
| results = [] |
| |
| for i, seq in tqdm(enumerate(sequences), desc="Testing with APEX"): |
| try: |
| |
| avg_mic = self.apex.predict_single(seq) |
| is_amp = self.apex.is_amp(seq, threshold=32.0) |
| |
| result = { |
| 'sequence': seq, |
| 'sequence_id': f'generated_{i:04d}', |
| 'apex_score': avg_mic, |
| 'is_amp': is_amp, |
| 'length': len(seq) |
| } |
| results.append(result) |
| |
| except Exception as e: |
| print(f"Error testing sequence {i}: {e}") |
| continue |
| |
| return results |
| |
| def analyze_results(self, results): |
| """ |
| Analyze the results of APEX testing. |
| """ |
| if not results: |
| print("No results to analyze") |
| return |
| |
| print(f"\n=== Analysis of {len(results)} Generated Peptides ===") |
| |
| |
| scores = [r['apex_score'] for r in results] |
| amp_count = sum(1 for r in results if r['is_amp']) |
| |
| print(f"Total sequences tested: {len(results)}") |
| print(f"Predicted AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") |
| print(f"Average MIC: {np.mean(scores):.2f} μg/mL") |
| print(f"MIC range: {np.min(scores):.2f} - {np.max(scores):.2f} μg/mL") |
| print(f"MIC std: {np.std(scores):.2f} μg/mL") |
| |
| |
| top_candidates = sorted(results, key=lambda x: x['apex_score'], reverse=True)[:10] |
| |
| print(f"\n=== Top 10 Candidates ===") |
| for i, candidate in enumerate(top_candidates): |
| print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " |
| f"Length: {candidate['length']:2d} | " |
| f"Sequence: {candidate['sequence']}") |
| |
| return results |
| |
| def save_results(self, results, filename='generated_peptides_results.json'): |
| """ |
| Save results to JSON file. |
| """ |
| if not results: |
| print("No results to save") |
| return |
| |
| output = { |
| 'metadata': { |
| 'model_path': self.model_path, |
| 'num_sequences': len(results), |
| 'generation_timestamp': str(torch.cuda.Event() if torch.cuda.is_available() else 'cpu'), |
| 'apex_available': APEX_AVAILABLE |
| }, |
| 'results': results |
| } |
| |
| with open(filename, 'w') as f: |
| json.dump(output, f, indent=2) |
| |
| print(f"✓ Results saved to {filename}") |
| |
| def run_full_pipeline(self, num_samples=100, save_results=True): |
| """ |
| Run the complete pipeline: generate peptides and test with APEX. |
| """ |
| print("🚀 Starting Full Peptide Generation and Testing Pipeline") |
| print("=" * 60) |
| |
| |
| sequences = self.generate_peptides(num_samples=num_samples) |
| |
| |
| results = self.test_with_apex(sequences) |
| |
| |
| if results: |
| self.analyze_results(results) |
| |
| |
| if save_results: |
| self.save_results(results) |
| |
| return results |
|
|
| def main(): |
| """ |
| Main function to test existing decoded sequence files with APEX. |
| """ |
| print("🧬 AMP Flow Model - Testing Decoded Sequences with APEX") |
| print("=" * 60) |
| |
| |
| if not APEX_AVAILABLE: |
| print("❌ Local APEX not available - cannot test sequences") |
| print("Please ensure local_apex_wrapper.py is properly set up") |
| return |
| |
| |
| print("Initializing APEX predictor...") |
| apex = LocalAPEXWrapper() |
| print("✓ Local APEX loaded successfully!") |
| |
| |
| today = datetime.now().strftime('%Y%m%d') |
| |
| |
| cfg_files = { |
| 'No CFG (0.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_no_cfg_00_{today}.txt', |
| 'Weak CFG (3.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_weak_cfg_30_{today}.txt', |
| 'Strong CFG (7.5)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_strong_cfg_75_{today}.txt', |
| 'Very Strong CFG (15.0)': f'/data2/edwardsun/decoded_sequences/decoded_sequences_very_strong_cfg_150_{today}.txt' |
| } |
| |
| all_results = {} |
| |
| for cfg_name, file_path in cfg_files.items(): |
| print(f"\n{'='*60}") |
| print(f"Testing {cfg_name} sequences...") |
| print(f"Loading: {file_path}") |
| |
| if not os.path.exists(file_path): |
| print(f"❌ File not found: {file_path}") |
| continue |
| |
| |
| sequences = [] |
| with open(file_path, 'r') as f: |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith('#') and '\t' in line: |
| |
| parts = line.split('\t') |
| if len(parts) >= 2: |
| seq = parts[1].strip() |
| if seq and len(seq) > 0: |
| sequences.append(seq) |
| |
| print(f"✓ Loaded {len(sequences)} sequences from {file_path}") |
| |
| |
| results = [] |
| print(f"Testing {len(sequences)} sequences with APEX...") |
| |
| for i, seq in tqdm(enumerate(sequences), desc=f"Testing {cfg_name}"): |
| try: |
| |
| avg_mic = apex.predict_single(seq) |
| is_amp = apex.is_amp(seq, threshold=32.0) |
| |
| result = { |
| 'sequence': seq, |
| 'sequence_id': f'{cfg_name.lower().replace(" ", "_").replace("(", "").replace(")", "").replace(".", "")}_{i:03d}', |
| 'cfg_setting': cfg_name, |
| 'apex_score': avg_mic, |
| 'is_amp': is_amp, |
| 'length': len(seq) |
| } |
| results.append(result) |
| |
| except Exception as e: |
| print(f"Warning: Error testing sequence {i}: {e}") |
| continue |
| |
| |
| if results: |
| print(f"\n=== Analysis of {cfg_name} ===") |
| scores = [r['apex_score'] for r in results] |
| amp_count = sum(1 for r in results if r['is_amp']) |
| |
| print(f"Total sequences tested: {len(results)}") |
| print(f"Predicted AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") |
| print(f"Average MIC: {np.mean(scores):.2f} μg/mL") |
| print(f"MIC range: {np.min(scores):.2f} - {np.max(scores):.2f} μg/mL") |
| print(f"MIC std: {np.std(scores):.2f} μg/mL") |
| |
| |
| top_candidates = sorted(results, key=lambda x: x['apex_score'])[:5] |
| |
| print(f"\n=== Top 5 Candidates ({cfg_name}) ===") |
| for i, candidate in enumerate(top_candidates): |
| print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " |
| f"Length: {candidate['length']:2d} | " |
| f"Sequence: {candidate['sequence']}") |
| |
| all_results[cfg_name] = results |
| |
| |
| output_dir = '/data2/edwardsun/apex_results' |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| output_file = os.path.join(output_dir, f"apex_results_{cfg_name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '')}_{today}.json") |
| with open(output_file, 'w') as f: |
| json.dump({ |
| 'metadata': { |
| 'cfg_setting': cfg_name, |
| 'num_sequences': len(results), |
| 'apex_available': APEX_AVAILABLE |
| }, |
| 'results': results |
| }, f, indent=2) |
| print(f"✓ Results saved to {output_file}") |
| |
| |
| print(f"\n{'='*60}") |
| print("OVERALL COMPARISON ACROSS CFG SETTINGS") |
| print(f"{'='*60}") |
| |
| for cfg_name, results in all_results.items(): |
| if results: |
| scores = [r['apex_score'] for r in results] |
| amp_count = sum(1 for r in results if r['is_amp']) |
| print(f"\n{cfg_name}:") |
| print(f" Total: {len(results)} | AMPs: {amp_count} ({amp_count/len(results)*100:.1f}%)") |
| print(f" Avg MIC: {np.mean(scores):.2f} μg/mL | Best MIC: {np.min(scores):.2f} μg/mL") |
| |
| |
| all_candidates = [] |
| for cfg_name, results in all_results.items(): |
| all_candidates.extend(results) |
| |
| if all_candidates: |
| print(f"\n{'='*60}") |
| print("TOP 10 OVERALL CANDIDATES (All CFG Settings)") |
| print(f"{'='*60}") |
| |
| top_overall = sorted(all_candidates, key=lambda x: x['apex_score'])[:10] |
| for i, candidate in enumerate(top_overall): |
| print(f"{i+1:2d}. MIC: {candidate['apex_score']:.2f} μg/mL | " |
| f"CFG: {candidate['cfg_setting']} | " |
| f"Sequence: {candidate['sequence']}") |
| |
| |
| output_dir = '/data2/edwardsun/apex_results' |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| overall_results_file = os.path.join(output_dir, f'apex_results_all_cfg_comparison_{today}.json') |
| with open(overall_results_file, 'w') as f: |
| json.dump({ |
| 'metadata': { |
| 'date': today, |
| 'total_sequences': len(all_candidates), |
| 'apex_available': APEX_AVAILABLE, |
| 'cfg_settings_tested': list(all_results.keys()) |
| }, |
| 'results': all_candidates |
| }, f, indent=2) |
| print(f"\n✓ Overall results saved to {overall_results_file}") |
| |
| |
| mic_summary_file = os.path.join(output_dir, f'mic_summary_{today}.json') |
| mic_summary = { |
| 'date': today, |
| 'summary_by_cfg': {}, |
| 'all_mics': [r['apex_score'] for r in all_candidates], |
| 'amp_count': sum(1 for r in all_candidates if r['is_amp']), |
| 'total_sequences': len(all_candidates) |
| } |
| |
| for cfg_name, results in all_results.items(): |
| if results: |
| scores = [r['apex_score'] for r in results] |
| amp_count = sum(1 for r in results if r['is_amp']) |
| mic_summary['summary_by_cfg'][cfg_name] = { |
| 'num_sequences': len(results), |
| 'amp_count': amp_count, |
| 'amp_percentage': amp_count/len(results)*100, |
| 'avg_mic': np.mean(scores), |
| 'min_mic': np.min(scores), |
| 'max_mic': np.max(scores), |
| 'std_mic': np.std(scores), |
| 'all_mics': scores |
| } |
| |
| with open(mic_summary_file, 'w') as f: |
| json.dump(mic_summary, f, indent=2) |
| print(f"✓ MIC summary saved to {mic_summary_file}") |
| |
| print(f"\n✅ APEX testing completed successfully!") |
| print(f"Tested {len(all_candidates)} total sequences across all CFG settings") |
|
|
| if __name__ == "__main__": |
| main() |