Spaces:
Sleeping
Sleeping
feat: introduce reward ablation configurations for enhanced training flexibility, implement YAML loading with extends support, and add reward variant tracking in training scripts
f7b8ac6 | """Configurable reward shaping settings for CyberSecurity_OWASP.""" | |
| from __future__ import annotations | |
| import hashlib | |
| import json | |
| import os | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any | |
| import yaml | |
| DEFAULT_GRPO_CONFIG_PATH = ( | |
| Path(__file__).resolve().parent / "training" / "configs" / "grpo_small.yaml" | |
| ) | |
| REWARD_MODES = {"dense_train", "sparse_eval"} | |
| REWARD_STAGES = {"early", "middle", "late", "final"} | |
| class RewardSettings: | |
| """Loaded reward settings with stage-aware helpers.""" | |
| mode: str | |
| training_mode: str | |
| stage: str | |
| raw: dict[str, Any] | |
| source_path: str | |
| def dense_train(self) -> bool: | |
| return self.mode == "dense_train" | |
| def shaping_weight(self) -> float: | |
| override = os.getenv("CYBERSECURITY_OWASP_SHAPING_WEIGHT") | |
| if override is not None: | |
| return float(override) | |
| return self.value("shaping_weight", 0.0) | |
| def entry(self, name: str) -> dict[str, Any]: | |
| value = self.raw.get(name, {}) | |
| return value if isinstance(value, dict) else {} | |
| def value(self, name: str, default: float = 0.0) -> float: | |
| entry = self.entry(name) | |
| if self.stage in entry: | |
| return float(entry[self.stage]) | |
| if "value" in entry: | |
| return float(entry["value"]) | |
| return float(default) | |
| def cap(self, name: str, default: float | None = None) -> float | None: | |
| entry = self.entry(name) | |
| if "cap" not in entry: | |
| return default | |
| return float(entry["cap"]) | |
| def int_value(self, name: str, key: str, default: int) -> int: | |
| entry = self.entry(name) | |
| return int(entry.get(key, default)) | |
| def terminate(self, name: str) -> bool: | |
| return bool(self.entry(name).get("terminate", False)) | |
| def load_reward_settings(path: str | Path | None = None) -> RewardSettings: | |
| """Load reward settings from the GRPO YAML config with env overrides.""" | |
| configured_path = Path( | |
| path | |
| or os.getenv("CYBERSECURITY_OWASP_REWARD_CONFIG", "") | |
| or DEFAULT_GRPO_CONFIG_PATH | |
| ) | |
| raw = _load_yaml_with_extends(configured_path) | |
| reward = dict(raw.get("reward") or {}) | |
| mode = os.getenv("CYBERSECURITY_OWASP_REWARD_MODE", str(reward.get("mode", "sparse_eval"))) | |
| training_mode = str(reward.get("training_mode", "dense_train")) | |
| stage = os.getenv("CYBERSECURITY_OWASP_REWARD_STAGE", str(reward.get("stage", "early"))) | |
| settings = RewardSettings( | |
| mode=mode, | |
| training_mode=training_mode, | |
| stage=stage, | |
| raw=reward, | |
| source_path=str(configured_path), | |
| ) | |
| validate_reward_settings(settings) | |
| return settings | |
| def _load_yaml_with_extends(path: Path, seen: set[Path] | None = None) -> dict[str, Any]: | |
| """Load a YAML file, recursively merging an optional relative `extends` file.""" | |
| resolved_path = path.expanduser().resolve() | |
| seen = seen or set() | |
| if resolved_path in seen: | |
| chain = " -> ".join(str(item) for item in [*seen, resolved_path]) | |
| raise ValueError(f"reward config extends cycle detected: {chain}") | |
| seen.add(resolved_path) | |
| raw = yaml.safe_load(resolved_path.read_text(encoding="utf-8")) or {} | |
| if not isinstance(raw, dict): | |
| raise ValueError(f"reward config must be a YAML mapping: {resolved_path}") | |
| extends = raw.get("extends") | |
| if not extends: | |
| return raw | |
| if not isinstance(extends, str): | |
| raise ValueError("reward config extends must be a string path") | |
| base_path = Path(extends) | |
| if not base_path.is_absolute(): | |
| base_path = resolved_path.parent / base_path | |
| child = {key: value for key, value in raw.items() if key != "extends"} | |
| return _deep_merge(_load_yaml_with_extends(base_path, seen), child) | |
| def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]: | |
| merged = dict(base) | |
| for key, value in override.items(): | |
| base_value = merged.get(key) | |
| if isinstance(base_value, dict) and isinstance(value, dict): | |
| merged[key] = _deep_merge(base_value, value) | |
| else: | |
| merged[key] = value | |
| return merged | |
| def flatten_reward_config( | |
| settings: RewardSettings | None = None, | |
| ) -> list[dict[str, Any]]: | |
| """Return display-friendly reward config rows for tracking dashboards.""" | |
| settings = settings or load_reward_settings() | |
| rows: list[dict[str, Any]] = [] | |
| for key in sorted(settings.raw): | |
| entry = settings.raw[key] | |
| if not isinstance(entry, dict): | |
| continue | |
| has_resolved_value = "value" in entry or settings.stage in entry | |
| rows.append( | |
| { | |
| "key": key, | |
| "value": _empty_if_missing(entry.get("value")), | |
| "stage_value": _empty_if_missing(entry.get(settings.stage)), | |
| "resolved": settings.value(key, 0.0) if has_resolved_value else "", | |
| "cap": _empty_if_missing(entry.get("cap")), | |
| "threshold": _empty_if_missing( | |
| entry.get("threshold", entry.get("threshold_lines")) | |
| ), | |
| "severe_threshold": _empty_if_missing( | |
| entry.get("severe_threshold", entry.get("severe_threshold_lines")) | |
| ), | |
| "terminate": bool(entry.get("terminate", False)), | |
| "description": str(entry.get("description", "")), | |
| } | |
| ) | |
| return rows | |
| def reward_config_hash(settings: RewardSettings | None = None) -> str: | |
| """Return a deterministic hash for the effective reward configuration.""" | |
| settings = settings or load_reward_settings() | |
| payload = { | |
| "mode": settings.mode, | |
| "training_mode": settings.training_mode, | |
| "stage": settings.stage, | |
| "shaping_weight": settings.shaping_weight, | |
| "raw": _strip_descriptions(settings.raw), | |
| } | |
| encoded = json.dumps(payload, sort_keys=True, separators=(",", ":"), default=str) | |
| return hashlib.sha256(encoded.encode("utf-8")).hexdigest() | |
| def reward_config_summary(settings: RewardSettings | None = None) -> dict[str, Any]: | |
| """Return reward config identity and flattened rows for run metadata.""" | |
| settings = settings or load_reward_settings() | |
| config_hash = reward_config_hash(settings) | |
| source = Path(settings.source_path) | |
| return { | |
| "reward_config_id": ( | |
| f"{source.stem}-{settings.mode}-{settings.stage}-{config_hash[:12]}" | |
| ), | |
| "reward_config_hash": config_hash, | |
| "reward_config_source": str(source), | |
| "reward_config_source_name": source.name, | |
| "reward_mode": settings.mode, | |
| "reward_training_mode": settings.training_mode, | |
| "reward_stage": settings.stage, | |
| "reward_shaping_weight": settings.shaping_weight, | |
| "reward_entries": flatten_reward_config(settings), | |
| } | |
| def reward_config_run_config(settings: RewardSettings | None = None) -> dict[str, Any]: | |
| """Return compact reward config fields safe to store in Trackio run config.""" | |
| summary = reward_config_summary(settings) | |
| reward_values = { | |
| str(row["key"]): { | |
| key: value | |
| for key, value in row.items() | |
| if key != "key" and value != "" | |
| } | |
| for row in summary["reward_entries"] | |
| } | |
| config = { | |
| "reward_config_id": summary["reward_config_id"], | |
| "reward_config_hash": summary["reward_config_hash"], | |
| "reward_config_source": summary["reward_config_source"], | |
| "reward_config_source_name": summary["reward_config_source_name"], | |
| "reward_variant": os.getenv("CYBERSECURITY_OWASP_REWARD_VARIANT", "default") or "default", | |
| "reward_mode": summary["reward_mode"], | |
| "reward_training_mode": summary["reward_training_mode"], | |
| "reward_stage": summary["reward_stage"], | |
| "reward_shaping_weight": summary["reward_shaping_weight"], | |
| "reward_config_values": reward_values, | |
| "reward_config_values_json": json.dumps(reward_values, sort_keys=True), | |
| } | |
| for reward_key, values in reward_values.items(): | |
| safe_reward_key = _config_key_safe(reward_key) | |
| for field, value in values.items(): | |
| if isinstance(value, (int, float, bool)): | |
| config[f"reward_config__{safe_reward_key}__{field}"] = value | |
| return config | |
| def validate_reward_settings(settings: RewardSettings) -> None: | |
| if settings.mode not in REWARD_MODES: | |
| raise ValueError("reward.mode must be dense_train or sparse_eval") | |
| if settings.training_mode not in REWARD_MODES: | |
| raise ValueError("reward.training_mode must be dense_train or sparse_eval") | |
| if settings.stage not in REWARD_STAGES: | |
| raise ValueError("reward.stage must be early, middle, late, or final") | |
| for key, value in settings.raw.items(): | |
| if not isinstance(value, dict): | |
| continue | |
| if not str(value.get("description", "")).strip(): | |
| raise ValueError(f"reward.{key}.description is required") | |
| def _empty_if_missing(value: Any) -> Any: | |
| return "" if value is None else value | |
| def _strip_descriptions(value: Any) -> Any: | |
| if isinstance(value, dict): | |
| return { | |
| str(key): _strip_descriptions(item) | |
| for key, item in value.items() | |
| if key != "description" | |
| } | |
| if isinstance(value, list): | |
| return [_strip_descriptions(item) for item in value] | |
| return value | |
| def _config_key_safe(value: str) -> str: | |
| return "".join(char if char.isalnum() or char == "_" else "_" for char in value).strip("_") | |
| def compute_token_penalty( | |
| completion_tokens: int, | |
| settings: RewardSettings | None = None, | |
| ) -> float: | |
| """Return the trainer-side token penalty for a completion.""" | |
| settings = settings or load_reward_settings() | |
| if not settings.dense_train: | |
| return 0.0 | |
| target = settings.int_value("token_penalty", "target_tokens", 350) | |
| excess = max(0, int(completion_tokens) - target) | |
| penalty = settings.value("token_penalty", 0.0) * excess | |
| cap = settings.cap("token_penalty", -0.5) | |
| return max(penalty, cap if cap is not None else penalty) | |