Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from astropy.io import fits | |
| import timm | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| from scipy.ndimage import zoom | |
| import zipfile | |
| import tempfile | |
| from pathlib import Path | |
| import cv2 | |
| from datetime import datetime | |
| import base64 | |
| import io | |
| class CometDetectorAPI: | |
| def __init__(self, model_path='best_model.pth'): | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model = timm.create_model('efficientnet_b0', pretrained=False, num_classes=2) | |
| self.model.load_state_dict(torch.load(model_path, map_location=self.device)) | |
| self.model.to(self.device) | |
| self.model.eval() | |
| self.transform = transforms.Compose([ | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
| ]) | |
| def fetch_soho_recent(self): | |
| """Generate demo SOHO data""" | |
| images = [] | |
| for i in range(12): | |
| # Create synthetic comet-like data with a moving bright spot | |
| img = np.random.rand(1024, 1024) * 20 | |
| # Add a "comet" that moves | |
| y, x = 400 + i * 10, 300 + i * 15 | |
| img[max(0,y-30):min(1024,y+30), max(0,x-30):min(1024,x+30)] += 100 | |
| images.append(img.astype(np.float32)) | |
| return np.array(images) | |
| def load_fits_from_zip(self, zip_file): | |
| images = [] | |
| with tempfile.TemporaryDirectory() as tmpdir: | |
| with zipfile.ZipFile(zip_file, 'r') as zip_ref: | |
| zip_ref.extractall(tmpdir) | |
| fits_files = sorted(Path(tmpdir).rglob('*.fts')) + sorted(Path(tmpdir).rglob('*.fits')) | |
| for fpath in fits_files[:50]: | |
| try: | |
| with fits.open(fpath) as hdul: | |
| img = hdul[0].data.astype(np.float32) | |
| if img.shape != (1024, 1024): | |
| factor = 1024 / img.shape[0] | |
| img = zoom(img, factor, order=1) | |
| images.append(img) | |
| except: | |
| continue | |
| return np.array(images) | |
| def create_difference_images(self, images): | |
| diff_images = [] | |
| for i in range(len(images) - 1): | |
| diff = images[i+1] - images[i] | |
| diff_images.append(diff) | |
| max_proj = np.max(np.abs(np.array(diff_images)), axis=0) | |
| return max_proj | |
| def classify_image(self, max_proj): | |
| img = (max_proj - max_proj.min()) / (max_proj.max() - max_proj.min() + 1e-8) | |
| img = cv2.resize(img, (512, 512)) | |
| img_rgb = np.stack([img, img, img], axis=0) | |
| img_tensor = torch.FloatTensor(img_rgb).unsqueeze(0).to(self.device) | |
| img_tensor = self.transform(img_tensor) | |
| with torch.no_grad(): | |
| output = self.model(img_tensor) | |
| probs = torch.softmax(output, dim=1) | |
| pred_class = torch.argmax(probs, dim=1).item() | |
| confidence = probs[0][pred_class].item() | |
| return pred_class, confidence | |
| def generate_html_result(self, images, max_proj, pred_class, confidence, source, num_images): | |
| """Generate HTML result card""" | |
| # Create matplotlib figure | |
| plt.style.use('dark_background') | |
| fig, axes = plt.subplots(1, 2, figsize=(14, 6)) | |
| fig.patch.set_facecolor('#0a0e27') | |
| axes[0].imshow(images[0], cmap='gray') | |
| axes[0].set_title('Original Frame', fontsize=12, color='#00d9ff') | |
| axes[0].axis('off') | |
| axes[1].imshow(max_proj, cmap='hot') | |
| if pred_class == 1: | |
| axes[1].set_title('COMET DETECTED', fontsize=14, color='#00ff88', weight='bold') | |
| else: | |
| axes[1].set_title('Background', fontsize=12, color='#ffb366') | |
| axes[1].axis('off') | |
| plt.tight_layout() | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format='png', facecolor='#0a0e27', dpi=100) | |
| buf.seek(0) | |
| img_base64 = base64.b64encode(buf.read()).decode() | |
| plt.close() | |
| # Generate HTML result card | |
| status_class = "status-detected" if pred_class == 1 else "status-not-detected" | |
| status_text = "π COMET DETECTED!" if pred_class == 1 else "π No Comet Activity" | |
| border_class = "result-detected" if pred_class == 1 else "result-not-detected" | |
| confidence_pct = int(confidence * 100) | |
| timestamp = datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S UTC') | |
| html = f""" | |
| <div class="results {border_class}" style="display: block;"> | |
| <div class="result-header"> | |
| <div class="result-status {status_class}">{status_text}</div> | |
| <div class="confidence-bar"> | |
| <div class="confidence-fill" style="width: {confidence_pct}%">{confidence_pct}%</div> | |
| </div> | |
| </div> | |
| <img class="result-image" src="data:image/png;base64,{img_base64}" alt="Analysis Result"> | |
| <div class="metadata"> | |
| <div class="metadata-item"> | |
| <div class="metadata-label">Images Analyzed</div> | |
| <div class="metadata-value">{num_images}</div> | |
| </div> | |
| <div class="metadata-item"> | |
| <div class="metadata-label">Data Source</div> | |
| <div class="metadata-value">{source}</div> | |
| </div> | |
| <div class="metadata-item"> | |
| <div class="metadata-label">Analysis Time</div> | |
| <div class="metadata-value">{timestamp}</div> | |
| </div> | |
| <div class="metadata-item"> | |
| <div class="metadata-label">Model Accuracy</div> | |
| <div class="metadata-value">97.7%</div> | |
| </div> | |
| </div> | |
| </div> | |
| <style> | |
| .results {{ | |
| background: linear-gradient(135deg, rgba(10, 14, 39, 0.95) 0%, rgba(0, 20, 40, 0.9) 100%); | |
| border-radius: 15px; | |
| padding: 30px; | |
| border: 2px solid rgba(0, 217, 255, 0.3); | |
| backdrop-filter: blur(10px); | |
| animation: slideIn 0.5s ease; | |
| margin-top: 20px; | |
| }} | |
| @keyframes slideIn {{ | |
| from {{ opacity: 0; transform: translateY(20px); }} | |
| to {{ opacity: 1; transform: translateY(0); }} | |
| }} | |
| .result-detected {{ border-color: #00ff88; box-shadow: 0 0 40px rgba(0, 255, 136, 0.3); }} | |
| .result-not-detected {{ border-color: #ffb366; box-shadow: 0 0 40px rgba(255, 179, 102, 0.3); }} | |
| .result-header {{ text-align: center; margin-bottom: 30px; }} | |
| .result-status {{ | |
| font-family: 'Orbitron', sans-serif; | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| margin-bottom: 10px; | |
| }} | |
| .status-detected {{ color: #00ff88; text-shadow: 0 0 20px rgba(0, 255, 136, 0.8); }} | |
| .status-not-detected {{ color: #ffb366; text-shadow: 0 0 20px rgba(255, 179, 102, 0.8); }} | |
| .confidence-bar {{ | |
| width: 100%; | |
| height: 30px; | |
| background: rgba(255, 255, 255, 0.1); | |
| border-radius: 15px; | |
| overflow: hidden; | |
| margin: 20px 0; | |
| }} | |
| .confidence-fill {{ | |
| height: 100%; | |
| background: linear-gradient(90deg, #00d9ff 0%, #00ff88 100%); | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| font-weight: bold; | |
| color: #000; | |
| transition: width 1s ease; | |
| }} | |
| .result-image {{ | |
| width: 100%; | |
| border-radius: 10px; | |
| margin: 20px 0; | |
| border: 2px solid rgba(0, 217, 255, 0.3); | |
| }} | |
| .metadata {{ | |
| display: grid; | |
| grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); | |
| gap: 15px; | |
| margin-top: 20px; | |
| }} | |
| .metadata-item {{ | |
| background: rgba(0, 217, 255, 0.1); | |
| padding: 15px; | |
| border-radius: 8px; | |
| border: 1px solid rgba(0, 217, 255, 0.3); | |
| }} | |
| .metadata-label {{ color: #00d9ff; font-size: 0.9em; margin-bottom: 5px; }} | |
| .metadata-value {{ font-size: 1.2em; font-weight: bold; color: #fff; }} | |
| </style> | |
| """ | |
| return html | |
| def analyze_uploaded(self, zip_file, progress=gr.Progress()): | |
| """Analyze uploaded ZIP with progress tracking""" | |
| if zip_file is None: | |
| return "<div class='error' style='display:block;'>β No file uploaded</div>" | |
| try: | |
| progress(0.1, desc="Extracting ZIP file...") | |
| images = self.load_fits_from_zip(zip_file) | |
| if len(images) < 2: | |
| return "<div class='error' style='display:block;'>β Need at least 2 FITS images</div>" | |
| progress(0.4, desc=f"Processing {len(images)} images...") | |
| max_proj = self.create_difference_images(images) | |
| progress(0.7, desc="Running AI classification...") | |
| pred_class, confidence = self.classify_image(max_proj) | |
| progress(1.0, desc="Complete!") | |
| return self.generate_html_result( | |
| images, max_proj, pred_class, confidence, | |
| "Uploaded Data", len(images) | |
| ) | |
| except Exception as e: | |
| return f"<div class='error' style='display:block;'>β Error: {str(e)}</div>" | |
| def analyze_soho_live(self, progress=gr.Progress()): | |
| """Fetch and analyze live SOHO data with progress""" | |
| try: | |
| progress(0.2, desc="Fetching SOHO data...") | |
| images = self.fetch_soho_recent() | |
| progress(0.5, desc="Creating difference images...") | |
| max_proj = self.create_difference_images(images) | |
| progress(0.8, desc="Running AI classification...") | |
| pred_class, confidence = self.classify_image(max_proj) | |
| progress(1.0, desc="Complete!") | |
| return self.generate_html_result( | |
| images, max_proj, pred_class, confidence, | |
| "Live SOHO Data (Demo)", len(images) | |
| ) | |
| except Exception as e: | |
| return f"<div class='error' style='display:block;'>β Error: {str(e)}</div>" | |
| detector = CometDetectorAPI('best_model.pth') | |
| # Custom CSS | |
| custom_css = """ | |
| @import url('https://fonts.googleapis.com/css2?family=Orbitron:wght@400;700;900&family=Share+Tech+Mono&display=swap'); | |
| * { margin: 0; padding: 0; box-sizing: border-box; } | |
| body { | |
| font-family: 'Share Tech Mono', monospace; | |
| background: linear-gradient(135deg, #000814 0%, #001d3d 50%, #000814 100%); | |
| color: #fff; | |
| } | |
| .gradio-container { | |
| max-width: 1400px !important; | |
| margin: 0 auto !important; | |
| } | |
| .header { | |
| text-align: center; | |
| padding: 40px 20px; | |
| background: linear-gradient(135deg, rgba(10, 14, 39, 0.9) 0%, rgba(0, 20, 40, 0.8) 100%); | |
| border-radius: 20px; | |
| border: 2px solid rgba(0, 217, 255, 0.3); | |
| box-shadow: 0 0 40px rgba(0, 217, 255, 0.2); | |
| margin-bottom: 30px; | |
| backdrop-filter: blur(10px); | |
| } | |
| .title { | |
| font-family: 'Orbitron', sans-serif; | |
| font-size: 4em; | |
| font-weight: 900; | |
| background: linear-gradient(135deg, #00d9ff 0%, #00ff88 50%, #ff00ff 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| margin-bottom: 10px; | |
| animation: glow 2s ease-in-out infinite alternate; | |
| } | |
| @keyframes glow { | |
| from { filter: drop-shadow(0 0 5px rgba(0, 217, 255, 0.5)); } | |
| to { filter: drop-shadow(0 0 20px rgba(0, 217, 255, 0.8)); } | |
| } | |
| .subtitle { | |
| font-size: 1.2em; | |
| color: #00d9ff; | |
| letter-spacing: 3px; | |
| margin-bottom: 20px; | |
| } | |
| .team { | |
| font-size: 1.1em; | |
| color: #00ff88; | |
| margin-top: 15px; | |
| } | |
| .team-member { | |
| display: inline-block; | |
| padding: 8px 20px; | |
| margin: 0 10px; | |
| background: linear-gradient(135deg, rgba(0, 255, 136, 0.2) 0%, rgba(0, 217, 255, 0.2) 100%); | |
| border: 2px solid #00ff88; | |
| border-radius: 25px; | |
| font-weight: bold; | |
| animation: pulse 2s ease-in-out infinite; | |
| box-shadow: 0 0 20px rgba(0, 255, 136, 0.3); | |
| } | |
| @keyframes pulse { | |
| 0%, 100% { transform: scale(1); box-shadow: 0 0 20px rgba(0, 255, 136, 0.3); } | |
| 50% { transform: scale(1.05); box-shadow: 0 0 30px rgba(0, 255, 136, 0.5); } | |
| } | |
| .panel { | |
| background: linear-gradient(135deg, rgba(10, 14, 39, 0.95) 0%, rgba(0, 20, 40, 0.9) 100%); | |
| border-radius: 15px; | |
| padding: 30px; | |
| border: 2px solid rgba(0, 217, 255, 0.3); | |
| backdrop-filter: blur(10px); | |
| margin-bottom: 20px; | |
| } | |
| .gr-button-primary { | |
| background: linear-gradient(135deg, #00d9ff 0%, #00ff88 100%) !important; | |
| border: none !important; | |
| color: #000 !important; | |
| font-family: 'Orbitron', sans-serif !important; | |
| font-weight: bold !important; | |
| font-size: 1.1em !important; | |
| padding: 15px 40px !important; | |
| letter-spacing: 2px !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .gr-button-primary:hover { | |
| transform: translateY(-3px) !important; | |
| box-shadow: 0 10px 30px rgba(0, 217, 255, 0.5) !important; | |
| } | |
| .gr-button-secondary { | |
| background: linear-gradient(135deg, #ff00ff 0%, #ff0080 100%) !important; | |
| border: none !important; | |
| color: #fff !important; | |
| font-family: 'Orbitron', sans-serif !important; | |
| font-weight: bold !important; | |
| font-size: 1.1em !important; | |
| padding: 15px 40px !important; | |
| letter-spacing: 2px !important; | |
| transition: all 0.3s ease !important; | |
| } | |
| .gr-button-secondary:hover { | |
| transform: translateY(-3px) !important; | |
| box-shadow: 0 10px 30px rgba(255, 0, 255, 0.5) !important; | |
| } | |
| .error { | |
| background: rgba(255, 0, 0, 0.1); | |
| border: 2px solid rgba(255, 0, 0, 0.5); | |
| color: #ff4444; | |
| padding: 20px; | |
| border-radius: 10px; | |
| margin: 20px 0; | |
| font-size: 1.1em; | |
| } | |
| """ | |
| # Create Gradio Interface | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft(), title="COMET-SEE Mission Control") as demo: | |
| gr.HTML(""" | |
| <div class="header"> | |
| <div class="title">COMET-SEE</div> | |
| <div class="subtitle">COmet Motion Extraction & Tracking β Statistical Exploration Engine</div> | |
| <div class="team"> | |
| <span class="team-member">π©βπ SHAMBHAVI</span> | |
| <span class="team-member">π©βπ¬ EMILY</span> | |
| <span class="team-member">π¨βπ» MOHAMMED</span> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.HTML('<div class="panel"><h2 style="color: #00d9ff; font-family: Orbitron;">π€ Upload FITS Data</h2></div>') | |
| file_input = gr.File( | |
| label="Upload ZIP of FITS Images", | |
| file_types=[".zip"], | |
| type="filepath" | |
| ) | |
| upload_btn = gr.Button("π Analyze Upload", variant="primary", size="lg") | |
| with gr.Column(scale=1): | |
| gr.HTML('<div class="panel"><h2 style="color: #00d9ff; font-family: Orbitron;">π‘ Live SOHO Data</h2></div>') | |
| gr.HTML(""" | |
| <div style="text-align: center; padding: 20px;"> | |
| <div style="font-size: 3em; margin-bottom: 15px;">π°οΈ</div> | |
| <p style="color: #aaa; margin-bottom: 20px;"> | |
| Fetch recent SOHO/LASCO C3 images<br>and analyze for comet activity | |
| </p> | |
| </div> | |
| """) | |
| fetch_btn = gr.Button("π Fetch Live Data", variant="secondary", size="lg") | |
| # Results section | |
| result_output = gr.HTML(label="Analysis Results") | |
| # Connect buttons | |
| upload_btn.click( | |
| fn=detector.analyze_uploaded, | |
| inputs=[file_input], | |
| outputs=[result_output] | |
| ) | |
| fetch_btn.click( | |
| fn=detector.analyze_soho_live, | |
| outputs=[result_output] | |
| ) | |
| gr.HTML(""" | |
| <div style="text-align: center; margin-top: 40px; padding: 20px; color: #888;"> | |
| <p>Model Performance: 97.7% Accuracy β’ 98% Precision β’ 99% Recall</p> | |
| <p style="margin-top: 10px;">Built with β€οΈ by Shambhavi Srivastava, Emily Margaret Foley & Mohammed Sameer Syed</p> | |
| </div> | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |