from __future__ import annotations import copy import hashlib from typing import Any import yaml from .data import TASK_REGISTRY, TaskSpec from .models import ConfigAction, ConfigObservation, ConfigReward, EnvState, TaskType class ConfigDebuggerEnv: def __init__(self) -> None: self.task_spec: TaskSpec | None = None self.task_id: TaskType | None = None self.current_config_text: str = "" self.previous_score: float = 0.0 self.step_count: int = 0 self.done: bool = False self.max_steps: int = 15 self.last_reward: ConfigReward | None = None self._state_visit_count: dict[str, int] = {} def reset(self, task_id: TaskType | str) -> ConfigObservation: normalized_task_id = task_id.value if isinstance(task_id, TaskType) else str(task_id) if normalized_task_id not in TASK_REGISTRY: valid = ", ".join(TASK_REGISTRY.keys()) raise ValueError(f"Unknown task_id '{task_id}'. Valid task ids: {valid}") spec = TASK_REGISTRY[normalized_task_id] self.task_spec = spec self.task_id = TaskType(normalized_task_id) self.current_config_text = spec.broken self.step_count = 0 self.done = False self.max_steps = spec.max_steps self._state_visit_count = {} initial_score = self._grade(self.current_config_text)["overall"] self.previous_score = initial_score self.last_reward = None self._track_state_visit(self.current_config_text) return self._build_observation() def step(self, action: ConfigAction) -> tuple[ConfigObservation, ConfigReward, bool, dict[str, Any]]: if self.task_spec is None or self.task_id is None: raise RuntimeError("Environment is not initialized. Call reset() first.") if self.done: obs = self._build_observation() reward = ConfigReward( value=0.0, previous_score=self.previous_score, current_score=self.previous_score, delta=0.0, penalties=["episode_already_done"], ) self.last_reward = reward return obs, reward, True, {"reason": "episode_already_done"} self.step_count += 1 penalties: list[str] = [] try: new_config_text, action_penalties = self._apply_action(self.current_config_text, action) penalties.extend(action_penalties) self.current_config_text = new_config_text except Exception as exc: penalties.append(f"invalid_action:{exc}") grading = self._grade(self.current_config_text) current_score = grading["overall"] delta = round(current_score - self.previous_score, 4) loop_penalty = self._track_state_visit(self.current_config_text) if loop_penalty > 0: penalties.append(f"loop_penalty:{loop_penalty:.2f}") reward_value = self._compute_reward(current_score, delta, penalties, loop_penalty) reward = ConfigReward( value=reward_value, previous_score=round(self.previous_score, 4), current_score=round(current_score, 4), delta=delta, penalties=penalties, ) self.previous_score = current_score self.done = current_score >= 0.98 or self.step_count >= self.max_steps self.last_reward = reward info = { "task_id": self.task_id.value, "schema_score": grading["schema"], "logic_score": grading["logic"], "syntax_valid": grading["syntax_valid"], } return self._build_observation(grading), reward, self.done, info def state(self) -> EnvState: observation = self._build_observation() if self.task_spec is not None else None return EnvState( task_id=self.task_id, done=self.done, step_count=self.step_count, max_steps=self.max_steps, observation=observation, last_reward=self.last_reward, ) def _build_observation(self, grading: dict[str, Any] | None = None) -> ConfigObservation: if self.task_spec is None or self.task_id is None: raise RuntimeError("Environment is not initialized. Call reset() first.") if grading is None: grading = self._grade(self.current_config_text) return ConfigObservation( task_id=self.task_id, task_description=self.task_spec.description, current_config=self.current_config_text, syntax_valid=grading["syntax_valid"], validation_errors=grading["errors"], schema_score=grading["schema"], logic_score=grading["logic"], overall_score=grading["overall"], step_count=self.step_count, max_steps=self.max_steps, ) def _compute_reward(self, current_score: float, delta: float, penalties: list[str], loop_penalty: float) -> float: reward = current_score if delta > 0: reward += min(0.15, delta) elif delta < 0: reward += delta * 0.4 penalty_total = loop_penalty if any(p.startswith("invalid_action") for p in penalties): penalty_total += 0.10 if any(p.startswith("destructive_delete") for p in penalties): penalty_total += 0.08 reward -= penalty_total if current_score >= 0.98: reward += 0.05 return round(max(0.0, min(1.0, reward)), 4) def _track_state_visit(self, config_text: str) -> float: state_hash = hashlib.sha1(config_text.encode("utf-8")).hexdigest() count = self._state_visit_count.get(state_hash, 0) + 1 self._state_visit_count[state_hash] = count # Penalize repeated states to discourage loops. if count <= 1: return 0.0 return min(0.03 * (count - 1), 0.12) def _apply_action(self, config_text: str, action: ConfigAction) -> tuple[str, list[str]]: penalties: list[str] = [] data = yaml.safe_load(config_text) if data is None: data = {} if not isinstance(data, dict): raise ValueError("current config is not a dictionary-like YAML document") root = copy.deepcopy(data) tokens = self._parse_path(action.path) if action.operation == "delete" and tokens and isinstance(tokens[0], str): if tokens[0] in {"services", "spec", "training", "hardware"} and len(tokens) == 1: penalties.append("destructive_delete:top_level_critical_key") if action.operation in {"edit", "add"}: self._set_path(root, tokens, action.value) else: deleted = self._delete_path(root, tokens) if not deleted: penalties.append("delete_noop") dumped = yaml.safe_dump(root, sort_keys=False) return dumped, penalties def _parse_path(self, path: str) -> list[str | int]: tokens: list[str | int] = [] for chunk in path.split("."): chunk = chunk.strip() if chunk == "": raise ValueError("path contains empty token") if chunk.isdigit(): tokens.append(int(chunk)) else: tokens.append(chunk) return tokens def _set_path(self, root: dict[str, Any], tokens: list[str | int], value: Any) -> None: if not tokens: raise ValueError("cannot set empty path") cursor: Any = root for i, token in enumerate(tokens[:-1]): nxt = tokens[i + 1] if isinstance(token, int): if not isinstance(cursor, list): raise ValueError("list index used on non-list node") while token >= len(cursor): cursor.append({} if isinstance(nxt, str) else []) if cursor[token] is None: cursor[token] = {} if isinstance(nxt, str) else [] cursor = cursor[token] else: if not isinstance(cursor, dict): raise ValueError("dict key used on non-dict node") if token not in cursor or cursor[token] is None: cursor[token] = {} if isinstance(nxt, str) else [] cursor = cursor[token] final = tokens[-1] if isinstance(final, int): if not isinstance(cursor, list): raise ValueError("final list index used on non-list node") while final >= len(cursor): cursor.append(None) cursor[final] = value else: if not isinstance(cursor, dict): raise ValueError("final dict key used on non-dict node") cursor[final] = value def _delete_path(self, root: dict[str, Any], tokens: list[str | int]) -> bool: if not tokens: return False cursor: Any = root for token in tokens[:-1]: if isinstance(token, int): if not isinstance(cursor, list) or token >= len(cursor): return False cursor = cursor[token] else: if not isinstance(cursor, dict) or token not in cursor: return False cursor = cursor[token] final = tokens[-1] if isinstance(final, int): if not isinstance(cursor, list) or final >= len(cursor): return False cursor.pop(final) return True if not isinstance(cursor, dict) or final not in cursor: return False del cursor[final] return True def _grade(self, config_text: str) -> dict[str, Any]: assert self.task_spec is not None errors: list[str] = [] try: parsed = yaml.safe_load(config_text) except Exception as exc: return { "syntax_valid": False, "schema": 0.0, "logic": 0.0, "overall": 0.0, "errors": [f"YAML syntax error: {exc}"], } if parsed is None: parsed = {} if not isinstance(parsed, dict): return { "syntax_valid": True, "schema": 0.0, "logic": 0.0, "overall": 0.0, "errors": ["Root document must be a mapping/dict"], } schema_score, schema_errors = self._grade_schema(parsed) logic_score, logic_errors = self._grade_logic(parsed) errors.extend(schema_errors) errors.extend(logic_errors) overall = round((0.60 * schema_score) + (0.40 * logic_score), 4) return { "syntax_valid": True, "schema": schema_score, "logic": logic_score, "overall": overall, "errors": errors[:20], } def _grade_schema(self, parsed: dict[str, Any]) -> tuple[float, list[str]]: assert self.task_spec is not None total_weight = 0.0 matched_weight = 0.0 errors: list[str] = [] for path, weight in self.task_spec.required_paths.items(): total_weight += weight expected = self._read_path(self.task_spec.target, self._parse_path(path)) got, exists = self._safe_read(parsed, self._parse_path(path)) if not exists: errors.append(f"Missing required path: {path}") continue if got == expected: matched_weight += weight else: errors.append(f"Mismatch at {path}: expected={expected!r}, got={got!r}") score = 0.0 if total_weight == 0 else round(matched_weight / total_weight, 4) return score, errors def _grade_logic(self, parsed: dict[str, Any]) -> tuple[float, list[str]]: assert self.task_spec is not None checks: list[tuple[str, bool]] = [] t = self.task_spec.task_id if t == "easy_docker": web_ports = self._safe_get(parsed, ["services", "web", "ports"], default=[]) db_ports = self._safe_get(parsed, ["services", "db", "ports"], default=[]) env_node = self._safe_get(parsed, ["services", "web", "environment"], default={}) checks.append(("web ports must be list", isinstance(web_ports, list))) checks.append(("all web ports must contain ':'", all(isinstance(p, str) and ":" in p for p in web_ports))) checks.append(("db port must include host and container", "5432:5432" in db_ports if isinstance(db_ports, list) else False)) checks.append(("environment must be dict", isinstance(env_node, dict))) elif t == "medium_k8s": replicas = self._safe_get(parsed, ["spec", "replicas"], default=None) limits_mem = self._safe_get( parsed, ["spec", "template", "spec", "containers", 0, "resources", "limits", "memory"], default="", ) req_mem = self._safe_get( parsed, ["spec", "template", "spec", "containers", 0, "resources", "requests", "memory"], default="", ) req_cpu = self._safe_get( parsed, ["spec", "template", "spec", "containers", 0, "resources", "requests", "cpu"], default="", ) checks.append(("replicas should be int", isinstance(replicas, int))) checks.append(("limits memory must include unit", isinstance(limits_mem, str) and limits_mem.endswith(("Mi", "Gi")))) checks.append(("requests memory must include unit", isinstance(req_mem, str) and req_mem.endswith(("Mi", "Gi")))) checks.append(("cpu request should be millicore string", isinstance(req_cpu, str) and req_cpu.endswith("m"))) elif t == "hard_ml_config": warmup = self._safe_get(parsed, ["training", "warmup_steps"], default=0) max_steps = self._safe_get(parsed, ["training", "max_steps"], default=0) use_cuda = self._safe_get(parsed, ["hardware", "use_cuda"], default=False) gpu_count = self._safe_get(parsed, ["hardware", "gpu_count"], default=0) batch_size = self._safe_get(parsed, ["training", "batch_size"], default=0) train_batch = self._safe_get(parsed, ["data", "train_batch_size"], default=0) log_interval = self._safe_get(parsed, ["logging", "log_interval"], default=999999) checks.append(("warmup_steps < max_steps", isinstance(warmup, int) and isinstance(max_steps, int) and warmup < max_steps)) checks.append(("gpu_count >=1 when use_cuda", (not use_cuda) or (isinstance(gpu_count, int) and gpu_count >= 1))) checks.append(("train_batch_size equals 2 * batch_size", isinstance(batch_size, int) and isinstance(train_batch, int) and train_batch == 2 * batch_size)) checks.append(("log_interval <= 100", isinstance(log_interval, int) and log_interval <= 100)) total = len(checks) passed = sum(1 for _, ok in checks if ok) errors = [msg for msg, ok in checks if not ok] score = 0.0 if total == 0 else round(passed / total, 4) return score, errors def _read_path(self, source: Any, tokens: list[str | int]) -> Any: cursor = source for token in tokens: if isinstance(token, int): cursor = cursor[token] else: cursor = cursor[token] return cursor def _safe_read(self, source: Any, tokens: list[str | int]) -> tuple[Any, bool]: cursor = source for token in tokens: try: if isinstance(token, int): if not isinstance(cursor, list): return None, False cursor = cursor[token] else: if not isinstance(cursor, dict) or token not in cursor: return None, False cursor = cursor[token] except Exception: return None, False return cursor, True def _safe_get(self, source: Any, tokens: list[str | int], default: Any) -> Any: value, exists = self._safe_read(source, tokens) return value if exists else default