Spaces:
Sleeping
Sleeping
File size: 8,053 Bytes
ec4ae03 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 | """
CSV Logger for training metrics.
Replaces wandb logging with simple CSV files that can be viewed later.
"""
import csv
import json
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
class CSVLogger:
"""
Logger that writes metrics to CSV files for easy viewing and analysis.
Each run creates a timestamped directory with:
- metrics.csv: Main training metrics (key metrics only)
- detailed_metrics/: Detailed metrics per iteration (JSON)
- config.json: Configuration parameters
- summary.json: Final summary statistics
"""
# Define which metrics to include in the main CSV (keep it concise)
KEY_METRICS = {
"iteration", "step", "timestamp",
# Training metrics
"train/policy_loss", "train/value_loss", "train/entropy",
"train/approx_kl", "train/clip_fraction",
# Evaluation metrics
"eval/accuracy", "eval/correct", "eval/total",
# Buffer/rollout metrics
"rollout/mean_reward", "rollout/num_trajectories", "rollout/mean_length",
# Curriculum metrics (high-level)
"curriculum/topic_diversity", "curriculum/avg_difficulty",
"curriculum/avg_novelty", "curriculum/replay_ratio",
# Performance metrics
"perf/rollout_time", "perf/train_time", "perf/total_time",
"perf/tokens_per_second",
# Consensus metrics
"consensus/rate", "consensus/answer_diversity",
# Disk/resource metrics
"system/disk_free_gb", "system/gpu_util_percent",
}
def __init__(
self,
project: str = "training",
run_name: Optional[str] = None,
log_dir: str = "logs",
config: Optional[Dict[str, Any]] = None,
log_detailed: bool = True,
):
"""
Initialize CSV logger.
Args:
project: Project name (used as subdirectory)
run_name: Optional run name, defaults to timestamp
log_dir: Base directory for logs
config: Optional configuration dict to save
log_detailed: If True, save full metrics as JSON per iteration
"""
self.project = project
self.run_name = run_name or f"run_{datetime.now():%Y%m%d_%H%M%S}"
self.log_detailed = log_detailed
# Create log directory
self.log_path = Path(log_dir) / project / self.run_name
self.log_path.mkdir(parents=True, exist_ok=True)
if self.log_detailed:
self.detailed_path = self.log_path / "detailed_metrics"
self.detailed_path.mkdir(exist_ok=True)
# Initialize metrics file
self.metrics_file = self.log_path / "metrics.csv"
self.metrics_writer = None
self.metrics_handle = None
self.fieldnames: List[str] = []
self.step_count = 0
# Save config
if config:
config_file = self.log_path / "config.json"
with open(config_file, "w") as f:
json.dump(config, f, indent=2, default=str)
print(f"CSV Logger initialized: {self.log_path}")
def log(self, metrics: Dict[str, Any], step: Optional[int] = None):
"""
Log metrics to CSV file (only key metrics) and optionally full metrics to JSON.
Args:
metrics: Dictionary of metric names and values
step: Optional step/iteration number
"""
if step is None:
step = self.step_count
self.step_count += 1
# Save full detailed metrics to JSON if enabled
if self.log_detailed:
detailed_file = self.detailed_path / f"step_{step:04d}.json"
with open(detailed_file, "w") as f:
json.dump(metrics, f, indent=2, default=str)
# Flatten nested dicts
flat_metrics = self._flatten_dict(metrics)
flat_metrics["step"] = step
flat_metrics["timestamp"] = datetime.now().isoformat()
# Filter to only key metrics for CSV
csv_metrics = {k: v for k, v in flat_metrics.items()
if k in self.KEY_METRICS or any(k.startswith(prefix) for prefix in ["iteration"])}
# Initialize CSV writer if needed
if self.metrics_writer is None:
# Determine initial fieldnames from key metrics
self.fieldnames = ["step", "timestamp"] + sorted(
[k for k in csv_metrics.keys() if k not in ["step", "timestamp"]]
)
self.metrics_handle = open(self.metrics_file, "w", newline="")
self.metrics_writer = csv.DictWriter(
self.metrics_handle,
fieldnames=self.fieldnames,
extrasaction="ignore"
)
self.metrics_writer.writeheader()
# Add any new fields that match our key metrics
new_fields = [k for k in csv_metrics.keys() if k not in self.fieldnames]
if new_fields:
self._add_columns(new_fields)
# Write row
self.metrics_writer.writerow(csv_metrics)
self.metrics_handle.flush()
def _flatten_dict(self, d: Dict[str, Any], parent_key: str = "", sep: str = "/") -> Dict[str, Any]:
"""
Flatten nested dictionary using separator.
Example: {"train": {"loss": 0.5}} -> {"train/loss": 0.5}
"""
items = []
for k, v in d.items():
new_key = f"{parent_key}{sep}{k}" if parent_key else k
if isinstance(v, dict):
items.extend(self._flatten_dict(v, new_key, sep=sep).items())
else:
# Convert to JSON string if not a simple type
if isinstance(v, (list, tuple)):
v = json.dumps(v)
elif not isinstance(v, (str, int, float, bool, type(None))):
v = str(v)
items.append((new_key, v))
return dict(items)
def _add_columns(self, new_fields: List[str]):
"""Add new columns to existing CSV by rewriting it."""
self.fieldnames.extend(new_fields)
# Read existing data
self.metrics_handle.close()
existing_data = []
if self.metrics_file.exists():
with open(self.metrics_file, "r") as f:
reader = csv.DictReader(f)
existing_data = list(reader)
# Rewrite with new fieldnames
self.metrics_handle = open(self.metrics_file, "w", newline="")
self.metrics_writer = csv.DictWriter(
self.metrics_handle,
fieldnames=self.fieldnames,
extrasaction="ignore"
)
self.metrics_writer.writeheader()
for row in existing_data:
self.metrics_writer.writerow(row)
def save_summary(self, summary: Dict[str, Any]):
"""
Save a summary dictionary to JSON.
Args:
summary: Summary statistics or final results
"""
summary_file = self.log_path / "summary.json"
with open(summary_file, "w") as f:
json.dump(summary, f, indent=2, default=str)
def save_artifact(self, name: str, data: Any):
"""
Save arbitrary data as JSON artifact.
Args:
name: Artifact name (will be used as filename)
data: Data to save (must be JSON serializable)
"""
artifact_file = self.log_path / f"{name}.json"
with open(artifact_file, "w") as f:
json.dump(data, f, indent=2, default=str)
def finish(self):
"""Close logger and clean up resources."""
if self.metrics_handle:
self.metrics_handle.close()
print(f"Logs saved to: {self.log_path}")
def __del__(self):
"""Ensure file handle is closed."""
if self.metrics_handle and not self.metrics_handle.closed:
self.metrics_handle.close()
|