from __future__ import annotations import argparse import re from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List, Optional, Tuple import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt REPO_ROOT = Path(__file__).resolve().parents[3] LOG_DIR = REPO_ROOT / "log" @dataclass class ParsedMetrics: series: Dict[str, List[Tuple[int, float]]] = field(default_factory=dict) paired: Dict[str, Tuple[float, float, bool]] = field(default_factory=dict) paired_labels: Tuple[str, str] = ("Baseline", "Model") def canonical_metric(name: str) -> str: return re.sub(r"[^a-z0-9]+", "", name.lower()) def sanitize_filename(name: str) -> str: cleaned = re.sub(r"[^a-zA-Z0-9._-]+", "_", name.strip()) cleaned = re.sub(r"_+", "_", cleaned).strip("_") return cleaned or "metric" def parse_number(token: str) -> Optional[Tuple[float, bool]]: s = token.strip() is_percent = s.endswith("%") s = s.replace("%", "") match = re.search(r"[-+]?\d*\.?\d+(?:[eE][-+]?\d+)?", s) if not match: return None return float(match.group(0)), is_percent def append_series(series: Dict[str, List[Tuple[int, float]]], metric: str, epoch: Optional[int], value: float) -> None: points = series.setdefault(metric, []) x = epoch if x is None: x = points[-1][0] + 1 if points else 1 points.append((x, value)) def parse_metrics_from_log(log_path: Path) -> ParsedMetrics: parsed = ParsedMetrics() current_epoch: Optional[int] = None lines = log_path.read_text(encoding="utf-8", errors="ignore").splitlines() for raw in lines: line = raw.strip() if not line: continue epoch_match = re.search(r"^Epoch\s+(\d+)(?:/\d+)?$", line, flags=re.IGNORECASE) if epoch_match: current_epoch = int(epoch_match.group(1)) continue header_match = re.search(r"^METRIC\s*\|\s*(.+?)\s*\|\s*(.+?)\s*$", line, flags=re.IGNORECASE) if header_match: parsed.paired_labels = (header_match.group(1).strip(), header_match.group(2).strip()) continue train_loss_match = re.search(r"Train Loss:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", line, flags=re.IGNORECASE) if train_loss_match: append_series(parsed.series, "Train Loss", current_epoch, float(train_loss_match.group(1))) continue ade_fde_match = re.search( r"^ADE:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\s*,\s*FDE:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", line, flags=re.IGNORECASE, ) if ade_fde_match: append_series(parsed.series, "ADE", current_epoch, float(ade_fde_match.group(1))) append_series(parsed.series, "FDE", current_epoch, float(ade_fde_match.group(2))) continue val_ade_fde_match = re.search( r"^Val\s+ADE:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\s*\|\s*Val\s+FDE:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", line, flags=re.IGNORECASE, ) if val_ade_fde_match: append_series(parsed.series, "Val ADE", current_epoch, float(val_ade_fde_match.group(1))) append_series(parsed.series, "Val FDE", current_epoch, float(val_ade_fde_match.group(2))) continue lr_match = re.search(r"Current Learning Rate:\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", line, flags=re.IGNORECASE) if lr_match: append_series(parsed.series, "Learning Rate", current_epoch, float(lr_match.group(1))) continue lr_pair_match = re.search( r"LR\s+base=([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)\s*\|\s*fusion=([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", line, flags=re.IGNORECASE, ) if lr_pair_match: append_series(parsed.series, "LR base", current_epoch, float(lr_pair_match.group(1))) append_series(parsed.series, "LR fusion", current_epoch, float(lr_pair_match.group(2))) continue table_row_match = re.search(r"^(.+?)\|\s*([^|]+)\|\s*([^|]+)$", line) if table_row_match and "----" not in line and not line.upper().startswith("METRIC"): metric_name = table_row_match.group(1).strip() left_token = table_row_match.group(2).strip() right_token = table_row_match.group(3).strip() left_parsed = parse_number(left_token) right_parsed = parse_number(right_token) if left_parsed and right_parsed: left_val, left_is_pct = left_parsed right_val, right_is_pct = right_parsed parsed.paired[metric_name] = (left_val, right_val, left_is_pct or right_is_pct) # Alias validation trajectory metrics to generic names when only validation labels are present. if "ADE" not in parsed.series and "Val ADE" in parsed.series: parsed.series["ADE"] = list(parsed.series["Val ADE"]) if "FDE" not in parsed.series and "Val FDE" in parsed.series: parsed.series["FDE"] = list(parsed.series["Val FDE"]) return parsed def setup_theme() -> None: plt.rcParams.update( { "figure.facecolor": "#000000", "axes.facecolor": "#000000", "savefig.facecolor": "#000000", "text.color": "#FFFFFF", "axes.labelcolor": "#FFFFFF", "xtick.color": "#FFFFFF", "ytick.color": "#FFFFFF", "axes.edgecolor": "#FFFFFF", "font.family": "Calibri", "font.size": 20, } ) def create_series_page(metric_name: str, points: List[Tuple[int, float]], source_name: str, out_path: Path) -> None: points = sorted(points, key=lambda x: x[0]) x_vals = [p[0] for p in points] y_vals = [p[1] for p in points] fig, ax = plt.subplots(figsize=(13.333, 7.5), dpi=150) ax.plot(x_vals, y_vals, color="#FFFFFF", linewidth=3.0, marker="o", markersize=5) ax.set_title(metric_name, fontsize=42, weight="bold", pad=20) ax.set_xlabel("Epoch / Step", fontsize=24, labelpad=12) ax.set_ylabel(metric_name, fontsize=24, labelpad=12) ax.grid(True, linestyle="--", linewidth=0.8, color="#5E5E5E", alpha=0.6) for spine in ax.spines.values(): spine.set_linewidth(1.2) min_v = min(y_vals) max_v = max(y_vals) last_v = y_vals[-1] summary = f"Min: {min_v:.4f} Max: {max_v:.4f} Last: {last_v:.4f}" fig.text(0.5, 0.05, summary, ha="center", va="center", fontsize=22, color="#FFFFFF") fig.text(0.01, 0.01, f"Source: {source_name}", ha="left", va="bottom", fontsize=12, color="#D8D8D8") fig.tight_layout(rect=(0.02, 0.08, 0.98, 0.96)) fig.savefig(out_path) plt.close(fig) def create_paired_page( metric_name: str, left_value: float, right_value: float, is_percent: bool, left_label: str, right_label: str, source_name: str, out_path: Path, ) -> None: fig, ax = plt.subplots(figsize=(13.333, 7.5), dpi=150) labels = [left_label, right_label] vals = [left_value, right_value] bars = ax.bar(labels, vals, color=["#B8B8B8", "#FFFFFF"], width=0.55) suffix = "%" if is_percent else "" for bar, val in zip(bars, vals): ax.text( bar.get_x() + bar.get_width() / 2, bar.get_height(), f"{val:.2f}{suffix}", ha="center", va="bottom", fontsize=20, color="#FFFFFF", ) ax.set_title(metric_name, fontsize=42, weight="bold", pad=20) ax.set_ylabel(metric_name + (" (%)" if is_percent else ""), fontsize=24) ax.grid(True, axis="y", linestyle="--", linewidth=0.8, color="#5E5E5E", alpha=0.6) for spine in ax.spines.values(): spine.set_linewidth(1.2) fig.text(0.01, 0.01, f"Source: {source_name}", ha="left", va="bottom", fontsize=12, color="#D8D8D8") fig.tight_layout(rect=(0.02, 0.06, 0.98, 0.96)) fig.savefig(out_path) plt.close(fig) def create_unavailable_page(metric_name: str, source_name: str, out_path: Path) -> None: fig = plt.figure(figsize=(13.333, 7.5), dpi=150) fig.patch.set_facecolor("#000000") fig.text(0.5, 0.62, metric_name, ha="center", va="center", fontsize=48, color="#FFFFFF", weight="bold") fig.text(0.5, 0.44, "Not available in selected log", ha="center", va="center", fontsize=26, color="#FFFFFF") fig.text(0.01, 0.01, f"Source: {source_name}", ha="left", va="bottom", fontsize=12, color="#D8D8D8") fig.savefig(out_path) plt.close(fig) def pick_default_log() -> Path: candidates = list(LOG_DIR.glob("phase2_fusion_train_*.txt")) + list(LOG_DIR.glob("train_log_*.txt")) if not candidates: candidates = list(LOG_DIR.glob("*.txt")) if not candidates: raise FileNotFoundError("No .txt logs found in log folder.") return max(candidates, key=lambda p: p.stat().st_mtime) def main() -> None: parser = argparse.ArgumentParser(description="Generate one PPT-ready page per metric from training/evaluation logs.") parser.add_argument("--log-file", type=str, default="", help="Path to source log file. Default: latest train/eval log.") parser.add_argument( "--output-dir", type=str, default="", help="Directory to save generated metric pages. Default: log/ppt_metric_pages//", ) parser.add_argument( "--requested", type=str, default="ADE,FDE,Val ADE,Val FDE,Train Loss,MSE,F1,Precision,Recall,Accuracy", help="Comma-separated metrics to include as missing pages if absent.", ) parser.add_argument( "--include-missing-pages", action="store_true", help="Create a separate page for requested metrics that are not found in the log.", ) args = parser.parse_args() setup_theme() log_path = Path(args.log_file) if args.log_file else pick_default_log() if not log_path.is_absolute(): log_path = REPO_ROOT / log_path if not log_path.exists(): raise FileNotFoundError(f"Log file not found: {log_path}") output_dir = Path(args.output_dir) if args.output_dir else (LOG_DIR / "ppt_metric_pages" / log_path.stem) if not output_dir.is_absolute(): output_dir = REPO_ROOT / output_dir output_dir.mkdir(parents=True, exist_ok=True) # Keep output deterministic for presentation export by removing old pages from previous runs. for old_png in output_dir.glob("*.png"): old_png.unlink() parsed = parse_metrics_from_log(log_path) generated: List[str] = [] for metric_name in sorted(parsed.series.keys()): filename = f"{sanitize_filename(metric_name)}.png" out_path = output_dir / filename create_series_page(metric_name, parsed.series[metric_name], log_path.name, out_path) generated.append(metric_name) left_label, right_label = parsed.paired_labels for metric_name in sorted(parsed.paired.keys()): left_value, right_value, is_percent = parsed.paired[metric_name] filename = f"{sanitize_filename(metric_name)}_comparison.png" out_path = output_dir / filename create_paired_page( metric_name=metric_name, left_value=left_value, right_value=right_value, is_percent=is_percent, left_label=left_label, right_label=right_label, source_name=log_path.name, out_path=out_path, ) generated.append(metric_name) requested = [m.strip() for m in args.requested.split(",") if m.strip()] generated_canonical = {canonical_metric(m) for m in generated} missing = [m for m in requested if canonical_metric(m) not in generated_canonical] if args.include_missing_pages: for metric_name in missing: filename = f"{sanitize_filename(metric_name)}_not_available.png" out_path = output_dir / filename create_unavailable_page(metric_name, log_path.name, out_path) manifest_path = output_dir / "metrics_manifest.txt" manifest_lines: List[str] = [ f"Source log: {log_path}", f"Output directory: {output_dir}", "", "Detected metrics:", ] for m in sorted(set(generated)): manifest_lines.append(f"- {m}") manifest_lines.append("") manifest_lines.append("Requested but missing:") if missing: for m in missing: manifest_lines.append(f"- {m}") else: manifest_lines.append("- None") manifest_path.write_text("\n".join(manifest_lines), encoding="utf-8") print(f"Generated {len(list(output_dir.glob('*.png')))} metric pages in: {output_dir}") print(f"Manifest: {manifest_path}") if __name__ == "__main__": main()