Spaces:
Sleeping
Sleeping
| """Metric tracking RL environment.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| try: | |
| from ..analysis_tools import AnalysisContext, SharedAnalysisToolkit, available_analysis_methods | |
| from ..evaluation import EvaluationConfig | |
| from ..models import ( | |
| MetricTrackerRlAction, | |
| MetricTrackerRlObservation, | |
| MetricSubmissionRow, | |
| SyntheticAnomalyGenerator, | |
| ) | |
| from ..tasks import DEFAULT_TASK_ID, available_task_specs, get_task_spec | |
| from .data_generator import ( | |
| EpisodeConfig, | |
| EpisodeData, | |
| MetricDataGenerator, | |
| available_synthetic_generator_methods, | |
| ) | |
| except ImportError: | |
| from analysis_tools import AnalysisContext, SharedAnalysisToolkit, available_analysis_methods | |
| from models import ( | |
| MetricTrackerRlAction, | |
| MetricTrackerRlObservation, | |
| MetricSubmissionRow, | |
| SyntheticAnomalyGenerator, | |
| ) | |
| from tasks import DEFAULT_TASK_ID, available_task_specs, get_task_spec | |
| from server.data_generator import ( | |
| EpisodeConfig, | |
| EpisodeData, | |
| MetricDataGenerator, | |
| available_synthetic_generator_methods, | |
| ) | |
| from evaluation import EvaluationConfig | |
| class RewardConfig: | |
| """Compatibility wrapper around the evaluator configuration.""" | |
| evaluation: EvaluationConfig = EvaluationConfig() | |
| class MetricTrackerRlEnvironment(Environment): | |
| """Iterative multi-anomaly benchmark with safe analysis methods.""" | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__( | |
| self, | |
| generator: MetricDataGenerator | None = None, | |
| reward_config: RewardConfig | None = None, | |
| ) -> None: | |
| initial_task = get_task_spec(DEFAULT_TASK_ID) | |
| self._generator = generator or MetricDataGenerator() | |
| self._reward_config = reward_config or RewardConfig() | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._episode: EpisodeData | None = None | |
| self._completed = False | |
| self._debug_mode = False | |
| self._active_task = initial_task | |
| self._next_task_id = initial_task.task_id | |
| self._next_reset_config = initial_task.build_episode_config() | |
| self._last_analysis_result: dict | None = None | |
| self._expose_applied_generators = False | |
| def configure_next_reset( | |
| self, | |
| *, | |
| task_id: str | None = None, | |
| seed: int | None = None, | |
| scenario_family: str | None = None, | |
| difficulty: str | None = None, | |
| anomaly_density: str | None = None, | |
| anomaly_count: int | None = None, | |
| anomalies: list[dict] | list[SyntheticAnomalyGenerator] | None = None, | |
| ) -> None: | |
| """Update the configuration used for the next reset.""" | |
| base_task = get_task_spec(task_id or self._next_task_id) | |
| base_config = base_task.build_episode_config() if task_id else self._next_reset_config | |
| anomaly_generators = tuple( | |
| item if isinstance(item, SyntheticAnomalyGenerator) else SyntheticAnomalyGenerator(**item) | |
| for item in (anomalies or []) | |
| ) | |
| self._next_task_id = base_task.task_id | |
| self._next_reset_config = EpisodeConfig( | |
| seed=base_config.seed if seed is None else seed, | |
| scenario_family=base_config.scenario_family if scenario_family is None else scenario_family, | |
| difficulty=base_config.difficulty if difficulty is None else difficulty, | |
| anomaly_density=base_config.anomaly_density if anomaly_density is None else anomaly_density, | |
| anomaly_count=base_config.anomaly_count if anomaly_count is None else anomaly_count, | |
| anomaly_generators=anomaly_generators or base_config.anomaly_generators, | |
| ).normalized() | |
| def set_debug_mode(self, enabled: bool) -> None: | |
| """Enable or disable debug-only environment views.""" | |
| self._debug_mode = bool(enabled) | |
| def export_debug_snapshot(self) -> dict: | |
| """Return a developer-only debug snapshot for the active episode.""" | |
| if not self._debug_mode: | |
| raise RuntimeError("Debug mode is disabled.") | |
| if self._episode is None: | |
| return {} | |
| return { | |
| "config": self._episode.config.__dict__, | |
| "expected_payload": [row.model_dump() for row in self._episode.expected_rows], | |
| "anomaly_schedule": self._episode.anomaly_schedule, | |
| "applied_synthetic_generators": [ | |
| row.model_dump() for row in self._episode.applied_synthetic_generators | |
| ], | |
| } | |
| def reset( | |
| self, | |
| task_id: str | None = None, | |
| seed: int | None = None, | |
| scenario_family: str | None = None, | |
| difficulty: str | None = None, | |
| anomaly_density: str | None = None, | |
| anomaly_count: int | None = None, | |
| anomalies: list[dict] | list[SyntheticAnomalyGenerator] | None = None, | |
| ) -> MetricTrackerRlObservation: | |
| """Generate a fresh dataset and hidden target payload.""" | |
| if any(value is not None for value in (task_id, seed, scenario_family, difficulty, anomaly_density, anomaly_count)) or anomalies is not None: | |
| self.configure_next_reset( | |
| task_id=task_id, | |
| seed=seed, | |
| scenario_family=scenario_family, | |
| difficulty=difficulty, | |
| anomaly_density=anomaly_density, | |
| anomaly_count=anomaly_count, | |
| anomalies=anomalies, | |
| ) | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._active_task = get_task_spec(self._next_task_id) | |
| self._episode = self._generator.generate_episode(self._next_reset_config) | |
| self._completed = False | |
| self._last_analysis_result = None | |
| self._expose_applied_generators = anomalies is not None | |
| return self._build_observation( | |
| status="ready", | |
| message=self._active_task.objective, | |
| reward=0.0, | |
| done=False, | |
| ) | |
| def step(self, action: MetricTrackerRlAction) -> MetricTrackerRlObservation: # type: ignore[override] | |
| """Evaluate a submitted payload and return deterministic feedback.""" | |
| if self._episode is None: | |
| return self.reset() | |
| if self._completed: | |
| return self._build_observation( | |
| status="completed", | |
| message="Dataset already solved. Call reset() to create a new dataset.", | |
| reward=1.0, | |
| done=True, | |
| submitted_rows=action.classifications, | |
| ) | |
| if action.analysis_method: | |
| self._state.step_count += 1 | |
| analysis_result = self._run_analysis(action.analysis_method, action.analysis_args) | |
| self._last_analysis_result = analysis_result | |
| return self._build_observation( | |
| status="analyzed", | |
| message=f"Ran analysis method `{action.analysis_method}`.", | |
| reward=0.0, | |
| done=False, | |
| analysis_result=analysis_result, | |
| ) | |
| submitted_rows = action.classifications | |
| generated_rows: list[MetricSubmissionRow] = [] | |
| if action.payload_generators: | |
| generator_result = self._run_analysis( | |
| "payload_generator", | |
| {"generator_methods": [item.model_dump() for item in action.payload_generators]}, | |
| ) | |
| self._last_analysis_result = generator_result | |
| generated_rows = [ | |
| MetricSubmissionRow(**row) | |
| for row in generator_result["result"]["generated_rows"] | |
| ] | |
| submitted_rows = generated_rows | |
| self._state.step_count += 1 | |
| result = self._active_task.grade_submission( | |
| submitted_rows, | |
| self._episode.expected_rows, | |
| config=self._reward_config.evaluation, | |
| include_debug_expected=self._debug_mode, | |
| ) | |
| self._completed = result.is_perfect | |
| reward = result.reward_breakdown.total_score | |
| message = self._submission_message(result) | |
| return self._build_observation( | |
| status="evaluated" if result.is_perfect else "in_progress", | |
| message=message, | |
| reward=reward, | |
| done=result.is_perfect, | |
| submitted_rows=result.preview.normalized_rows, | |
| reward_breakdown=result.reward_breakdown, | |
| submission_preview=result.preview, | |
| issues=result.issues, | |
| correct_row_count=result.matched_rows, | |
| analysis_result=self._last_analysis_result, | |
| generated_rows=generated_rows, | |
| ) | |
| def state(self) -> State: | |
| """Return current episode state.""" | |
| return self._state | |
| def _build_observation( | |
| self, | |
| *, | |
| status: str, | |
| message: str, | |
| reward: float, | |
| done: bool, | |
| submitted_rows=None, | |
| reward_breakdown=None, | |
| submission_preview=None, | |
| issues=None, | |
| correct_row_count: int = 0, | |
| analysis_result=None, | |
| generated_rows=None, | |
| ) -> MetricTrackerRlObservation: | |
| assert self._episode is not None | |
| metadata = { | |
| "step": self._state.step_count, | |
| "current_state": self.state.model_dump(), | |
| "task_id": self._active_task.task_id, | |
| "objective": self._active_task.objective, | |
| "grader_name": self._active_task.grader_name, | |
| "seed": self._episode.config.seed, | |
| "scenario_family": self._episode.config.scenario_family, | |
| "difficulty": self._episode.config.difficulty, | |
| "anomaly_density": self._episode.config.anomaly_density, | |
| "anomaly_count": self._episode.config.anomaly_count, | |
| } | |
| return MetricTrackerRlObservation( | |
| task_id=self._active_task.task_id, | |
| status=status, | |
| message=message, | |
| instruction=self._active_task.instruction, | |
| conversion_metric_definitions=list(self._generator.config.conversion_definitions), | |
| available_synthetic_generator_methods=available_synthetic_generator_methods(), | |
| applied_synthetic_generators=( | |
| self._episode.applied_synthetic_generators | |
| if self._debug_mode or self._expose_applied_generators | |
| else [] | |
| ), | |
| available_methods=available_analysis_methods(), | |
| available_tasks=available_task_specs(), | |
| daily_metrics=[], | |
| hourly_metrics=[], | |
| analysis_result=analysis_result, | |
| generated_rows=generated_rows or [], | |
| submitted_rows=submitted_rows or [], | |
| submission_preview=submission_preview, | |
| submission_issues=issues or [], | |
| reward_breakdown=reward_breakdown, | |
| expected_row_count=len(self._episode.expected_rows), | |
| correct_row_count=correct_row_count, | |
| reward=reward, | |
| done=done, | |
| config=metadata, | |
| debug=( | |
| { | |
| "task_id": self._active_task.task_id, | |
| "expected_payload": [row.model_dump() for row in self._episode.expected_rows], | |
| "anomaly_schedule": self._episode.anomaly_schedule, | |
| "reward_breakdown": reward_breakdown.model_dump() if reward_breakdown else None, | |
| "issues": [item.model_dump() for item in (issues or [])], | |
| } | |
| if self._debug_mode | |
| else None | |
| ), | |
| ) | |
| def _run_analysis(self, method_name: str, arguments: dict) -> dict: | |
| toolkit = SharedAnalysisToolkit( | |
| AnalysisContext( | |
| daily_metrics=self._episode.daily_metrics, | |
| hourly_metrics=self._episode.hourly_metrics, | |
| conversion_definitions=list(self._generator.config.conversion_definitions), | |
| instruction=self._active_task.instruction, | |
| config={ | |
| "task_id": self._active_task.task_id, | |
| "objective": self._active_task.objective, | |
| "grader_name": self._active_task.grader_name, | |
| **self._episode.config.__dict__, | |
| }, | |
| ) | |
| ) | |
| if method_name == "task_overview": | |
| result = toolkit.task_overview() | |
| elif method_name == "list_dates": | |
| result = toolkit.list_dates() | |
| elif method_name == "list_entities": | |
| result = toolkit.list_entities() | |
| elif method_name == "rows_for_date": | |
| result = toolkit.rows_for_date(arguments["date"]) | |
| elif method_name == "hourly_rows_for_date": | |
| result = toolkit.hourly_rows_for_date(arguments["date"]) | |
| elif method_name == "compare_rate_to_median": | |
| result = toolkit.compare_rate_to_median(arguments["date"], arguments["entity_name"]) | |
| elif method_name == "compare_count_to_median": | |
| result = toolkit.compare_count_to_median(arguments["date"], arguments["entity_name"]) | |
| elif method_name == "detect_funnel_break": | |
| result = toolkit.detect_funnel_break(arguments["date"]) | |
| elif method_name == "check_impossible_counts": | |
| result = toolkit.check_impossible_counts(arguments["date"]) | |
| elif method_name == "list_suspicious_dates": | |
| result = toolkit.list_suspicious_dates(limit=arguments.get("limit", 10)) | |
| elif method_name == "preview_submission": | |
| result = toolkit.preview_submission(arguments.get("rows", [])) | |
| elif method_name == "show_raw_data": | |
| result = toolkit.show_raw_data(limit=arguments.get("limit", 5)) | |
| elif method_name == "get_metric_median": | |
| result = toolkit.get_metric_median_multi( | |
| metric_name=arguments.get("metric_name"), | |
| metric_names=arguments.get("metric_names", []), | |
| ) | |
| elif method_name == "get_metric_std_dev_from_median": | |
| result = toolkit.get_metric_std_dev_from_median_multi( | |
| metric_name=arguments.get("metric_name"), | |
| metric_names=arguments.get("metric_names", []), | |
| ) | |
| elif method_name == "get_rows_with_abs_diff_from_median_gt": | |
| result = toolkit.get_rows_with_abs_diff_from_median_gt_multi( | |
| metric_name=arguments.get("metric_name"), | |
| metric_names=arguments.get("metric_names", []), | |
| threshold=float(arguments["threshold"]), | |
| ) | |
| elif method_name == "get_median_filter_rows": | |
| result = toolkit.get_median_filter_rows_multi( | |
| metric_name=arguments.get("metric_name"), | |
| metric_names=arguments.get("metric_names", []), | |
| threshold_multiplier=float(arguments["threshold_multiplier"]), | |
| ) | |
| elif method_name == "get_rate_drop_from_median_rows": | |
| result = toolkit.get_rate_drop_from_median_rows( | |
| metric_name=arguments.get("metric_name"), | |
| metric_names=arguments.get("metric_names", []), | |
| threshold_multiplier=float(arguments["threshold_multiplier"]), | |
| ) | |
| elif method_name == "get_rate_spike_from_median_rows": | |
| result = toolkit.get_rate_spike_from_median_rows( | |
| metric_name=arguments.get("metric_name"), | |
| metric_names=arguments.get("metric_names", []), | |
| threshold_multiplier=float(arguments["threshold_multiplier"]), | |
| ) | |
| elif method_name == "get_absolute_drop_in_event_count_rows": | |
| result = toolkit.get_absolute_drop_in_event_count_rows( | |
| metric_name=arguments.get("metric_name"), | |
| metric_names=arguments.get("metric_names", []), | |
| threshold_multiplier=float(arguments["threshold_multiplier"]), | |
| ) | |
| elif method_name == "get_absolute_spike_in_event_count_rows": | |
| result = toolkit.get_absolute_spike_in_event_count_rows( | |
| metric_name=arguments.get("metric_name"), | |
| metric_names=arguments.get("metric_names", []), | |
| threshold_multiplier=float(arguments["threshold_multiplier"]), | |
| ) | |
| elif method_name == "get_funnel_break_rows": | |
| result = toolkit.get_funnel_break_rows( | |
| threshold_multiplier=float(arguments["threshold_multiplier"]), | |
| ) | |
| elif method_name == "get_hourly_traffic_mix_shift_rows": | |
| result = toolkit.get_hourly_traffic_mix_shift_rows( | |
| threshold_multiplier=float(arguments["threshold_multiplier"]), | |
| ) | |
| elif method_name == "get_instrumentation_data_quality_issue_rows": | |
| result = toolkit.get_instrumentation_data_quality_issue_rows( | |
| threshold_multiplier=float(arguments["threshold_multiplier"]), | |
| ) | |
| elif method_name == "payload_generator": | |
| result = toolkit.payload_generator(arguments.get("generator_methods", [])) | |
| else: | |
| raise ValueError(f"Unsupported analysis method: {method_name}") | |
| return { | |
| "method": method_name, | |
| "arguments": arguments, | |
| "result": result, | |
| } | |
| def _submission_message(result) -> str: | |
| if result.is_perfect: | |
| return "Submission is fully correct." | |
| extra_issues = [issue for issue in result.issues if issue.issue_type == "extra_row"] | |
| missing_count = result.reward_breakdown.missing_rows | |
| if not extra_issues and missing_count > 0: | |
| return ( | |
| "All submitted rows are anomalies, but a few are missing. " | |
| f"Missing value count: {missing_count}." | |
| ) | |
| if extra_issues: | |
| first = extra_issues[0] | |
| return f"Specific row is not an anomaly: {first.row_key}." | |
| return ( | |
| f"Matched {result.reward_breakdown.matched_rows}/" | |
| f"{result.reward_breakdown.expected_rows} expected rows. Review the feedback." | |
| ) | |