| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| PitVQA Multi-Agent Orchestration System |
| |
| Specialized agents for methodologically rigorous VLM pipeline management: |
| 1. JobMonitorAgent - Track HuggingFace Jobs status |
| 2. CurationAgent - Quality-filter showcase examples |
| 3. DatasetAgent - Validate image-embedded dataset |
| 4. ModelVerifierAgent - Test merged model outputs |
| 5. DemoSyncAgent - Update Gradio Space with results |
| |
| Run with: python pitvqa_agent_orchestrator.py |
| """ |
|
|
| import os |
| import json |
| import time |
| from dataclasses import dataclass |
| from typing import Dict, List, Optional, Any |
| from datetime import datetime |
| from enum import Enum |
|
|
| |
| |
| |
|
|
| class AgentStatus(Enum): |
| IDLE = "idle" |
| RUNNING = "running" |
| SUCCESS = "success" |
| FAILED = "failed" |
| WAITING = "waiting" |
|
|
| @dataclass |
| class AgentResult: |
| agent_name: str |
| status: AgentStatus |
| message: str |
| data: Optional[Dict] = None |
| timestamp: str = "" |
|
|
| def __post_init__(self): |
| if not self.timestamp: |
| self.timestamp = datetime.now().isoformat() |
|
|
| |
| |
| |
|
|
| class BaseAgent: |
| """Base class for all PitVQA agents.""" |
|
|
| def __init__(self, name: str): |
| self.name = name |
| self.status = AgentStatus.IDLE |
| self.results: List[AgentResult] = [] |
|
|
| def log(self, message: str, level: str = "INFO"): |
| icon = {"INFO": "βΉοΈ", "SUCCESS": "β
", "ERROR": "β", "WARN": "β οΈ"}.get(level, "π") |
| print(f"[{self.name}] {icon} {message}") |
|
|
| def run(self) -> AgentResult: |
| raise NotImplementedError |
|
|
| def report(self) -> Dict: |
| return { |
| "agent": self.name, |
| "status": self.status.value, |
| "results": [r.__dict__ for r in self.results] |
| } |
|
|
| |
| |
| |
|
|
| class JobMonitorAgent(BaseAgent): |
| """Monitors HuggingFace Jobs and reports status.""" |
|
|
| def __init__(self, job_ids: List[str]): |
| super().__init__("JobMonitor") |
| self.job_ids = job_ids |
| self.job_status = {} |
|
|
| def check_job(self, job_id: str) -> Dict: |
| """Check single job status using HF API.""" |
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
|
|
| |
| job = api.get_job(job_id) |
| return { |
| "id": job_id, |
| "status": job.status.stage if hasattr(job.status, 'stage') else str(job.status), |
| "message": job.status.message if hasattr(job.status, 'message') else None |
| } |
| except Exception as e: |
| return {"id": job_id, "status": "UNKNOWN", "error": str(e)} |
|
|
| def run(self) -> AgentResult: |
| self.status = AgentStatus.RUNNING |
| self.log(f"Checking {len(self.job_ids)} jobs...") |
|
|
| all_complete = True |
| any_failed = False |
|
|
| for job_id in self.job_ids: |
| status = self.check_job(job_id) |
| self.job_status[job_id] = status |
|
|
| stage = status.get("status", "UNKNOWN") |
| self.log(f"Job {job_id[:8]}: {stage}") |
|
|
| if stage not in ["COMPLETED", "SUCCESS"]: |
| all_complete = False |
| if stage in ["FAILED", "ERROR"]: |
| any_failed = True |
|
|
| if any_failed: |
| self.status = AgentStatus.FAILED |
| return AgentResult(self.name, AgentStatus.FAILED, "Some jobs failed", self.job_status) |
| elif all_complete: |
| self.status = AgentStatus.SUCCESS |
| return AgentResult(self.name, AgentStatus.SUCCESS, "All jobs complete", self.job_status) |
| else: |
| self.status = AgentStatus.WAITING |
| return AgentResult(self.name, AgentStatus.WAITING, "Jobs still running", self.job_status) |
|
|
| |
| |
| |
|
|
| class CurationAgent(BaseAgent): |
| """Curates showcase examples based on quality criteria.""" |
|
|
| QUALITY_CRITERIA = { |
| "coordinate_validity": lambda x, y: 0 <= x <= 100 and 0 <= y <= 100, |
| "coordinate_diversity": lambda coords: len(set(coords)) > len(coords) * 0.5, |
| "video_diversity": lambda vids: len(set(vids)) >= min(5, len(vids)), |
| "frame_diversity": lambda frames: len(set(frames)) >= min(8, len(frames)), |
| } |
|
|
| def __init__(self, results_path: str = "./curation_review/all_results.json"): |
| super().__init__("Curation") |
| self.results_path = results_path |
| self.curated_examples = [] |
|
|
| def load_results(self) -> List[Dict]: |
| """Load raw curation results.""" |
| try: |
| with open(self.results_path) as f: |
| return json.load(f) |
| except FileNotFoundError: |
| self.log("Results file not found - job may still be running", "WARN") |
| return [] |
|
|
| def score_example(self, example: Dict) -> float: |
| """Score a single example (0-1).""" |
| score = 0.0 |
|
|
| |
| if example.get("success"): |
| score += 0.3 |
|
|
| |
| if example.get("task") == "point": |
| x, y = example.get("x"), example.get("y") |
| if x and y: |
| |
| if 10 < x < 90 and 10 < y < 90: |
| score += 0.3 |
| else: |
| score += 0.1 |
| elif example.get("task") == "bbox": |
| bbox = example.get("bbox") |
| if bbox and len(bbox) == 4: |
| |
| area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) |
| if 100 < area < 5000: |
| score += 0.3 |
| else: |
| score += 0.1 |
|
|
| |
| response = example.get("response", "") |
| if "<point" in response or "<box" in response: |
| score += 0.2 |
|
|
| |
| target = example.get("target", "") |
| if target in response.lower(): |
| score += 0.2 |
|
|
| return min(score, 1.0) |
|
|
| def curate(self, results: List[Dict], top_k: int = 12) -> List[Dict]: |
| """Select best diverse examples.""" |
| if not results: |
| return [] |
|
|
| |
| scored = [(self.score_example(ex), ex) for ex in results if ex.get("success")] |
| scored.sort(key=lambda x: x[0], reverse=True) |
|
|
| |
| curated = [] |
| used_videos = set() |
| used_frames = set() |
| used_tasks = {"point": 0, "bbox": 0} |
|
|
| for score, ex in scored: |
| if len(curated) >= top_k: |
| break |
|
|
| video = ex.get("video_id") |
| frame = ex.get("frame_idx") |
| task = ex.get("task") |
|
|
| |
| if used_videos.count(video) >= 2: |
| continue |
| if (video, frame) in used_frames: |
| continue |
| if used_tasks.get(task, 0) >= top_k // 2: |
| continue |
|
|
| curated.append({**ex, "quality_score": score}) |
| used_videos.add(video) |
| used_frames.add((video, frame)) |
| used_tasks[task] = used_tasks.get(task, 0) + 1 |
|
|
| return curated |
|
|
| def run(self) -> AgentResult: |
| self.status = AgentStatus.RUNNING |
| self.log("Loading curation results...") |
|
|
| results = self.load_results() |
| if not results: |
| self.status = AgentStatus.WAITING |
| return AgentResult(self.name, AgentStatus.WAITING, "No results available yet") |
|
|
| self.log(f"Scoring {len(results)} examples...") |
| self.curated_examples = self.curate(results) |
|
|
| if len(self.curated_examples) >= 8: |
| self.status = AgentStatus.SUCCESS |
|
|
| |
| videos = set(ex["video_id"] for ex in self.curated_examples) |
| frames = set(ex["frame_idx"] for ex in self.curated_examples) |
|
|
| self.log(f"Curated {len(self.curated_examples)} examples", "SUCCESS") |
| self.log(f" Videos: {len(videos)} unique") |
| self.log(f" Frames: {len(frames)} unique") |
|
|
| return AgentResult( |
| self.name, |
| AgentStatus.SUCCESS, |
| f"Curated {len(self.curated_examples)} high-quality diverse examples", |
| {"examples": self.curated_examples} |
| ) |
| else: |
| self.status = AgentStatus.FAILED |
| return AgentResult( |
| self.name, |
| AgentStatus.FAILED, |
| f"Only {len(self.curated_examples)} examples passed quality checks" |
| ) |
|
|
| |
| |
| |
|
|
| class DatasetValidatorAgent(BaseAgent): |
| """Validates image-embedded dataset quality.""" |
|
|
| def __init__(self, dataset_id: str = "mmrech/pitvqa-spatial-with-images"): |
| super().__init__("DatasetValidator") |
| self.dataset_id = dataset_id |
|
|
| def run(self) -> AgentResult: |
| self.status = AgentStatus.RUNNING |
| self.log(f"Validating dataset: {self.dataset_id}") |
|
|
| try: |
| from datasets import load_dataset |
|
|
| |
| ds = load_dataset(self.dataset_id, split="train[:10]") |
|
|
| |
| required_fields = ["image", "messages"] |
| missing = [f for f in required_fields if f not in ds.features] |
|
|
| if missing: |
| self.status = AgentStatus.FAILED |
| return AgentResult( |
| self.name, |
| AgentStatus.FAILED, |
| f"Missing fields: {missing}" |
| ) |
|
|
| |
| valid_images = 0 |
| for ex in ds: |
| img = ex.get("image") |
| if img and hasattr(img, "size") and img.size[0] > 0: |
| valid_images += 1 |
|
|
| if valid_images == len(ds): |
| self.status = AgentStatus.SUCCESS |
| return AgentResult( |
| self.name, |
| AgentStatus.SUCCESS, |
| f"Dataset valid: {valid_images}/{len(ds)} images OK", |
| {"sample_count": len(ds), "valid_images": valid_images} |
| ) |
| else: |
| self.status = AgentStatus.FAILED |
| return AgentResult( |
| self.name, |
| AgentStatus.FAILED, |
| f"Invalid images: {len(ds) - valid_images}/{len(ds)}" |
| ) |
|
|
| except Exception as e: |
| self.status = AgentStatus.WAITING |
| return AgentResult( |
| self.name, |
| AgentStatus.WAITING, |
| f"Dataset not yet available: {e}" |
| ) |
|
|
| |
| |
| |
|
|
| class ModelVerifierAgent(BaseAgent): |
| """Verifies merged model outputs are correct.""" |
|
|
| TEST_PROMPTS = [ |
| ("Point to the suction device", "point"), |
| ("Draw a bounding box around the surgical instrument", "bbox"), |
| ("What surgical phase is this?", "classification"), |
| ] |
|
|
| def __init__(self, model_id: str = "mmrech/pitvqa-qwen2vl-merged"): |
| super().__init__("ModelVerifier") |
| self.model_id = model_id |
|
|
| def run(self) -> AgentResult: |
| self.status = AgentStatus.RUNNING |
| self.log(f"Verifying model: {self.model_id}") |
|
|
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
|
|
| |
| try: |
| info = api.model_info(self.model_id) |
| self.log(f"Model found: {info.modelId}") |
|
|
| |
| files = [f.rfilename for f in info.siblings] |
| required = ["config.json", "model.safetensors"] |
|
|
| |
| has_model = any("safetensors" in f or "pytorch" in f for f in files) |
| has_config = "config.json" in files |
|
|
| if has_model and has_config: |
| self.status = AgentStatus.SUCCESS |
| return AgentResult( |
| self.name, |
| AgentStatus.SUCCESS, |
| f"Model verified: {len(files)} files present", |
| {"files": files[:10]} |
| ) |
| else: |
| self.status = AgentStatus.FAILED |
| return AgentResult( |
| self.name, |
| AgentStatus.FAILED, |
| f"Missing model files (has_model={has_model}, has_config={has_config})" |
| ) |
|
|
| except Exception as e: |
| self.status = AgentStatus.WAITING |
| return AgentResult( |
| self.name, |
| AgentStatus.WAITING, |
| f"Model not yet available: {e}" |
| ) |
|
|
| except Exception as e: |
| self.status = AgentStatus.FAILED |
| return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}") |
|
|
| |
| |
| |
|
|
| class TrainingSpecialistAgent(BaseAgent): |
| """ |
| Specialist in HuggingFace LLM Training (TRL/SFT/LoRA/DPO). |
| |
| Responsibilities: |
| - Validate training configurations |
| - Check adapter quality |
| - Recommend training improvements |
| - Verify LoRA/PEFT setup |
| """ |
|
|
| TRAINING_METHODS = { |
| "SFT": "Supervised Fine-Tuning - learning from (input, output) pairs", |
| "LoRA": "Low-Rank Adaptation - parameter-efficient adapters", |
| "DPO": "Direct Preference Optimization - learning from preferences", |
| "RLHF": "Reinforcement Learning from Human Feedback", |
| } |
|
|
| OPTIMAL_CONFIG = { |
| "lora_r": 16, |
| "lora_alpha": 32, |
| "learning_rate": 1e-4, |
| "batch_size": 1, |
| "gradient_accumulation_steps": 16, |
| "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj"], |
| } |
|
|
| def __init__(self, adapter_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"): |
| super().__init__("TrainingSpecialist") |
| self.adapter_repo = adapter_repo |
|
|
| def validate_adapter_config(self) -> Dict: |
| """Validate adapter configuration.""" |
| try: |
| from huggingface_hub import hf_hub_download |
| import json |
|
|
| |
| config_path = hf_hub_download( |
| repo_id=self.adapter_repo, |
| filename="stage4/adapter_config.json" |
| ) |
|
|
| with open(config_path) as f: |
| config = json.load(f) |
|
|
| |
| issues = [] |
| recommendations = [] |
|
|
| |
| if config.get("r", 0) < 8: |
| issues.append("LoRA rank too low (r < 8)") |
| elif config.get("r", 0) > 64: |
| recommendations.append("Consider reducing LoRA rank for efficiency") |
|
|
| |
| target_modules = config.get("target_modules", []) |
| if not any("proj" in m for m in target_modules): |
| issues.append("No projection layers targeted") |
|
|
| return { |
| "config": config, |
| "issues": issues, |
| "recommendations": recommendations, |
| "valid": len(issues) == 0 |
| } |
|
|
| except Exception as e: |
| return {"error": str(e), "valid": False} |
|
|
| def recommend_next_training(self, current_metrics: Dict = None) -> Dict: |
| """Recommend next training steps based on current metrics.""" |
| recommendations = [] |
|
|
| if not current_metrics: |
| recommendations.append({ |
| "priority": "HIGH", |
| "action": "Run evaluation to get baseline metrics", |
| "method": "scripts/evaluate_unified_vlm.py" |
| }) |
| else: |
| accuracy = current_metrics.get("accuracy", 0) |
|
|
| if accuracy < 0.7: |
| recommendations.append({ |
| "priority": "HIGH", |
| "action": "Increase training epochs or data", |
| "method": "SFT with more epochs" |
| }) |
|
|
| if accuracy >= 0.7 and accuracy < 0.85: |
| recommendations.append({ |
| "priority": "MEDIUM", |
| "action": "Consider DPO for preference learning", |
| "method": "Create chosen/rejected pairs from predictions" |
| }) |
|
|
| if accuracy >= 0.85: |
| recommendations.append({ |
| "priority": "LOW", |
| "action": "Model performing well - focus on inference optimization", |
| "method": "Merge adapters, quantize for deployment" |
| }) |
|
|
| return {"recommendations": recommendations} |
|
|
| def run(self) -> AgentResult: |
| self.status = AgentStatus.RUNNING |
| self.log(f"Validating training setup: {self.adapter_repo}") |
|
|
| |
| validation = self.validate_adapter_config() |
|
|
| if validation.get("valid"): |
| self.status = AgentStatus.SUCCESS |
| recommendations = self.recommend_next_training() |
|
|
| return AgentResult( |
| self.name, |
| AgentStatus.SUCCESS, |
| f"Training config valid. LoRA r={validation['config'].get('r')}", |
| { |
| "config": validation["config"], |
| "recommendations": recommendations["recommendations"] |
| } |
| ) |
| elif validation.get("error"): |
| self.status = AgentStatus.WAITING |
| return AgentResult( |
| self.name, |
| AgentStatus.WAITING, |
| f"Could not load adapter: {validation['error']}" |
| ) |
| else: |
| self.status = AgentStatus.FAILED |
| return AgentResult( |
| self.name, |
| AgentStatus.FAILED, |
| f"Issues found: {validation['issues']}", |
| validation |
| ) |
|
|
| |
| |
| |
|
|
| class EvaluationSpecialistAgent(BaseAgent): |
| """ |
| Specialist in Model Evaluation (metrics, benchmarks, validation). |
| |
| Responsibilities: |
| - Compute accuracy, F1, precision, recall |
| - Validate coordinate predictions (MAE, quadrant accuracy) |
| - Compare against baselines |
| - Generate evaluation reports |
| """ |
|
|
| METRICS = { |
| "classification": ["accuracy", "f1", "precision", "recall"], |
| "localization": ["mae", "quadrant_accuracy", "distance_error"], |
| "detection": ["iou", "ap", "ar"], |
| } |
|
|
| THRESHOLDS = { |
| "quadrant_accuracy": 0.75, |
| "mae": 15.0, |
| "classification_accuracy": 0.80, |
| } |
|
|
| def __init__(self, model_repo: str = "mmrech/pitvqa-qwen2vl-unified-v2"): |
| super().__init__("EvaluationSpecialist") |
| self.model_repo = model_repo |
| self.metrics = {} |
|
|
| def load_evaluation_results(self) -> Dict: |
| """Load existing evaluation results if available.""" |
| try: |
| with open("evaluation_results.json") as f: |
| return json.load(f) |
| except FileNotFoundError: |
| return {} |
|
|
| def compute_quick_metrics(self, predictions: List[Dict]) -> Dict: |
| """Compute quick metrics from predictions.""" |
| if not predictions: |
| return {} |
|
|
| metrics = {} |
|
|
| |
| coord_preds = [p for p in predictions if p.get("task") in ["point", "pointing"]] |
| if coord_preds: |
| valid = [p for p in coord_preds if p.get("x") is not None] |
| metrics["valid_rate"] = len(valid) / len(coord_preds) |
|
|
| |
| errors = [] |
| for p in valid: |
| if p.get("gt_x") and p.get("gt_y"): |
| err = ((p["x"] - p["gt_x"])**2 + (p["y"] - p["gt_y"])**2)**0.5 |
| errors.append(err) |
|
|
| if errors: |
| metrics["mae"] = sum(errors) / len(errors) |
| metrics["quadrant_accuracy"] = sum(1 for e in errors if e < 25) / len(errors) |
|
|
| |
| class_preds = [p for p in predictions if p.get("task") == "classification"] |
| if class_preds: |
| correct = sum(1 for p in class_preds if p.get("prediction") == p.get("ground_truth")) |
| metrics["classification_accuracy"] = correct / len(class_preds) |
|
|
| return metrics |
|
|
| def evaluate_against_thresholds(self, metrics: Dict) -> Dict: |
| """Check metrics against quality thresholds.""" |
| results = {"passed": [], "failed": [], "warnings": []} |
|
|
| for metric, threshold in self.THRESHOLDS.items(): |
| if metric in metrics: |
| value = metrics[metric] |
| if metric == "mae": |
| passed = value <= threshold |
| else: |
| passed = value >= threshold |
|
|
| entry = {"metric": metric, "value": value, "threshold": threshold} |
| if passed: |
| results["passed"].append(entry) |
| else: |
| results["failed"].append(entry) |
|
|
| return results |
|
|
| def generate_report(self, metrics: Dict, threshold_results: Dict) -> str: |
| """Generate evaluation report.""" |
| report = [] |
| report.append("=" * 50) |
| report.append("EVALUATION REPORT") |
| report.append("=" * 50) |
|
|
| report.append("\nπ METRICS:") |
| for k, v in metrics.items(): |
| report.append(f" {k}: {v:.4f}" if isinstance(v, float) else f" {k}: {v}") |
|
|
| report.append("\nβ
PASSED:") |
| for item in threshold_results["passed"]: |
| report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})") |
|
|
| if threshold_results["failed"]: |
| report.append("\nβ FAILED:") |
| for item in threshold_results["failed"]: |
| report.append(f" {item['metric']}: {item['value']:.4f} (threshold: {item['threshold']})") |
|
|
| return "\n".join(report) |
|
|
| def run(self, predictions: List[Dict] = None) -> AgentResult: |
| self.status = AgentStatus.RUNNING |
| self.log("Running evaluation...") |
|
|
| |
| existing = self.load_evaluation_results() |
|
|
| if existing: |
| self.log("Found existing evaluation results") |
| self.metrics = existing |
| elif predictions: |
| self.log(f"Computing metrics from {len(predictions)} predictions") |
| self.metrics = self.compute_quick_metrics(predictions) |
| else: |
| self.status = AgentStatus.WAITING |
| return AgentResult( |
| self.name, |
| AgentStatus.WAITING, |
| "No predictions available for evaluation" |
| ) |
|
|
| |
| threshold_results = self.evaluate_against_thresholds(self.metrics) |
|
|
| |
| report = self.generate_report(self.metrics, threshold_results) |
| self.log(f"\n{report}") |
|
|
| if threshold_results["failed"]: |
| self.status = AgentStatus.FAILED |
| return AgentResult( |
| self.name, |
| AgentStatus.FAILED, |
| f"{len(threshold_results['failed'])} metrics below threshold", |
| {"metrics": self.metrics, "thresholds": threshold_results} |
| ) |
| else: |
| self.status = AgentStatus.SUCCESS |
| return AgentResult( |
| self.name, |
| AgentStatus.SUCCESS, |
| f"All {len(threshold_results['passed'])} metrics passed", |
| {"metrics": self.metrics, "thresholds": threshold_results} |
| ) |
|
|
| |
| |
| |
|
|
| class DemoSyncAgent(BaseAgent): |
| """Syncs curated examples to Gradio Space.""" |
|
|
| def __init__(self, space_id: str = "mmrech/pitvqa-surgical-vlm"): |
| super().__init__("DemoSync") |
| self.space_id = space_id |
|
|
| def run(self, curated_examples: List[Dict] = None) -> AgentResult: |
| self.status = AgentStatus.RUNNING |
| self.log(f"Syncing to Space: {self.space_id}") |
|
|
| if not curated_examples: |
| self.status = AgentStatus.WAITING |
| return AgentResult( |
| self.name, |
| AgentStatus.WAITING, |
| "No curated examples to sync" |
| ) |
|
|
| try: |
| from huggingface_hub import HfApi |
| api = HfApi() |
|
|
| |
| try: |
| info = api.space_info(self.space_id) |
| runtime = info.runtime |
|
|
| if runtime and runtime.stage == "RUNNING": |
| self.log(f"Space is running", "SUCCESS") |
|
|
| |
| examples_json = json.dumps(curated_examples, indent=2) |
|
|
| self.status = AgentStatus.SUCCESS |
| return AgentResult( |
| self.name, |
| AgentStatus.SUCCESS, |
| f"Space running, {len(curated_examples)} examples ready for sync", |
| {"space_status": "RUNNING", "examples_count": len(curated_examples)} |
| ) |
| else: |
| self.status = AgentStatus.WAITING |
| return AgentResult( |
| self.name, |
| AgentStatus.WAITING, |
| f"Space not running: {runtime.stage if runtime else 'unknown'}" |
| ) |
|
|
| except Exception as e: |
| self.status = AgentStatus.FAILED |
| return AgentResult(self.name, AgentStatus.FAILED, f"Space error: {e}") |
|
|
| except Exception as e: |
| self.status = AgentStatus.FAILED |
| return AgentResult(self.name, AgentStatus.FAILED, f"Error: {e}") |
|
|
| |
| |
| |
|
|
| class PitVQAOrchestrator: |
| """Coordinates all agents for the PitVQA pipeline.""" |
|
|
| def __init__(self, job_ids: List[str]): |
| self.agents = { |
| "monitor": JobMonitorAgent(job_ids), |
| "curation": CurationAgent(), |
| "dataset": DatasetValidatorAgent(), |
| "model": ModelVerifierAgent(), |
| "training": TrainingSpecialistAgent(), |
| "evaluation": EvaluationSpecialistAgent(), |
| "demo": DemoSyncAgent(), |
| } |
| self.results = {} |
| self.run_count = 0 |
|
|
| def run_cycle(self) -> Dict: |
| """Run one orchestration cycle.""" |
| self.run_count += 1 |
| print(f"\n{'='*60}") |
| print(f"π ORCHESTRATION CYCLE {self.run_count}") |
| print(f"{'='*60}") |
|
|
| |
| print("\nπ Phase 1: Job Monitoring") |
| monitor_result = self.agents["monitor"].run() |
| self.results["monitor"] = monitor_result |
|
|
| |
| print("\nπ Phase 2: Training Validation (HF-LLM-Trainer)") |
| training_result = self.agents["training"].run() |
| self.results["training"] = training_result |
|
|
| |
| if monitor_result.status in [AgentStatus.SUCCESS, AgentStatus.WAITING]: |
|
|
| |
| print("\nπ¨ Phase 3: Curation") |
| curation_result = self.agents["curation"].run() |
| self.results["curation"] = curation_result |
|
|
| |
| print("\nπ¦ Phase 4: Dataset Validation") |
| dataset_result = self.agents["dataset"].run() |
| self.results["dataset"] = dataset_result |
|
|
| |
| print("\nπ€ Phase 5: Model Verification") |
| model_result = self.agents["model"].run() |
| self.results["model"] = model_result |
|
|
| |
| print("\nπ Phase 6: Evaluation (Metrics & Quality)") |
| curated = curation_result.data.get("examples", []) if curation_result.data else [] |
| eval_result = self.agents["evaluation"].run(predictions=curated) |
| self.results["evaluation"] = eval_result |
|
|
| |
| print("\nπ Phase 7: Demo Sync") |
| demo_result = self.agents["demo"].run(curated) |
| self.results["demo"] = demo_result |
|
|
| return self.generate_report() |
|
|
| def generate_report(self) -> Dict: |
| """Generate comprehensive status report.""" |
| report = { |
| "timestamp": datetime.now().isoformat(), |
| "cycle": self.run_count, |
| "overall_status": self._compute_overall_status(), |
| "agents": {} |
| } |
|
|
| for name, result in self.results.items(): |
| report["agents"][name] = { |
| "status": result.status.value, |
| "message": result.message |
| } |
|
|
| return report |
|
|
| def _compute_overall_status(self) -> str: |
| """Compute overall pipeline status.""" |
| statuses = [r.status for r in self.results.values()] |
|
|
| if all(s == AgentStatus.SUCCESS for s in statuses): |
| return "COMPLETE" |
| elif any(s == AgentStatus.FAILED for s in statuses): |
| return "NEEDS_ATTENTION" |
| elif any(s == AgentStatus.WAITING for s in statuses): |
| return "IN_PROGRESS" |
| else: |
| return "UNKNOWN" |
|
|
| def print_summary(self, report: Dict): |
| """Print human-readable summary.""" |
| print(f"\n{'='*60}") |
| print("π ORCHESTRATION SUMMARY") |
| print(f"{'='*60}") |
| print(f"Time: {report['timestamp']}") |
| print(f"Cycle: {report['cycle']}") |
| print(f"Overall: {report['overall_status']}") |
| print("\nAgent Status:") |
| for name, info in report["agents"].items(): |
| icon = {"success": "β
", "failed": "β", "waiting": "β³", "running": "π"}.get(info["status"], "β") |
| print(f" {icon} {name}: {info['status']} - {info['message'][:50]}") |
|
|
| |
| |
| |
|
|
| def main(): |
| print("π PitVQA Multi-Agent Orchestrator Starting...") |
|
|
| |
| job_ids = [ |
| "696cfe9946affbb321046bd9", |
| "696cfebf57a10a9d296ca042", |
| ] |
|
|
| orchestrator = PitVQAOrchestrator(job_ids) |
|
|
| |
| report = orchestrator.run_cycle() |
| orchestrator.print_summary(report) |
|
|
| |
| with open("orchestration_report.json", "w") as f: |
| json.dump(report, f, indent=2) |
| print(f"\nπΎ Report saved to orchestration_report.json") |
|
|
| return report |
|
|
| if __name__ == "__main__": |
| main() |
|
|