Spaces:
Sleeping
Sleeping
| """ | |
| CSV Logger for training metrics. | |
| Replaces wandb logging with simple CSV files that can be viewed later. | |
| """ | |
| import csv | |
| import json | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Set | |
| class CSVLogger: | |
| """ | |
| Logger that writes metrics to CSV files for easy viewing and analysis. | |
| Each run creates a timestamped directory with: | |
| - metrics.csv: Main training metrics (key metrics only) | |
| - detailed_metrics/: Detailed metrics per iteration (JSON) | |
| - config.json: Configuration parameters | |
| - summary.json: Final summary statistics | |
| """ | |
| # Define which metrics to include in the main CSV (keep it concise) | |
| KEY_METRICS = { | |
| "iteration", "step", "timestamp", | |
| # Training metrics | |
| "train/policy_loss", "train/value_loss", "train/entropy", | |
| "train/approx_kl", "train/clip_fraction", | |
| # Evaluation metrics | |
| "eval/accuracy", "eval/correct", "eval/total", | |
| # Buffer/rollout metrics | |
| "rollout/mean_reward", "rollout/num_trajectories", "rollout/mean_length", | |
| # Curriculum metrics (high-level) | |
| "curriculum/topic_diversity", "curriculum/avg_difficulty", | |
| "curriculum/avg_novelty", "curriculum/replay_ratio", | |
| # Performance metrics | |
| "perf/rollout_time", "perf/train_time", "perf/total_time", | |
| "perf/tokens_per_second", | |
| # Consensus metrics | |
| "consensus/rate", "consensus/answer_diversity", | |
| # Disk/resource metrics | |
| "system/disk_free_gb", "system/gpu_util_percent", | |
| } | |
| def __init__( | |
| self, | |
| project: str = "training", | |
| run_name: Optional[str] = None, | |
| log_dir: str = "logs", | |
| config: Optional[Dict[str, Any]] = None, | |
| log_detailed: bool = True, | |
| ): | |
| """ | |
| Initialize CSV logger. | |
| Args: | |
| project: Project name (used as subdirectory) | |
| run_name: Optional run name, defaults to timestamp | |
| log_dir: Base directory for logs | |
| config: Optional configuration dict to save | |
| log_detailed: If True, save full metrics as JSON per iteration | |
| """ | |
| self.project = project | |
| self.run_name = run_name or f"run_{datetime.now():%Y%m%d_%H%M%S}" | |
| self.log_detailed = log_detailed | |
| # Create log directory | |
| self.log_path = Path(log_dir) / project / self.run_name | |
| self.log_path.mkdir(parents=True, exist_ok=True) | |
| if self.log_detailed: | |
| self.detailed_path = self.log_path / "detailed_metrics" | |
| self.detailed_path.mkdir(exist_ok=True) | |
| # Initialize metrics file | |
| self.metrics_file = self.log_path / "metrics.csv" | |
| self.metrics_writer = None | |
| self.metrics_handle = None | |
| self.fieldnames: List[str] = [] | |
| self.step_count = 0 | |
| # Save config | |
| if config: | |
| config_file = self.log_path / "config.json" | |
| with open(config_file, "w") as f: | |
| json.dump(config, f, indent=2, default=str) | |
| print(f"CSV Logger initialized: {self.log_path}") | |
| def log(self, metrics: Dict[str, Any], step: Optional[int] = None): | |
| """ | |
| Log metrics to CSV file (only key metrics) and optionally full metrics to JSON. | |
| Args: | |
| metrics: Dictionary of metric names and values | |
| step: Optional step/iteration number | |
| """ | |
| if step is None: | |
| step = self.step_count | |
| self.step_count += 1 | |
| # Save full detailed metrics to JSON if enabled | |
| if self.log_detailed: | |
| detailed_file = self.detailed_path / f"step_{step:04d}.json" | |
| with open(detailed_file, "w") as f: | |
| json.dump(metrics, f, indent=2, default=str) | |
| # Flatten nested dicts | |
| flat_metrics = self._flatten_dict(metrics) | |
| flat_metrics["step"] = step | |
| flat_metrics["timestamp"] = datetime.now().isoformat() | |
| # Filter to only key metrics for CSV | |
| csv_metrics = {k: v for k, v in flat_metrics.items() | |
| if k in self.KEY_METRICS or any(k.startswith(prefix) for prefix in ["iteration"])} | |
| # Initialize CSV writer if needed | |
| if self.metrics_writer is None: | |
| # Determine initial fieldnames from key metrics | |
| self.fieldnames = ["step", "timestamp"] + sorted( | |
| [k for k in csv_metrics.keys() if k not in ["step", "timestamp"]] | |
| ) | |
| self.metrics_handle = open(self.metrics_file, "w", newline="") | |
| self.metrics_writer = csv.DictWriter( | |
| self.metrics_handle, | |
| fieldnames=self.fieldnames, | |
| extrasaction="ignore" | |
| ) | |
| self.metrics_writer.writeheader() | |
| # Add any new fields that match our key metrics | |
| new_fields = [k for k in csv_metrics.keys() if k not in self.fieldnames] | |
| if new_fields: | |
| self._add_columns(new_fields) | |
| # Write row | |
| self.metrics_writer.writerow(csv_metrics) | |
| self.metrics_handle.flush() | |
| def _flatten_dict(self, d: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]: | |
| """ | |
| Flatten nested dictionary using separator. | |
| Example: {"train": {"loss": 0.5}} -> {"train/loss": 0.5} | |
| """ | |
| items = [] | |
| for k, v in d.items(): | |
| new_key = f"{parent_key}{sep}{k}" if parent_key else k | |
| if isinstance(v, dict): | |
| items.extend(self._flatten_dict(v, new_key, sep=sep).items()) | |
| else: | |
| # Convert to JSON string if not a simple type | |
| if isinstance(v, (list, tuple)): | |
| v = json.dumps(v) | |
| elif not isinstance(v, (str, int, float, bool, type(None))): | |
| v = str(v) | |
| items.append((new_key, v)) | |
| return dict(items) | |
| def _add_columns(self, new_fields: List[str]): | |
| """Add new columns to existing CSV by rewriting it.""" | |
| self.fieldnames.extend(new_fields) | |
| # Read existing data | |
| self.metrics_handle.close() | |
| existing_data = [] | |
| if self.metrics_file.exists(): | |
| with open(self.metrics_file, "r") as f: | |
| reader = csv.DictReader(f) | |
| existing_data = list(reader) | |
| # Rewrite with new fieldnames | |
| self.metrics_handle = open(self.metrics_file, "w", newline="") | |
| self.metrics_writer = csv.DictWriter( | |
| self.metrics_handle, | |
| fieldnames=self.fieldnames, | |
| extrasaction="ignore" | |
| ) | |
| self.metrics_writer.writeheader() | |
| for row in existing_data: | |
| self.metrics_writer.writerow(row) | |
| def save_summary(self, summary: Dict[str, Any]): | |
| """ | |
| Save a summary dictionary to JSON. | |
| Args: | |
| summary: Summary statistics or final results | |
| """ | |
| summary_file = self.log_path / "summary.json" | |
| with open(summary_file, "w") as f: | |
| json.dump(summary, f, indent=2, default=str) | |
| def save_artifact(self, name: str, data: Any): | |
| """ | |
| Save arbitrary data as JSON artifact. | |
| Args: | |
| name: Artifact name (will be used as filename) | |
| data: Data to save (must be JSON serializable) | |
| """ | |
| artifact_file = self.log_path / f"{name}.json" | |
| with open(artifact_file, "w") as f: | |
| json.dump(data, f, indent=2, default=str) | |
| def finish(self): | |
| """Close logger and clean up resources.""" | |
| if self.metrics_handle: | |
| self.metrics_handle.close() | |
| print(f"Logs saved to: {self.log_path}") | |
| def __del__(self): | |
| """Ensure file handle is closed.""" | |
| if self.metrics_handle and not self.metrics_handle.closed: | |
| self.metrics_handle.close() | |