metric_tracker_rl / server /metric_tracker_rl_environment.py
kushalExplores's picture
Upload folder using huggingface_hub
e415506 verified
"""Metric tracking RL environment."""
from __future__ import annotations
from dataclasses import dataclass
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
try:
from ..analysis_tools import AnalysisContext, SharedAnalysisToolkit, available_analysis_methods
from ..evaluation import EvaluationConfig
from ..models import (
MetricTrackerRlAction,
MetricTrackerRlObservation,
MetricSubmissionRow,
SyntheticAnomalyGenerator,
)
from ..tasks import DEFAULT_TASK_ID, available_task_specs, get_task_spec
from .data_generator import (
EpisodeConfig,
EpisodeData,
MetricDataGenerator,
available_synthetic_generator_methods,
)
except ImportError:
from analysis_tools import AnalysisContext, SharedAnalysisToolkit, available_analysis_methods
from models import (
MetricTrackerRlAction,
MetricTrackerRlObservation,
MetricSubmissionRow,
SyntheticAnomalyGenerator,
)
from tasks import DEFAULT_TASK_ID, available_task_specs, get_task_spec
from server.data_generator import (
EpisodeConfig,
EpisodeData,
MetricDataGenerator,
available_synthetic_generator_methods,
)
from evaluation import EvaluationConfig
@dataclass(frozen=True)
class RewardConfig:
"""Compatibility wrapper around the evaluator configuration."""
evaluation: EvaluationConfig = EvaluationConfig()
class MetricTrackerRlEnvironment(Environment):
"""Iterative multi-anomaly benchmark with safe analysis methods."""
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(
self,
generator: MetricDataGenerator | None = None,
reward_config: RewardConfig | None = None,
) -> None:
initial_task = get_task_spec(DEFAULT_TASK_ID)
self._generator = generator or MetricDataGenerator()
self._reward_config = reward_config or RewardConfig()
self._state = State(episode_id=str(uuid4()), step_count=0)
self._episode: EpisodeData | None = None
self._completed = False
self._debug_mode = False
self._active_task = initial_task
self._next_task_id = initial_task.task_id
self._next_reset_config = initial_task.build_episode_config()
self._last_analysis_result: dict | None = None
self._expose_applied_generators = False
def configure_next_reset(
self,
*,
task_id: str | None = None,
seed: int | None = None,
scenario_family: str | None = None,
difficulty: str | None = None,
anomaly_density: str | None = None,
anomaly_count: int | None = None,
anomalies: list[dict] | list[SyntheticAnomalyGenerator] | None = None,
) -> None:
"""Update the configuration used for the next reset."""
base_task = get_task_spec(task_id or self._next_task_id)
base_config = base_task.build_episode_config() if task_id else self._next_reset_config
anomaly_generators = tuple(
item if isinstance(item, SyntheticAnomalyGenerator) else SyntheticAnomalyGenerator(**item)
for item in (anomalies or [])
)
self._next_task_id = base_task.task_id
self._next_reset_config = EpisodeConfig(
seed=base_config.seed if seed is None else seed,
scenario_family=base_config.scenario_family if scenario_family is None else scenario_family,
difficulty=base_config.difficulty if difficulty is None else difficulty,
anomaly_density=base_config.anomaly_density if anomaly_density is None else anomaly_density,
anomaly_count=base_config.anomaly_count if anomaly_count is None else anomaly_count,
anomaly_generators=anomaly_generators or base_config.anomaly_generators,
).normalized()
def set_debug_mode(self, enabled: bool) -> None:
"""Enable or disable debug-only environment views."""
self._debug_mode = bool(enabled)
def export_debug_snapshot(self) -> dict:
"""Return a developer-only debug snapshot for the active episode."""
if not self._debug_mode:
raise RuntimeError("Debug mode is disabled.")
if self._episode is None:
return {}
return {
"config": self._episode.config.__dict__,
"expected_payload": [row.model_dump() for row in self._episode.expected_rows],
"anomaly_schedule": self._episode.anomaly_schedule,
"applied_synthetic_generators": [
row.model_dump() for row in self._episode.applied_synthetic_generators
],
}
def reset(
self,
task_id: str | None = None,
seed: int | None = None,
scenario_family: str | None = None,
difficulty: str | None = None,
anomaly_density: str | None = None,
anomaly_count: int | None = None,
anomalies: list[dict] | list[SyntheticAnomalyGenerator] | None = None,
) -> MetricTrackerRlObservation:
"""Generate a fresh dataset and hidden target payload."""
if any(value is not None for value in (task_id, seed, scenario_family, difficulty, anomaly_density, anomaly_count)) or anomalies is not None:
self.configure_next_reset(
task_id=task_id,
seed=seed,
scenario_family=scenario_family,
difficulty=difficulty,
anomaly_density=anomaly_density,
anomaly_count=anomaly_count,
anomalies=anomalies,
)
self._state = State(episode_id=str(uuid4()), step_count=0)
self._active_task = get_task_spec(self._next_task_id)
self._episode = self._generator.generate_episode(self._next_reset_config)
self._completed = False
self._last_analysis_result = None
self._expose_applied_generators = anomalies is not None
return self._build_observation(
status="ready",
message=self._active_task.objective,
reward=0.0,
done=False,
)
def step(self, action: MetricTrackerRlAction) -> MetricTrackerRlObservation: # type: ignore[override]
"""Evaluate a submitted payload and return deterministic feedback."""
if self._episode is None:
return self.reset()
if self._completed:
return self._build_observation(
status="completed",
message="Dataset already solved. Call reset() to create a new dataset.",
reward=1.0,
done=True,
submitted_rows=action.classifications,
)
if action.analysis_method:
self._state.step_count += 1
analysis_result = self._run_analysis(action.analysis_method, action.analysis_args)
self._last_analysis_result = analysis_result
return self._build_observation(
status="analyzed",
message=f"Ran analysis method `{action.analysis_method}`.",
reward=0.0,
done=False,
analysis_result=analysis_result,
)
submitted_rows = action.classifications
generated_rows: list[MetricSubmissionRow] = []
if action.payload_generators:
generator_result = self._run_analysis(
"payload_generator",
{"generator_methods": [item.model_dump() for item in action.payload_generators]},
)
self._last_analysis_result = generator_result
generated_rows = [
MetricSubmissionRow(**row)
for row in generator_result["result"]["generated_rows"]
]
submitted_rows = generated_rows
self._state.step_count += 1
result = self._active_task.grade_submission(
submitted_rows,
self._episode.expected_rows,
config=self._reward_config.evaluation,
include_debug_expected=self._debug_mode,
)
self._completed = result.is_perfect
reward = result.reward_breakdown.total_score
message = self._submission_message(result)
return self._build_observation(
status="evaluated" if result.is_perfect else "in_progress",
message=message,
reward=reward,
done=result.is_perfect,
submitted_rows=result.preview.normalized_rows,
reward_breakdown=result.reward_breakdown,
submission_preview=result.preview,
issues=result.issues,
correct_row_count=result.matched_rows,
analysis_result=self._last_analysis_result,
generated_rows=generated_rows,
)
@property
def state(self) -> State:
"""Return current episode state."""
return self._state
def _build_observation(
self,
*,
status: str,
message: str,
reward: float,
done: bool,
submitted_rows=None,
reward_breakdown=None,
submission_preview=None,
issues=None,
correct_row_count: int = 0,
analysis_result=None,
generated_rows=None,
) -> MetricTrackerRlObservation:
assert self._episode is not None
metadata = {
"step": self._state.step_count,
"current_state": self.state.model_dump(),
"task_id": self._active_task.task_id,
"objective": self._active_task.objective,
"grader_name": self._active_task.grader_name,
"seed": self._episode.config.seed,
"scenario_family": self._episode.config.scenario_family,
"difficulty": self._episode.config.difficulty,
"anomaly_density": self._episode.config.anomaly_density,
"anomaly_count": self._episode.config.anomaly_count,
}
return MetricTrackerRlObservation(
task_id=self._active_task.task_id,
status=status,
message=message,
instruction=self._active_task.instruction,
conversion_metric_definitions=list(self._generator.config.conversion_definitions),
available_synthetic_generator_methods=available_synthetic_generator_methods(),
applied_synthetic_generators=(
self._episode.applied_synthetic_generators
if self._debug_mode or self._expose_applied_generators
else []
),
available_methods=available_analysis_methods(),
available_tasks=available_task_specs(),
daily_metrics=[],
hourly_metrics=[],
analysis_result=analysis_result,
generated_rows=generated_rows or [],
submitted_rows=submitted_rows or [],
submission_preview=submission_preview,
submission_issues=issues or [],
reward_breakdown=reward_breakdown,
expected_row_count=len(self._episode.expected_rows),
correct_row_count=correct_row_count,
reward=reward,
done=done,
config=metadata,
debug=(
{
"task_id": self._active_task.task_id,
"expected_payload": [row.model_dump() for row in self._episode.expected_rows],
"anomaly_schedule": self._episode.anomaly_schedule,
"reward_breakdown": reward_breakdown.model_dump() if reward_breakdown else None,
"issues": [item.model_dump() for item in (issues or [])],
}
if self._debug_mode
else None
),
)
def _run_analysis(self, method_name: str, arguments: dict) -> dict:
toolkit = SharedAnalysisToolkit(
AnalysisContext(
daily_metrics=self._episode.daily_metrics,
hourly_metrics=self._episode.hourly_metrics,
conversion_definitions=list(self._generator.config.conversion_definitions),
instruction=self._active_task.instruction,
config={
"task_id": self._active_task.task_id,
"objective": self._active_task.objective,
"grader_name": self._active_task.grader_name,
**self._episode.config.__dict__,
},
)
)
if method_name == "task_overview":
result = toolkit.task_overview()
elif method_name == "list_dates":
result = toolkit.list_dates()
elif method_name == "list_entities":
result = toolkit.list_entities()
elif method_name == "rows_for_date":
result = toolkit.rows_for_date(arguments["date"])
elif method_name == "hourly_rows_for_date":
result = toolkit.hourly_rows_for_date(arguments["date"])
elif method_name == "compare_rate_to_median":
result = toolkit.compare_rate_to_median(arguments["date"], arguments["entity_name"])
elif method_name == "compare_count_to_median":
result = toolkit.compare_count_to_median(arguments["date"], arguments["entity_name"])
elif method_name == "detect_funnel_break":
result = toolkit.detect_funnel_break(arguments["date"])
elif method_name == "check_impossible_counts":
result = toolkit.check_impossible_counts(arguments["date"])
elif method_name == "list_suspicious_dates":
result = toolkit.list_suspicious_dates(limit=arguments.get("limit", 10))
elif method_name == "preview_submission":
result = toolkit.preview_submission(arguments.get("rows", []))
elif method_name == "show_raw_data":
result = toolkit.show_raw_data(limit=arguments.get("limit", 5))
elif method_name == "get_metric_median":
result = toolkit.get_metric_median_multi(
metric_name=arguments.get("metric_name"),
metric_names=arguments.get("metric_names", []),
)
elif method_name == "get_metric_std_dev_from_median":
result = toolkit.get_metric_std_dev_from_median_multi(
metric_name=arguments.get("metric_name"),
metric_names=arguments.get("metric_names", []),
)
elif method_name == "get_rows_with_abs_diff_from_median_gt":
result = toolkit.get_rows_with_abs_diff_from_median_gt_multi(
metric_name=arguments.get("metric_name"),
metric_names=arguments.get("metric_names", []),
threshold=float(arguments["threshold"]),
)
elif method_name == "get_median_filter_rows":
result = toolkit.get_median_filter_rows_multi(
metric_name=arguments.get("metric_name"),
metric_names=arguments.get("metric_names", []),
threshold_multiplier=float(arguments["threshold_multiplier"]),
)
elif method_name == "get_rate_drop_from_median_rows":
result = toolkit.get_rate_drop_from_median_rows(
metric_name=arguments.get("metric_name"),
metric_names=arguments.get("metric_names", []),
threshold_multiplier=float(arguments["threshold_multiplier"]),
)
elif method_name == "get_rate_spike_from_median_rows":
result = toolkit.get_rate_spike_from_median_rows(
metric_name=arguments.get("metric_name"),
metric_names=arguments.get("metric_names", []),
threshold_multiplier=float(arguments["threshold_multiplier"]),
)
elif method_name == "get_absolute_drop_in_event_count_rows":
result = toolkit.get_absolute_drop_in_event_count_rows(
metric_name=arguments.get("metric_name"),
metric_names=arguments.get("metric_names", []),
threshold_multiplier=float(arguments["threshold_multiplier"]),
)
elif method_name == "get_absolute_spike_in_event_count_rows":
result = toolkit.get_absolute_spike_in_event_count_rows(
metric_name=arguments.get("metric_name"),
metric_names=arguments.get("metric_names", []),
threshold_multiplier=float(arguments["threshold_multiplier"]),
)
elif method_name == "get_funnel_break_rows":
result = toolkit.get_funnel_break_rows(
threshold_multiplier=float(arguments["threshold_multiplier"]),
)
elif method_name == "get_hourly_traffic_mix_shift_rows":
result = toolkit.get_hourly_traffic_mix_shift_rows(
threshold_multiplier=float(arguments["threshold_multiplier"]),
)
elif method_name == "get_instrumentation_data_quality_issue_rows":
result = toolkit.get_instrumentation_data_quality_issue_rows(
threshold_multiplier=float(arguments["threshold_multiplier"]),
)
elif method_name == "payload_generator":
result = toolkit.payload_generator(arguments.get("generator_methods", []))
else:
raise ValueError(f"Unsupported analysis method: {method_name}")
return {
"method": method_name,
"arguments": arguments,
"result": result,
}
@staticmethod
def _submission_message(result) -> str:
if result.is_perfect:
return "Submission is fully correct."
extra_issues = [issue for issue in result.issues if issue.issue_type == "extra_row"]
missing_count = result.reward_breakdown.missing_rows
if not extra_issues and missing_count > 0:
return (
"All submitted rows are anomalies, but a few are missing. "
f"Missing value count: {missing_count}."
)
if extra_issues:
first = extra_issues[0]
return f"Specific row is not an anomaly: {first.row_key}."
return (
f"Matched {result.reward_breakdown.matched_rows}/"
f"{result.reward_breakdown.expected_rows} expected rows. Review the feedback."
)