kushalExplores commited on
Commit
e415506
·
verified ·
1 Parent(s): a5fe7ab

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .pytest_cache
3
+ .ruff_cache
4
+ .mypy_cache
5
+ .DS_Store
6
+ .env
7
+ .env.*
8
+ .venv
9
+ __pycache__
10
+ *.pyc
11
+ *.pyo
12
+ *.pyd
13
+ tests
14
+ openenv_metric_tracker_rl.egg-info
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Secrets / local env
2
+ .env
3
+ .env.*
4
+
5
+ # Virtual environments
6
+ .venv/
7
+ venv/
8
+
9
+ # Python cache
10
+ __pycache__/
11
+ *.py[cod]
12
+
13
+ # Test / tool cache
14
+ .pytest_cache/
15
+ .mypy_cache/
16
+ .ruff_cache/
17
+ .coverage
18
+
19
+ # Build / packaging artifacts
20
+ build/
21
+ dist/
22
+ *.egg-info/
23
+
24
+ # OS files
25
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ ENV PYTHONDONTWRITEBYTECODE=1 \
4
+ PYTHONUNBUFFERED=1 \
5
+ PIP_NO_CACHE_DIR=1 \
6
+ UV_PROJECT_ENVIRONMENT=/opt/venv \
7
+ PATH="/opt/venv/bin:/root/.local/bin:${PATH}" \
8
+ PYTHONPATH=/app \
9
+ PORT=8000 \
10
+ ENABLE_WEB_INTERFACE=true
11
+
12
+ WORKDIR /app
13
+
14
+ RUN apt-get update && \
15
+ apt-get install -y --no-install-recommends curl git && \
16
+ rm -rf /var/lib/apt/lists/*
17
+
18
+ RUN pip install --no-cache-dir uv
19
+
20
+ COPY pyproject.toml uv.lock README.md openenv.yaml /app/
21
+ COPY __init__.py analysis_tools.py client.py evaluation.py inference.py models.py payload_generation.py tasks.py /app/
22
+ COPY server /app/server
23
+
24
+ RUN uv sync --frozen --no-dev
25
+
26
+ EXPOSE 8000
27
+
28
+ HEALTHCHECK --interval=30s --timeout=5s --start-period=20s --retries=3 \
29
+ CMD curl -fsS "http://127.0.0.1:${PORT}/health" || exit 1
30
+
31
+ CMD ["python", "-m", "uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,10 +1,228 @@
1
  ---
2
- title: Metric Tracker Rl
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: docker
 
7
  pinned: false
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Metric Tracker RL
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: docker
7
+ app_port: 8000
8
  pinned: false
9
+ tags:
10
+ - openenv
11
+ - reinforcement-learning
12
+ - analytics
13
+ - anomaly-detection
14
  ---
15
 
16
+ # Metric Tracker RL
17
+
18
+ `metric_tracker_rl` is an OpenEnv benchmark for investigating synthetic product-funnel metrics and submitting a structured anomaly report. It is designed to run as a containerized Hugging Face Space and exposes the same environment through both an OpenEnv-compatible HTTP API and a Gradio debugger.
19
+
20
+ ## Environment Description And Motivation
21
+
22
+ This environment models a common analytics workflow: a team notices a KPI shift, inspects daily and hourly aggregates, compares observed values to historical baselines, and decides which anomalies are real enough to report. The benchmark focuses on disciplined investigation rather than raw generation. Agents must use safe analysis tools, avoid over-submitting, and produce a precise anomaly payload that matches hidden seeded ground truth.
23
+
24
+ The motivation for the benchmark is to test whether an agent can:
25
+
26
+ - navigate a realistic tabular analytics task without direct oracle access
27
+ - combine count-based, rate-based, funnel, and hourly reasoning
28
+ - preserve precision when multiple anomaly families may be present
29
+ - translate evidence into a stable machine-graded submission format
30
+
31
+ Each reset creates a deterministic four-week synthetic dataset with daily and hourly funnel aggregates. Hidden anomaly labels are derived from the reset configuration, so tasks are reproducible and programmatically graded.
32
+
33
+ ## Action Space
34
+
35
+ The environment accepts `MetricTrackerRlAction` with three fields:
36
+
37
+ - `classifications`: final anomaly rows to grade
38
+ - `analysis_method`: optional safe method name to call instead of grading
39
+ - `analysis_args`: arguments for the selected analysis method
40
+ - `payload_generators`: optional declarative generator methods that create submission rows inside the environment
41
+
42
+ Each `classifications` row must include:
43
+
44
+ - `date`: ISO date in `YYYY-MM-DD`
45
+ - `entity_type`: one of the stable families such as `conversion_rate`, `event_count`, `funnel_step`, `hourly_mix`, or `data_quality`
46
+ - `entity_name`: stable metric or entity identifier
47
+ - `anomaly_type`: anomaly family identifier
48
+ - `detection_method`: analysis method used to justify the row
49
+ - `baseline_value`: historical reference value
50
+ - `observed_value`: measured anomalous value
51
+ - `delta_value`: `observed_value - baseline_value`
52
+ - `severity`: one of `low`, `medium`, `high`, or `critical`
53
+
54
+ ## Observation Space
55
+
56
+ The environment returns `MetricTrackerRlObservation`, which includes:
57
+
58
+ - task metadata: `task_id`, `instruction`, `status`, and visible episode config
59
+ - method surface: `available_methods` and `available_synthetic_generator_methods`
60
+ - task catalog: `available_tasks`
61
+ - metric definitions: `conversion_metric_definitions`
62
+ - latest tool output: `analysis_result`
63
+ - latest submission output: `generated_rows`, `submitted_rows`, `submission_preview`, `submission_issues`, and `reward_breakdown`
64
+ - progress counters: `expected_row_count` and `correct_row_count`
65
+
66
+ In standard benchmark mode, raw `daily_metrics`, raw `hourly_metrics`, and hidden debug payloads are not exposed directly. Agents are expected to inspect the data through the read-only shared analysis methods instead.
67
+
68
+ ## Shared Analysis Surface
69
+
70
+ Humans in the Gradio debugger and agents in `inference.py` use the same read-only analysis surface:
71
+
72
+ - `task_overview`
73
+ - `list_dates`
74
+ - `list_entities`
75
+ - `rows_for_date`
76
+ - `hourly_rows_for_date`
77
+ - `compare_rate_to_median`
78
+ - `compare_count_to_median`
79
+ - `detect_funnel_break`
80
+ - `check_impossible_counts`
81
+ - `list_suspicious_dates`
82
+ - `preview_submission`
83
+ - payload-generator helpers such as `get_median_filter_rows`
84
+
85
+ This keeps the benchmark focused on investigation quality rather than privileged access.
86
+
87
+ ## Tasks And Expected Difficulty
88
+
89
+ The benchmark ships with three named deterministic tasks:
90
+
91
+ 1. `easy_single_spike`
92
+ Expected difficulty: easy.
93
+ One obvious event-count spike is present. A careful single-method investigation should usually be enough.
94
+ 2. `medium_mixed_pair`
95
+ Expected difficulty: medium.
96
+ Three anomalies are present across mixed count and rate signals. Precision matters because over-submission is penalized.
97
+ 3. `hard_mixed_multi`
98
+ Expected difficulty: hard.
99
+ Five anomalies are present with higher density and weaker signal separation. Agents need broader exploration and tighter filtering.
100
+
101
+ Supported anomaly families across resets:
102
+
103
+ - `rate_drop_from_median`
104
+ - `rate_spike_from_median`
105
+ - `absolute_drop_in_event_count`
106
+ - `absolute_spike_in_event_count`
107
+ - `funnel_break`
108
+ - `hourly_traffic_mix_shift`
109
+ - `instrumentation_data_quality_issue`
110
+
111
+ ## Reward And Grading
112
+
113
+ Grading is deterministic and normalized to `[0, 1]`. The evaluator rewards:
114
+
115
+ - precision
116
+ - recall
117
+ - correct `anomaly_type`
118
+ - correct `detection_method`
119
+ - numeric accuracy for `baseline_value`, `observed_value`, and `delta_value` within tolerance
120
+ - correct `severity`
121
+
122
+ Penalties apply for:
123
+
124
+ - extra rows
125
+ - duplicate rows
126
+ - invalid rows
127
+ - exploit-style mass submission patterns
128
+
129
+ The observation exposes `submission_preview`, `submission_issues`, and `reward_breakdown` after a graded step.
130
+
131
+ ## Baseline Scores
132
+
133
+ Reference scores below were measured locally with a deterministic scripted payload-generator baseline that submits:
134
+
135
+ - `easy_single_spike`: `get_absolute_spike_in_event_count_rows(threshold_multiplier=2.0)`
136
+ - `medium_mixed_pair`: `get_median_filter_rows(threshold_multiplier=2.0)`
137
+ - `hard_mixed_multi`: `get_median_filter_rows(threshold_multiplier=2.0)`
138
+
139
+ Measured normalized scores:
140
+
141
+ - `easy_single_spike`: `1.000000`
142
+ - `medium_mixed_pair`: `0.662500`
143
+ - `hard_mixed_multi`: `0.421818`
144
+ - average across named tasks: `0.694773`
145
+
146
+ These numbers are useful as a simple non-LLM reference point, not as a ceiling. A perfect submission still scores `1.0` on each task.
147
+
148
+ ## Hugging Face Space Deployment
149
+
150
+ This repository is configured for a containerized Hugging Face Space:
151
+
152
+ - `README.md` frontmatter sets `sdk: docker`
153
+ - the Space is tagged with `openenv`
154
+ - [`openenv.yaml`](/Users/kushaljaisinghani/Documents/sample_envs/metric_tracker_rl/openenv.yaml) points to `server.app:app`
155
+ - [`Dockerfile`](/Users/kushaljaisinghani/Documents/sample_envs/metric_tracker_rl/Dockerfile) starts the OpenEnv HTTP server on port `8000`
156
+
157
+ ## Setup
158
+
159
+ ### Local Python Setup
160
+
161
+ ```bash
162
+ cd metric_tracker_rl
163
+ uv sync
164
+ ```
165
+
166
+ ### Run The Environment Locally
167
+
168
+ ```bash
169
+ cd metric_tracker_rl
170
+ uv run python -m uvicorn server.app:app --host 0.0.0.0 --port 8000
171
+ ```
172
+
173
+ ### Run The Inference Baseline
174
+
175
+ Set credentials in [`.env.inference`](/Users/kushaljaisinghani/Documents/sample_envs/metric_tracker_rl/.env.inference), then run:
176
+
177
+ ```bash
178
+ cd metric_tracker_rl
179
+ source .env.inference
180
+ uv run python inference.py
181
+ ```
182
+
183
+ The inference baseline runs:
184
+
185
+ - `easy_single_spike`
186
+ - `medium_mixed_pair`
187
+ - `hard_mixed_multi`
188
+
189
+ It prints one score per task and an overall average benchmark score.
190
+
191
+ ## Container Build And Run
192
+
193
+ Build the image:
194
+
195
+ ```bash
196
+ cd metric_tracker_rl
197
+ docker build -t metric-tracker-rl .
198
+ ```
199
+
200
+ Run the container:
201
+
202
+ ```bash
203
+ docker run --rm -p 8000:8000 metric-tracker-rl
204
+ ```
205
+
206
+ Once running, the Space-compatible server is available at `http://localhost:8000`.
207
+
208
+ ## Validation
209
+
210
+ Useful checks:
211
+
212
+ ```bash
213
+ cd metric_tracker_rl
214
+ openenv validate .
215
+ python -m uvicorn server.app:app --host 0.0.0.0 --port 8000
216
+ ```
217
+
218
+ ## Manual Debugging UI
219
+
220
+ The bundled Gradio UI exposes:
221
+
222
+ - named-task selection
223
+ - reset controls for `seed`, `scenario_family`, `difficulty`, and `anomaly_density`
224
+ - the same shared analysis methods used by the agent baseline
225
+ - payload preview and submission feedback
226
+ - charts for daily counts, rates, hourly metrics, and funnel shape
227
+
228
+ Debug mode can expose expected rows and anomaly schedules for development, but that view is intentionally gated and is not part of standard benchmark play.
__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metric Tracker Rl Environment."""
2
+
3
+ from .client import MetricTrackerRlEnv
4
+ from .models import MetricSubmissionRow, MetricTrackerRlAction, MetricTrackerRlObservation
5
+ from .payload_generation import (
6
+ available_analysis_methods,
7
+ available_payload_generation_methods,
8
+ available_synthetic_generator_methods,
9
+ )
10
+ from .tasks import DEFAULT_TASK_ID, DEFAULT_TASK_ORDER, TASKS, TaskSpec, available_task_specs, get_task_spec
11
+
12
+ __all__ = [
13
+ "MetricSubmissionRow",
14
+ "MetricTrackerRlAction",
15
+ "MetricTrackerRlObservation",
16
+ "MetricTrackerRlEnv",
17
+ "available_analysis_methods",
18
+ "available_payload_generation_methods",
19
+ "available_synthetic_generator_methods",
20
+ "TaskSpec",
21
+ "TASKS",
22
+ "DEFAULT_TASK_ID",
23
+ "DEFAULT_TASK_ORDER",
24
+ "get_task_spec",
25
+ "available_task_specs",
26
+ ]
analysis_tools.py ADDED
@@ -0,0 +1,1229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared safe analysis methods for agents and the manual UI."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ import math
7
+ from statistics import median
8
+ from typing import Any
9
+
10
+ try:
11
+ from .models import (
12
+ ConversionMetricDefinition,
13
+ MethodSpec,
14
+ MetricRecord,
15
+ MetricSubmissionRow,
16
+ PayloadGeneratorMethod,
17
+ SubmissionIssue,
18
+ SubmissionPreview,
19
+ )
20
+ except ImportError:
21
+ from models import (
22
+ ConversionMetricDefinition,
23
+ MethodSpec,
24
+ MetricRecord,
25
+ MetricSubmissionRow,
26
+ PayloadGeneratorMethod,
27
+ SubmissionIssue,
28
+ SubmissionPreview,
29
+ )
30
+
31
+
32
+ FUNNEL_STEPS: tuple[tuple[str, str], ...] = (
33
+ ("menu_opens", "app_opens"),
34
+ ("product_added_to_cart", "menu_opens"),
35
+ ("orders_placed", "product_added_to_cart"),
36
+ ("payment_successful", "orders_placed"),
37
+ )
38
+
39
+ COUNT_METRICS: tuple[str, ...] = (
40
+ "app_opens",
41
+ "menu_opens",
42
+ "product_added_to_cart",
43
+ "orders_placed",
44
+ "payment_successful",
45
+ )
46
+
47
+ DEFAULT_METHOD_SPECS: tuple[MethodSpec, ...] = (
48
+ MethodSpec(
49
+ name="task_overview",
50
+ description="Return compact task context, config, entity catalog, and payload schema.",
51
+ ),
52
+ MethodSpec(name="list_dates", description="List all dates in the dataset."),
53
+ MethodSpec(
54
+ name="list_entities",
55
+ description="List count, rate, funnel, hourly mix, and data quality entities.",
56
+ ),
57
+ MethodSpec(
58
+ name="rows_for_date",
59
+ description="Return daily counts and derived rates for one date.",
60
+ parameters=["date"],
61
+ ),
62
+ MethodSpec(
63
+ name="hourly_rows_for_date",
64
+ description="Return hourly rows and traffic-share summaries for one date.",
65
+ parameters=["date"],
66
+ ),
67
+ MethodSpec(
68
+ name="compare_rate_to_median",
69
+ description="Compare one conversion rate against its daily median baseline.",
70
+ parameters=["date", "entity_name"],
71
+ ),
72
+ MethodSpec(
73
+ name="compare_count_to_median",
74
+ description="Compare one event count against its daily median baseline.",
75
+ parameters=["date", "entity_name"],
76
+ ),
77
+ MethodSpec(
78
+ name="detect_funnel_break",
79
+ description="Inspect funnel-step rates and monotonicity for a date.",
80
+ parameters=["date"],
81
+ ),
82
+ MethodSpec(
83
+ name="check_impossible_counts",
84
+ description="Find impossible daily or hourly count relationships for a date.",
85
+ parameters=["date"],
86
+ ),
87
+ MethodSpec(
88
+ name="list_suspicious_dates",
89
+ description="Rank dates by anomaly suspicion using shared heuristics.",
90
+ parameters=["limit"],
91
+ ),
92
+ MethodSpec(
93
+ name="preview_submission",
94
+ description="Validate candidate payload rows without revealing ground truth.",
95
+ parameters=["rows"],
96
+ ),
97
+ MethodSpec(
98
+ name="show_raw_data",
99
+ description="Return a head() style view of daily aggregate rows with count and rate metrics.",
100
+ parameters=["limit"],
101
+ ),
102
+ MethodSpec(
103
+ name="get_metric_median",
104
+ description="Return the median for a count metric or conversion metric.",
105
+ parameters=["metric_name"],
106
+ ),
107
+ MethodSpec(
108
+ name="get_metric_std_dev_from_median",
109
+ description="Return sqrt(mean((value - median)^2)) for a metric.",
110
+ parameters=["metric_name"],
111
+ ),
112
+ MethodSpec(
113
+ name="get_rows_with_abs_diff_from_median_gt",
114
+ description="Return all dates where abs(value - median) is greater than a threshold.",
115
+ parameters=["metric_name", "threshold"],
116
+ ),
117
+ MethodSpec(
118
+ name="get_median_filter_rows",
119
+ description="Build payload rows where abs(value - median) > threshold_multiplier * std_from_median.",
120
+ parameters=["metric_name", "threshold_multiplier"],
121
+ ),
122
+ MethodSpec(
123
+ name="get_rate_drop_from_median_rows",
124
+ description="Build conversion-rate payload rows where median-filtered values drop below baseline.",
125
+ parameters=["metric_name", "threshold_multiplier"],
126
+ ),
127
+ MethodSpec(
128
+ name="get_rate_spike_from_median_rows",
129
+ description="Build conversion-rate payload rows where median-filtered values spike above baseline.",
130
+ parameters=["metric_name", "threshold_multiplier"],
131
+ ),
132
+ MethodSpec(
133
+ name="get_absolute_drop_in_event_count_rows",
134
+ description="Build event-count payload rows where median-filtered values drop below baseline.",
135
+ parameters=["metric_name", "threshold_multiplier"],
136
+ ),
137
+ MethodSpec(
138
+ name="get_absolute_spike_in_event_count_rows",
139
+ description="Build event-count payload rows where median-filtered values spike above baseline.",
140
+ parameters=["metric_name", "threshold_multiplier"],
141
+ ),
142
+ MethodSpec(
143
+ name="get_funnel_break_rows",
144
+ description="Build payload rows for funnel-step breaks by scanning dates for large funnel-rate drops.",
145
+ parameters=["threshold_multiplier"],
146
+ ),
147
+ MethodSpec(
148
+ name="get_hourly_traffic_mix_shift_rows",
149
+ description="Build payload rows for dates with unusual app_open daytime-share shifts.",
150
+ parameters=["threshold_multiplier"],
151
+ ),
152
+ MethodSpec(
153
+ name="get_instrumentation_data_quality_issue_rows",
154
+ description="Build payload rows for dates with impossible count relationships or instrumentation issues.",
155
+ parameters=["threshold_multiplier"],
156
+ ),
157
+ MethodSpec(
158
+ name="payload_generator",
159
+ description="Run a list of payload generation methods and merge the generated rows.",
160
+ parameters=["generator_methods"],
161
+ ),
162
+ )
163
+
164
+
165
+ def available_analysis_methods() -> list[MethodSpec]:
166
+ """Return the shared safe method surface."""
167
+ return list(DEFAULT_METHOD_SPECS)
168
+
169
+
170
+ @dataclass
171
+ class AnalysisContext:
172
+ """Structured input for the shared method implementation."""
173
+
174
+ daily_metrics: list[MetricRecord]
175
+ hourly_metrics: list[MetricRecord]
176
+ conversion_definitions: list[ConversionMetricDefinition]
177
+ instruction: str = ""
178
+ config: dict[str, Any] | None = None
179
+
180
+
181
+ class SharedAnalysisToolkit:
182
+ """Shared method implementation for agents and the manual UI."""
183
+
184
+ def __init__(self, context: AnalysisContext) -> None:
185
+ self._context = context
186
+ self._daily_by_date = {row.date: row for row in context.daily_metrics}
187
+ self._hourly_by_date: dict[str, list[MetricRecord]] = {}
188
+ for row in context.hourly_metrics:
189
+ self._hourly_by_date.setdefault(row.date, []).append(row)
190
+ for rows in self._hourly_by_date.values():
191
+ rows.sort(key=lambda item: item.hour if item.hour is not None else -1)
192
+ self._dates = sorted(self._daily_by_date)
193
+ self._conversion_map = {item.name: item for item in context.conversion_definitions}
194
+
195
+ def task_overview(self) -> dict[str, Any]:
196
+ """Return a compact overview of the task and available entities."""
197
+ return {
198
+ "instruction": self._context.instruction,
199
+ "config": self._context.config or {},
200
+ "date_count": len(self._dates),
201
+ "dates": self._dates,
202
+ "threshold_search_space": {
203
+ "rate_delta_pct_points": [3.0, 4.5, 6.0, 8.0],
204
+ "count_delta_pct": [10.0, 15.0, 22.0, 30.0],
205
+ "funnel_delta_pct_points": [3.5, 5.0, 7.0, 10.0],
206
+ "hourly_mix_delta_pct": [8.0, 12.0, 18.0, 25.0],
207
+ },
208
+ "payload_schema": [
209
+ "date",
210
+ "entity_type",
211
+ "entity_name",
212
+ "anomaly_type",
213
+ "detection_method",
214
+ "baseline_value",
215
+ "observed_value",
216
+ "delta_value",
217
+ "severity",
218
+ ],
219
+ "available_methods": [item.model_dump() for item in available_analysis_methods()],
220
+ "entities": self.list_entities()["entities"],
221
+ }
222
+
223
+ def list_dates(self) -> dict[str, Any]:
224
+ return {"dates": self._dates}
225
+
226
+ def list_entities(self) -> dict[str, Any]:
227
+ conversions = [
228
+ {
229
+ "entity_type": "conversion_rate",
230
+ "entity_name": item.name,
231
+ "formula": item.description,
232
+ }
233
+ for item in self._context.conversion_definitions
234
+ ]
235
+ counts = [
236
+ {
237
+ "entity_type": "event_count",
238
+ "entity_name": metric_name,
239
+ }
240
+ for metric_name in COUNT_METRICS
241
+ ]
242
+ funnels = [
243
+ {
244
+ "entity_type": "funnel_step",
245
+ "entity_name": f"{numerator}_from_{denominator}",
246
+ }
247
+ for numerator, denominator in FUNNEL_STEPS
248
+ ]
249
+ quality = [
250
+ {
251
+ "entity_type": "data_quality",
252
+ "entity_name": f"{numerator}_lte_{denominator}",
253
+ }
254
+ for numerator, denominator in FUNNEL_STEPS
255
+ ]
256
+ hourly = [
257
+ {
258
+ "entity_type": "hourly_mix",
259
+ "entity_name": "app_opens:daytime_share",
260
+ }
261
+ ]
262
+ return {"entities": conversions + counts + funnels + quality + hourly}
263
+
264
+ def rows_for_date(self, date: str) -> dict[str, Any]:
265
+ row = self._daily_by_date.get(date)
266
+ if row is None:
267
+ return {"found": False, "date": date, "error": "Date not found."}
268
+ derived_rates = self._conversion_rates(row)
269
+ return {
270
+ "found": True,
271
+ "date": date,
272
+ "daily_metrics": row.model_dump(),
273
+ "derived_rates": derived_rates,
274
+ }
275
+
276
+ def hourly_rows_for_date(self, date: str) -> dict[str, Any]:
277
+ rows = self._hourly_by_date.get(date, [])
278
+ if not rows:
279
+ return {"found": False, "date": date, "error": "Date not found."}
280
+ total = sum(item.app_opens for item in rows) or 1
281
+ daytime_hours = {8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19}
282
+ daytime_share = round(
283
+ sum(item.app_opens for item in rows if item.hour in daytime_hours) / total,
284
+ 4,
285
+ )
286
+ return {
287
+ "found": True,
288
+ "date": date,
289
+ "summary": {
290
+ "daytime_share": daytime_share,
291
+ "top_hours": sorted(
292
+ (
293
+ {
294
+ "hour": item.hour,
295
+ "app_opens": item.app_opens,
296
+ "share": round(item.app_opens / total, 4),
297
+ }
298
+ for item in rows
299
+ ),
300
+ key=lambda item: item["app_opens"],
301
+ reverse=True,
302
+ )[:5],
303
+ },
304
+ "rows": [item.model_dump() for item in rows],
305
+ }
306
+
307
+ def compare_rate_to_median(self, date: str, entity_name: str) -> dict[str, Any]:
308
+ record = self._daily_by_date.get(date)
309
+ definition = self._conversion_map.get(entity_name)
310
+ if record is None or definition is None:
311
+ return {
312
+ "found": False,
313
+ "date": date,
314
+ "entity_name": entity_name,
315
+ "error": "Date or conversion entity not found.",
316
+ }
317
+ series = [self._rate_for_record(item, definition) for item in self._context.daily_metrics]
318
+ baseline = round(median(series), 4)
319
+ observed = round(self._rate_for_record(record, definition), 4)
320
+ delta = round(observed - baseline, 4)
321
+ anomaly_type = "normal"
322
+ if delta <= -self._rate_threshold():
323
+ anomaly_type = "rate_drop_from_median"
324
+ elif delta >= self._rate_threshold():
325
+ anomaly_type = "rate_spike_from_median"
326
+ return {
327
+ "found": True,
328
+ "date": date,
329
+ "entity_type": "conversion_rate",
330
+ "entity_name": entity_name,
331
+ "detection_method": "compare_rate_to_median",
332
+ "baseline_value": baseline,
333
+ "observed_value": observed,
334
+ "delta_value": delta,
335
+ "anomaly_type": anomaly_type,
336
+ "severity": self._severity(abs(delta), medium=4.0, high=8.0, critical=12.0),
337
+ }
338
+
339
+ def compare_count_to_median(self, date: str, entity_name: str) -> dict[str, Any]:
340
+ record = self._daily_by_date.get(date)
341
+ if record is None or entity_name not in COUNT_METRICS:
342
+ return {
343
+ "found": False,
344
+ "date": date,
345
+ "entity_name": entity_name,
346
+ "error": "Date or count entity not found.",
347
+ }
348
+ series = [float(getattr(item, entity_name)) for item in self._context.daily_metrics]
349
+ baseline = round(median(series), 4)
350
+ observed = round(float(getattr(record, entity_name)), 4)
351
+ delta = round(observed - baseline, 4)
352
+ threshold = max(50.0, baseline * self._count_threshold_fraction())
353
+ anomaly_type = "normal"
354
+ if delta <= -threshold:
355
+ anomaly_type = "absolute_drop_in_event_count"
356
+ elif delta >= threshold:
357
+ anomaly_type = "absolute_spike_in_event_count"
358
+ return {
359
+ "found": True,
360
+ "date": date,
361
+ "entity_type": "event_count",
362
+ "entity_name": entity_name,
363
+ "detection_method": "compare_count_to_median",
364
+ "baseline_value": baseline,
365
+ "observed_value": observed,
366
+ "delta_value": delta,
367
+ "anomaly_type": anomaly_type,
368
+ "severity": self._severity(
369
+ abs(delta) / max(baseline, 1.0) * 100.0,
370
+ medium=12.0,
371
+ high=22.0,
372
+ critical=35.0,
373
+ ),
374
+ }
375
+
376
+ def detect_funnel_break(self, date: str) -> dict[str, Any]:
377
+ record = self._daily_by_date.get(date)
378
+ if record is None:
379
+ return {"found": False, "date": date, "error": "Date not found."}
380
+ candidates: list[dict[str, Any]] = []
381
+ for numerator, denominator in FUNNEL_STEPS:
382
+ entity_name = f"{numerator}_from_{denominator}"
383
+ baseline_series = [
384
+ self._ratio(getattr(item, numerator), getattr(item, denominator)) * 100.0
385
+ for item in self._context.daily_metrics
386
+ ]
387
+ baseline = round(median(baseline_series), 4)
388
+ observed = round(
389
+ self._ratio(getattr(record, numerator), getattr(record, denominator)) * 100.0,
390
+ 4,
391
+ )
392
+ delta = round(observed - baseline, 4)
393
+ issue = {
394
+ "entity_type": "funnel_step",
395
+ "entity_name": entity_name,
396
+ "detection_method": "detect_funnel_break",
397
+ "baseline_value": baseline,
398
+ "observed_value": observed,
399
+ "delta_value": delta,
400
+ "monotonicity_broken": getattr(record, numerator) > getattr(record, denominator),
401
+ "severity": self._severity(abs(delta), medium=5.0, high=10.0, critical=15.0),
402
+ }
403
+ if issue["monotonicity_broken"] or delta <= -self._funnel_threshold():
404
+ issue["anomaly_type"] = "funnel_break"
405
+ candidates.append(issue)
406
+ return {"found": True, "date": date, "candidates": candidates}
407
+
408
+ def check_impossible_counts(self, date: str) -> dict[str, Any]:
409
+ daily = self._daily_by_date.get(date)
410
+ hourly_rows = self._hourly_by_date.get(date, [])
411
+ if daily is None:
412
+ return {"found": False, "date": date, "error": "Date not found."}
413
+ issues = []
414
+ issues.extend(self._impossible_issues(daily, scope="daily"))
415
+ for row in hourly_rows:
416
+ issues.extend(self._impossible_issues(row, scope=f"hour_{row.hour:02d}"))
417
+ total_excess = round(sum(item["excess_value"] for item in issues), 4)
418
+ return {
419
+ "found": True,
420
+ "date": date,
421
+ "issue_count": len(issues),
422
+ "total_excess": total_excess,
423
+ "issues": issues,
424
+ }
425
+
426
+ def list_suspicious_dates(self, limit: int = 10) -> dict[str, Any]:
427
+ ranked = []
428
+ hourly_baseline = self._median_daytime_share()
429
+ for date in self._dates:
430
+ rate_signal = 0.0
431
+ for definition in self._context.conversion_definitions:
432
+ comparison = self.compare_rate_to_median(date, definition.name)
433
+ rate_signal = max(rate_signal, abs(comparison["delta_value"]))
434
+ count_signal = 0.0
435
+ for metric_name in COUNT_METRICS:
436
+ comparison = self.compare_count_to_median(date, metric_name)
437
+ baseline = max(comparison["baseline_value"], 1.0)
438
+ count_signal = max(
439
+ count_signal,
440
+ abs(comparison["delta_value"]) / baseline * 100.0,
441
+ )
442
+ funnel_candidates = self.detect_funnel_break(date)["candidates"]
443
+ impossible = self.check_impossible_counts(date)
444
+ hourly_share = self.hourly_rows_for_date(date)["summary"]["daytime_share"]
445
+ hourly_signal = abs(hourly_share - hourly_baseline) * 100.0
446
+ suspicion_score = round(
447
+ rate_signal + count_signal + hourly_signal + impossible["total_excess"] * 0.05
448
+ + len(funnel_candidates) * 6.0,
449
+ 4,
450
+ )
451
+ ranked.append(
452
+ {
453
+ "date": date,
454
+ "suspicion_score": suspicion_score,
455
+ "max_rate_delta": round(rate_signal, 4),
456
+ "max_count_delta_pct": round(count_signal, 4),
457
+ "hourly_mix_delta_pct": round(hourly_signal, 4),
458
+ "funnel_candidate_count": len(funnel_candidates),
459
+ "impossible_issue_count": impossible["issue_count"],
460
+ }
461
+ )
462
+ ranked.sort(key=lambda item: (item["suspicion_score"], item["date"]), reverse=True)
463
+ return {"dates": ranked[: max(limit, 1)]}
464
+
465
+ def preview_submission(self, rows: list[dict[str, Any]] | list[MetricSubmissionRow]) -> dict[str, Any]:
466
+ preview = preview_submission_rows(rows)
467
+ return preview.model_dump()
468
+
469
+ def show_raw_data(self, limit: int = 5) -> dict[str, Any]:
470
+ rows = []
471
+ for record in self._context.daily_metrics[: max(limit, 1)]:
472
+ row = record.model_dump()
473
+ row.update(self._conversion_rates(record))
474
+ rows.append(row)
475
+ return {
476
+ "row_count": len(self._context.daily_metrics),
477
+ "returned_rows": len(rows),
478
+ "rows": rows,
479
+ }
480
+
481
+ def get_metric_median(self, metric_name: str) -> dict[str, Any]:
482
+ descriptor = self._metric_descriptor(metric_name)
483
+ values = descriptor["values"]
484
+ metric_median = round(median(values), 4) if values else 0.0
485
+ return {
486
+ "metric_name": metric_name,
487
+ "metric_type": descriptor["metric_type"],
488
+ "median_value": metric_median,
489
+ "sample_size": len(values),
490
+ }
491
+
492
+ def get_metric_median_multi(
493
+ self,
494
+ metric_name: str | None = None,
495
+ metric_names: list[str] | None = None,
496
+ ) -> dict[str, Any]:
497
+ resolved_metrics = self._resolve_metric_names(metric_name=metric_name, metric_names=metric_names)
498
+ results = [self.get_metric_median(name) for name in resolved_metrics]
499
+ return {
500
+ "metric_name": metric_name,
501
+ "metric_names": resolved_metrics,
502
+ "results": results,
503
+ }
504
+
505
+ def get_metric_std_dev_from_median(self, metric_name: str) -> dict[str, Any]:
506
+ descriptor = self._metric_descriptor(metric_name)
507
+ values = descriptor["values"]
508
+ metric_median = median(values) if values else 0.0
509
+ std_from_median = math.sqrt(
510
+ sum((value - metric_median) ** 2 for value in values) / len(values)
511
+ ) if values else 0.0
512
+ return {
513
+ "metric_name": metric_name,
514
+ "metric_type": descriptor["metric_type"],
515
+ "median_value": round(metric_median, 4),
516
+ "std_dev_from_median": round(std_from_median, 4),
517
+ "sample_size": len(values),
518
+ }
519
+
520
+ def get_metric_std_dev_from_median_multi(
521
+ self,
522
+ metric_name: str | None = None,
523
+ metric_names: list[str] | None = None,
524
+ ) -> dict[str, Any]:
525
+ resolved_metrics = self._resolve_metric_names(metric_name=metric_name, metric_names=metric_names)
526
+ results = [self.get_metric_std_dev_from_median(name) for name in resolved_metrics]
527
+ return {
528
+ "metric_name": metric_name,
529
+ "metric_names": resolved_metrics,
530
+ "results": results,
531
+ }
532
+
533
+ def get_rows_with_abs_diff_from_median_gt(self, metric_name: str, threshold: float) -> dict[str, Any]:
534
+ descriptor = self._metric_descriptor(metric_name)
535
+ metric_median = median(descriptor["values"]) if descriptor["values"] else 0.0
536
+ matches = []
537
+ for date_key, value in descriptor["per_date_values"].items():
538
+ abs_diff = abs(value - metric_median)
539
+ if abs_diff <= threshold:
540
+ continue
541
+ row = {
542
+ "date": date_key,
543
+ "metric_name": metric_name,
544
+ "metric_type": descriptor["metric_type"],
545
+ "median_value": round(metric_median, 4),
546
+ "observed_value": round(value, 4),
547
+ "abs_diff": round(abs_diff, 4),
548
+ }
549
+ suggested = self._build_submission_row_for_metric(
550
+ metric_name=metric_name,
551
+ date=date_key,
552
+ baseline_value=float(metric_median),
553
+ observed_value=float(value),
554
+ )
555
+ if suggested is not None:
556
+ row["suggested_payload_row"] = suggested.model_dump()
557
+ matches.append(row)
558
+ return {
559
+ "metric_name": metric_name,
560
+ "threshold": threshold,
561
+ "match_count": len(matches),
562
+ "rows": matches,
563
+ }
564
+
565
+ def get_rows_with_abs_diff_from_median_gt_multi(
566
+ self,
567
+ metric_name: str | None = None,
568
+ metric_names: list[str] | None = None,
569
+ threshold: float = 0.0,
570
+ ) -> dict[str, Any]:
571
+ resolved_metrics = self._resolve_metric_names(metric_name=metric_name, metric_names=metric_names)
572
+ results = [
573
+ self.get_rows_with_abs_diff_from_median_gt(name, threshold)
574
+ for name in resolved_metrics
575
+ ]
576
+ return {
577
+ "metric_name": metric_name,
578
+ "metric_names": resolved_metrics,
579
+ "threshold": threshold,
580
+ "results": results,
581
+ }
582
+
583
+ def get_median_filter_rows(self, metric_name: str, threshold_multiplier: float) -> dict[str, Any]:
584
+ return self.get_median_filter_rows_multi(
585
+ metric_name=metric_name,
586
+ metric_names=[],
587
+ threshold_multiplier=threshold_multiplier,
588
+ )
589
+
590
+ def get_median_filter_rows_multi(
591
+ self,
592
+ metric_name: str | None = None,
593
+ metric_names: list[str] | None = None,
594
+ threshold_multiplier: float = 2.0,
595
+ ) -> dict[str, Any]:
596
+ resolved_metrics = self._resolve_metric_names(metric_name=metric_name, metric_names=metric_names)
597
+ details = []
598
+ generated: dict[str, dict[str, Any]] = {}
599
+ total_matches = 0
600
+ for resolved_metric in resolved_metrics:
601
+ stats = self.get_metric_std_dev_from_median(resolved_metric)
602
+ threshold = stats["std_dev_from_median"] * threshold_multiplier
603
+ rows_result = self.get_rows_with_abs_diff_from_median_gt(resolved_metric, threshold)
604
+ payload_rows = [
605
+ row["suggested_payload_row"]
606
+ for row in rows_result["rows"]
607
+ if row.get("suggested_payload_row")
608
+ ]
609
+ total_matches += rows_result["match_count"]
610
+ for row in payload_rows:
611
+ submission_row = MetricSubmissionRow(**row)
612
+ generated[submission_row_key(submission_row)] = submission_row.model_dump()
613
+ details.append(
614
+ {
615
+ "metric_name": resolved_metric,
616
+ "threshold": round(threshold, 4),
617
+ "match_count": rows_result["match_count"],
618
+ "rows": rows_result["rows"],
619
+ "generated_rows": payload_rows,
620
+ }
621
+ )
622
+ return {
623
+ "method_name": "get_median_filter_rows",
624
+ "metric_name": metric_name,
625
+ "metric_names": resolved_metrics,
626
+ "threshold_multiplier": threshold_multiplier,
627
+ "match_count": total_matches,
628
+ "generated_rows": list(generated.values()),
629
+ "details": details,
630
+ }
631
+
632
+ def get_rate_drop_from_median_rows(
633
+ self,
634
+ metric_name: str | None = None,
635
+ metric_names: list[str] | None = None,
636
+ threshold_multiplier: float = 2.0,
637
+ ) -> dict[str, Any]:
638
+ return self._metric_family_filter_rows(
639
+ method_name="get_rate_drop_from_median_rows",
640
+ metric_name=metric_name,
641
+ metric_names=metric_names,
642
+ threshold_multiplier=threshold_multiplier,
643
+ metric_type="conversion_rate",
644
+ allowed_anomaly_types={"rate_drop_from_median"},
645
+ )
646
+
647
+ def get_rate_spike_from_median_rows(
648
+ self,
649
+ metric_name: str | None = None,
650
+ metric_names: list[str] | None = None,
651
+ threshold_multiplier: float = 2.0,
652
+ ) -> dict[str, Any]:
653
+ return self._metric_family_filter_rows(
654
+ method_name="get_rate_spike_from_median_rows",
655
+ metric_name=metric_name,
656
+ metric_names=metric_names,
657
+ threshold_multiplier=threshold_multiplier,
658
+ metric_type="conversion_rate",
659
+ allowed_anomaly_types={"rate_spike_from_median"},
660
+ )
661
+
662
+ def get_absolute_drop_in_event_count_rows(
663
+ self,
664
+ metric_name: str | None = None,
665
+ metric_names: list[str] | None = None,
666
+ threshold_multiplier: float = 2.0,
667
+ ) -> dict[str, Any]:
668
+ return self._metric_family_filter_rows(
669
+ method_name="get_absolute_drop_in_event_count_rows",
670
+ metric_name=metric_name,
671
+ metric_names=metric_names,
672
+ threshold_multiplier=threshold_multiplier,
673
+ metric_type="event_count",
674
+ allowed_anomaly_types={"absolute_drop_in_event_count"},
675
+ )
676
+
677
+ def get_absolute_spike_in_event_count_rows(
678
+ self,
679
+ metric_name: str | None = None,
680
+ metric_names: list[str] | None = None,
681
+ threshold_multiplier: float = 2.0,
682
+ ) -> dict[str, Any]:
683
+ return self._metric_family_filter_rows(
684
+ method_name="get_absolute_spike_in_event_count_rows",
685
+ metric_name=metric_name,
686
+ metric_names=metric_names,
687
+ threshold_multiplier=threshold_multiplier,
688
+ metric_type="event_count",
689
+ allowed_anomaly_types={"absolute_spike_in_event_count"},
690
+ )
691
+
692
+ def get_funnel_break_rows(self, threshold_multiplier: float = 2.0) -> dict[str, Any]:
693
+ details = []
694
+ generated: dict[str, dict[str, Any]] = {}
695
+ total_matches = 0
696
+ for numerator, denominator in FUNNEL_STEPS:
697
+ entity_name = f"{numerator}_from_{denominator}"
698
+ per_date_values = {
699
+ date_key: round(
700
+ self._ratio(getattr(record, numerator), getattr(record, denominator)) * 100.0,
701
+ 4,
702
+ )
703
+ for date_key, record in self._daily_by_date.items()
704
+ }
705
+ values = list(per_date_values.values())
706
+ baseline = median(values) if values else 0.0
707
+ std_from_median = math.sqrt(
708
+ sum((value - baseline) ** 2 for value in values) / len(values)
709
+ ) if values else 0.0
710
+ threshold = max(std_from_median * float(threshold_multiplier), self._funnel_threshold())
711
+ rows = []
712
+ generated_rows = []
713
+ for date_key, observed_value in per_date_values.items():
714
+ delta_value = round(observed_value - baseline, 4)
715
+ if delta_value > -threshold:
716
+ continue
717
+ row = {
718
+ "date": date_key,
719
+ "entity_type": "funnel_step",
720
+ "entity_name": entity_name,
721
+ "anomaly_type": "funnel_break",
722
+ "detection_method": "detect_funnel_break",
723
+ "baseline_value": round(baseline, 4),
724
+ "observed_value": round(observed_value, 4),
725
+ "delta_value": delta_value,
726
+ "severity": self._severity(abs(delta_value), medium=5.0, high=10.0, critical=15.0),
727
+ }
728
+ total_matches += 1
729
+ rows.append(row)
730
+ submission_row = MetricSubmissionRow(**row)
731
+ generated[submission_row_key(submission_row)] = submission_row.model_dump()
732
+ generated_rows.append(submission_row.model_dump())
733
+ details.append(
734
+ {
735
+ "entity_name": entity_name,
736
+ "threshold": round(threshold, 4),
737
+ "match_count": len(rows),
738
+ "rows": rows,
739
+ "generated_rows": generated_rows,
740
+ }
741
+ )
742
+ return {
743
+ "method_name": "get_funnel_break_rows",
744
+ "threshold_multiplier": threshold_multiplier,
745
+ "match_count": total_matches,
746
+ "generated_rows": list(generated.values()),
747
+ "details": details,
748
+ }
749
+
750
+ def get_hourly_traffic_mix_shift_rows(self, threshold_multiplier: float = 2.0) -> dict[str, Any]:
751
+ per_date_values = {}
752
+ for date_key in self._dates:
753
+ summary = self.hourly_rows_for_date(date_key).get("summary", {})
754
+ per_date_values[date_key] = float(summary.get("daytime_share", 0.0))
755
+ values = list(per_date_values.values())
756
+ baseline = median(values) if values else 0.0
757
+ std_from_median = math.sqrt(
758
+ sum((value - baseline) ** 2 for value in values) / len(values)
759
+ ) if values else 0.0
760
+ threshold = std_from_median * float(threshold_multiplier)
761
+ rows = []
762
+ generated_rows = []
763
+ for date_key, observed_value in per_date_values.items():
764
+ delta_value = round(observed_value - baseline, 4)
765
+ if abs(delta_value) <= threshold:
766
+ continue
767
+ row = {
768
+ "date": date_key,
769
+ "entity_type": "hourly_mix",
770
+ "entity_name": "app_opens:daytime_share",
771
+ "anomaly_type": "hourly_traffic_mix_shift",
772
+ "detection_method": "hourly_rows_for_date",
773
+ "baseline_value": round(baseline, 4),
774
+ "observed_value": round(observed_value, 4),
775
+ "delta_value": delta_value,
776
+ "severity": self._severity(abs(delta_value) * 100.0, medium=10.0, high=18.0, critical=25.0),
777
+ }
778
+ rows.append(row)
779
+ generated_rows.append(row)
780
+ return {
781
+ "method_name": "get_hourly_traffic_mix_shift_rows",
782
+ "threshold_multiplier": threshold_multiplier,
783
+ "match_count": len(rows),
784
+ "generated_rows": generated_rows,
785
+ "details": [
786
+ {
787
+ "entity_name": "app_opens:daytime_share",
788
+ "threshold": round(threshold, 4),
789
+ "match_count": len(rows),
790
+ "rows": rows,
791
+ "generated_rows": generated_rows,
792
+ }
793
+ ],
794
+ }
795
+
796
+ def get_instrumentation_data_quality_issue_rows(
797
+ self,
798
+ threshold_multiplier: float = 2.0,
799
+ ) -> dict[str, Any]:
800
+ per_date_totals = {
801
+ date_key: float(self.check_impossible_counts(date_key).get("total_excess", 0.0))
802
+ for date_key in self._dates
803
+ }
804
+ values = list(per_date_totals.values())
805
+ baseline = median(values) if values else 0.0
806
+ std_from_median = math.sqrt(
807
+ sum((value - baseline) ** 2 for value in values) / len(values)
808
+ ) if values else 0.0
809
+ threshold = std_from_median * float(threshold_multiplier)
810
+ generated: dict[str, dict[str, Any]] = {}
811
+ details = []
812
+ total_matches = 0
813
+ for numerator, denominator in FUNNEL_STEPS:
814
+ entity_name = f"{numerator}_lte_{denominator}"
815
+ rows = []
816
+ generated_rows = []
817
+ for date_key in self._dates:
818
+ result = self.check_impossible_counts(date_key)
819
+ issue_names = {item["entity_name"] for item in result.get("issues", [])}
820
+ observed_value = float(result.get("total_excess", 0.0))
821
+ if entity_name not in issue_names or observed_value <= threshold:
822
+ continue
823
+ row = {
824
+ "date": date_key,
825
+ "entity_type": "data_quality",
826
+ "entity_name": entity_name,
827
+ "anomaly_type": "instrumentation_data_quality_issue",
828
+ "detection_method": "check_impossible_counts",
829
+ "baseline_value": round(baseline, 4),
830
+ "observed_value": round(observed_value, 4),
831
+ "delta_value": round(observed_value - baseline, 4),
832
+ "severity": self._severity(observed_value, medium=20.0, high=60.0, critical=120.0),
833
+ }
834
+ total_matches += 1
835
+ rows.append(row)
836
+ submission_row = MetricSubmissionRow(**row)
837
+ generated[submission_row_key(submission_row)] = submission_row.model_dump()
838
+ generated_rows.append(submission_row.model_dump())
839
+ details.append(
840
+ {
841
+ "entity_name": entity_name,
842
+ "threshold": round(threshold, 4),
843
+ "match_count": len(rows),
844
+ "rows": rows,
845
+ "generated_rows": generated_rows,
846
+ }
847
+ )
848
+ return {
849
+ "method_name": "get_instrumentation_data_quality_issue_rows",
850
+ "threshold_multiplier": threshold_multiplier,
851
+ "match_count": total_matches,
852
+ "generated_rows": list(generated.values()),
853
+ "details": details,
854
+ }
855
+
856
+ def payload_generator(
857
+ self,
858
+ generator_methods: list[dict[str, Any]] | list[PayloadGeneratorMethod],
859
+ ) -> dict[str, Any]:
860
+ methods = [
861
+ item if isinstance(item, PayloadGeneratorMethod) else PayloadGeneratorMethod(**item)
862
+ for item in generator_methods
863
+ ]
864
+ generated: dict[str, MetricSubmissionRow] = {}
865
+ details = []
866
+ for method in methods:
867
+ result = self._run_payload_generator_method(method)
868
+ if "error" in result:
869
+ details.append(result)
870
+ continue
871
+ for row in result["generated_rows"]:
872
+ submission_row = MetricSubmissionRow(**row)
873
+ generated[submission_row_key(submission_row)] = submission_row
874
+ details.append(result)
875
+ return {
876
+ "generator_methods": [item.model_dump() for item in methods],
877
+ "generated_row_count": len(generated),
878
+ "generated_rows": [row.model_dump() for row in generated.values()],
879
+ "details": details,
880
+ }
881
+
882
+ def _run_payload_generator_method(self, method: PayloadGeneratorMethod) -> dict[str, Any]:
883
+ if method.method_name == "get_median_filter_rows":
884
+ return self.get_median_filter_rows(
885
+ metric_name=method.metric_name,
886
+ threshold_multiplier=method.threshold_multiplier,
887
+ ) if not method.metric_names else self.get_median_filter_rows_multi(
888
+ metric_name=method.metric_name,
889
+ metric_names=method.metric_names,
890
+ threshold_multiplier=method.threshold_multiplier,
891
+ )
892
+ if method.method_name == "get_rate_drop_from_median_rows":
893
+ return self.get_rate_drop_from_median_rows(
894
+ metric_name=method.metric_name,
895
+ metric_names=method.metric_names,
896
+ threshold_multiplier=method.threshold_multiplier,
897
+ )
898
+ if method.method_name == "get_rate_spike_from_median_rows":
899
+ return self.get_rate_spike_from_median_rows(
900
+ metric_name=method.metric_name,
901
+ metric_names=method.metric_names,
902
+ threshold_multiplier=method.threshold_multiplier,
903
+ )
904
+ if method.method_name == "get_absolute_drop_in_event_count_rows":
905
+ return self.get_absolute_drop_in_event_count_rows(
906
+ metric_name=method.metric_name,
907
+ metric_names=method.metric_names,
908
+ threshold_multiplier=method.threshold_multiplier,
909
+ )
910
+ if method.method_name == "get_absolute_spike_in_event_count_rows":
911
+ return self.get_absolute_spike_in_event_count_rows(
912
+ metric_name=method.metric_name,
913
+ metric_names=method.metric_names,
914
+ threshold_multiplier=method.threshold_multiplier,
915
+ )
916
+ if method.method_name == "get_funnel_break_rows":
917
+ return self.get_funnel_break_rows(threshold_multiplier=method.threshold_multiplier)
918
+ if method.method_name == "get_hourly_traffic_mix_shift_rows":
919
+ return self.get_hourly_traffic_mix_shift_rows(threshold_multiplier=method.threshold_multiplier)
920
+ if method.method_name == "get_instrumentation_data_quality_issue_rows":
921
+ return self.get_instrumentation_data_quality_issue_rows(threshold_multiplier=method.threshold_multiplier)
922
+ return {
923
+ "method": method.model_dump(),
924
+ "error": "Unsupported payload generator method.",
925
+ }
926
+
927
+ def build_row_from_analysis(self, analysis_result: dict[str, Any]) -> dict[str, Any] | None:
928
+ """Extract a payload row when an analysis result directly maps to one."""
929
+ required_fields = {
930
+ "date",
931
+ "entity_type",
932
+ "entity_name",
933
+ "anomaly_type",
934
+ "detection_method",
935
+ "baseline_value",
936
+ "observed_value",
937
+ "delta_value",
938
+ "severity",
939
+ }
940
+ if required_fields.issubset(analysis_result) and analysis_result.get("anomaly_type") != "normal":
941
+ return {field: analysis_result[field] for field in required_fields}
942
+ return None
943
+
944
+ def _conversion_rates(self, record: MetricRecord) -> dict[str, float]:
945
+ return {
946
+ item.name: round(self._rate_for_record(record, item), 4)
947
+ for item in self._context.conversion_definitions
948
+ }
949
+
950
+ def _metric_descriptor(self, metric_name: str) -> dict[str, Any]:
951
+ if metric_name in COUNT_METRICS:
952
+ values = [float(getattr(item, metric_name)) for item in self._context.daily_metrics]
953
+ per_date_values = {
954
+ item.date: float(getattr(item, metric_name))
955
+ for item in self._context.daily_metrics
956
+ }
957
+ return {
958
+ "metric_type": "event_count",
959
+ "values": values,
960
+ "per_date_values": per_date_values,
961
+ }
962
+ definition = self._conversion_map.get(metric_name)
963
+ if definition is None:
964
+ raise ValueError(f"Unknown metric_name: {metric_name}")
965
+ values = [self._rate_for_record(item, definition) for item in self._context.daily_metrics]
966
+ per_date_values = {
967
+ item.date: self._rate_for_record(item, definition)
968
+ for item in self._context.daily_metrics
969
+ }
970
+ return {
971
+ "metric_type": "conversion_rate",
972
+ "values": values,
973
+ "per_date_values": per_date_values,
974
+ }
975
+
976
+ def _resolve_metric_names(
977
+ self,
978
+ *,
979
+ metric_name: str | None,
980
+ metric_names: list[str] | None,
981
+ ) -> list[str]:
982
+ names = [item for item in (metric_names or []) if item]
983
+ if metric_name:
984
+ names.append(metric_name)
985
+ if not names:
986
+ names = list(COUNT_METRICS) + list(self._conversion_map.keys())
987
+ deduped = []
988
+ seen = set()
989
+ for item in names:
990
+ if item in seen:
991
+ continue
992
+ seen.add(item)
993
+ deduped.append(item)
994
+ return deduped
995
+
996
+ def _resolve_metric_names_for_type(
997
+ self,
998
+ *,
999
+ metric_name: str | None,
1000
+ metric_names: list[str] | None,
1001
+ metric_type: str,
1002
+ ) -> list[str]:
1003
+ resolved = self._resolve_metric_names(metric_name=metric_name, metric_names=metric_names)
1004
+ return [
1005
+ item
1006
+ for item in resolved
1007
+ if self._metric_descriptor(item)["metric_type"] == metric_type
1008
+ ]
1009
+
1010
+ def _metric_family_filter_rows(
1011
+ self,
1012
+ *,
1013
+ method_name: str,
1014
+ metric_name: str | None,
1015
+ metric_names: list[str] | None,
1016
+ threshold_multiplier: float,
1017
+ metric_type: str,
1018
+ allowed_anomaly_types: set[str],
1019
+ ) -> dict[str, Any]:
1020
+ resolved_metrics = self._resolve_metric_names_for_type(
1021
+ metric_name=metric_name,
1022
+ metric_names=metric_names,
1023
+ metric_type=metric_type,
1024
+ )
1025
+ raw_result = self.get_median_filter_rows_multi(
1026
+ metric_name=None,
1027
+ metric_names=resolved_metrics,
1028
+ threshold_multiplier=threshold_multiplier,
1029
+ )
1030
+ generated: dict[str, dict[str, Any]] = {}
1031
+ details = []
1032
+ total_matches = 0
1033
+ for detail in raw_result["details"]:
1034
+ filtered_rows = []
1035
+ filtered_generated = []
1036
+ for row in detail["rows"]:
1037
+ suggested = row.get("suggested_payload_row")
1038
+ if not suggested or suggested.get("anomaly_type") not in allowed_anomaly_types:
1039
+ continue
1040
+ filtered_rows.append(row)
1041
+ submission_row = MetricSubmissionRow(**suggested)
1042
+ generated[submission_row_key(submission_row)] = submission_row.model_dump()
1043
+ filtered_generated.append(submission_row.model_dump())
1044
+ total_matches += len(filtered_rows)
1045
+ details.append(
1046
+ {
1047
+ **detail,
1048
+ "match_count": len(filtered_rows),
1049
+ "rows": filtered_rows,
1050
+ "generated_rows": filtered_generated,
1051
+ }
1052
+ )
1053
+ return {
1054
+ "method_name": method_name,
1055
+ "metric_name": metric_name,
1056
+ "metric_names": resolved_metrics,
1057
+ "threshold_multiplier": threshold_multiplier,
1058
+ "match_count": total_matches,
1059
+ "generated_rows": list(generated.values()),
1060
+ "details": details,
1061
+ }
1062
+
1063
+ def _build_submission_row_for_metric(
1064
+ self,
1065
+ *,
1066
+ metric_name: str,
1067
+ date: str,
1068
+ baseline_value: float,
1069
+ observed_value: float,
1070
+ ) -> MetricSubmissionRow | None:
1071
+ delta_value = round(observed_value - baseline_value, 4)
1072
+ if metric_name in COUNT_METRICS:
1073
+ threshold = max(50.0, baseline_value * self._count_threshold_fraction())
1074
+ if abs(delta_value) <= threshold:
1075
+ return None
1076
+ anomaly_type = (
1077
+ "absolute_spike_in_event_count"
1078
+ if delta_value > 0
1079
+ else "absolute_drop_in_event_count"
1080
+ )
1081
+ return MetricSubmissionRow(
1082
+ date=date,
1083
+ entity_type="event_count",
1084
+ entity_name=metric_name,
1085
+ anomaly_type=anomaly_type,
1086
+ detection_method="compare_count_to_median",
1087
+ baseline_value=round(baseline_value, 4),
1088
+ observed_value=round(observed_value, 4),
1089
+ delta_value=delta_value,
1090
+ severity=self._severity(
1091
+ abs(delta_value) / max(baseline_value, 1.0) * 100.0,
1092
+ medium=12.0,
1093
+ high=22.0,
1094
+ critical=35.0,
1095
+ ),
1096
+ )
1097
+ threshold = self._rate_threshold()
1098
+ if abs(delta_value) <= threshold:
1099
+ return None
1100
+ anomaly_type = "rate_spike_from_median" if delta_value > 0 else "rate_drop_from_median"
1101
+ return MetricSubmissionRow(
1102
+ date=date,
1103
+ entity_type="conversion_rate",
1104
+ entity_name=metric_name,
1105
+ anomaly_type=anomaly_type,
1106
+ detection_method="compare_rate_to_median",
1107
+ baseline_value=round(baseline_value, 4),
1108
+ observed_value=round(observed_value, 4),
1109
+ delta_value=delta_value,
1110
+ severity=self._severity(abs(delta_value), medium=4.0, high=8.0, critical=12.0),
1111
+ )
1112
+
1113
+ def _impossible_issues(self, row: MetricRecord, scope: str) -> list[dict[str, Any]]:
1114
+ issues = []
1115
+ for numerator, denominator in FUNNEL_STEPS:
1116
+ numerator_value = getattr(row, numerator)
1117
+ denominator_value = getattr(row, denominator)
1118
+ if numerator_value > denominator_value:
1119
+ issues.append(
1120
+ {
1121
+ "scope": scope,
1122
+ "entity_name": f"{numerator}_lte_{denominator}",
1123
+ "numerator": numerator_value,
1124
+ "denominator": denominator_value,
1125
+ "excess_value": round(float(numerator_value - denominator_value), 4),
1126
+ }
1127
+ )
1128
+ return issues
1129
+
1130
+ def _median_daytime_share(self) -> float:
1131
+ shares = []
1132
+ for date in self._dates:
1133
+ hourly_data = self.hourly_rows_for_date(date)
1134
+ shares.append(hourly_data["summary"]["daytime_share"])
1135
+ return round(median(shares), 4) if shares else 0.0
1136
+
1137
+ @staticmethod
1138
+ def _ratio(numerator: int, denominator: int) -> float:
1139
+ if denominator <= 0:
1140
+ return 0.0
1141
+ return numerator / denominator
1142
+
1143
+ def _rate_for_record(
1144
+ self,
1145
+ record: MetricRecord,
1146
+ definition: ConversionMetricDefinition,
1147
+ ) -> float:
1148
+ return self._ratio(
1149
+ getattr(record, definition.numerator),
1150
+ getattr(record, definition.denominator),
1151
+ ) * 100.0
1152
+
1153
+ def _rate_threshold(self) -> float:
1154
+ difficulty = (self._context.config or {}).get("difficulty", "medium")
1155
+ return {"easy": 6.0, "medium": 4.5, "hard": 3.0}.get(difficulty, 4.5)
1156
+
1157
+ def _funnel_threshold(self) -> float:
1158
+ difficulty = (self._context.config or {}).get("difficulty", "medium")
1159
+ return {"easy": 7.0, "medium": 5.0, "hard": 3.5}.get(difficulty, 5.0)
1160
+
1161
+ def _count_threshold_fraction(self) -> float:
1162
+ difficulty = (self._context.config or {}).get("difficulty", "medium")
1163
+ return {"easy": 0.22, "medium": 0.15, "hard": 0.10}.get(difficulty, 0.15)
1164
+
1165
+ @staticmethod
1166
+ def _severity(value: float, *, medium: float, high: float, critical: float) -> str:
1167
+ if value >= critical:
1168
+ return "critical"
1169
+ if value >= high:
1170
+ return "high"
1171
+ if value >= medium:
1172
+ return "medium"
1173
+ return "low"
1174
+
1175
+
1176
+ def preview_submission_rows(
1177
+ rows: list[dict[str, Any]] | list[MetricSubmissionRow],
1178
+ ) -> SubmissionPreview:
1179
+ """Validate submission rows without using ground truth."""
1180
+ normalized_rows: list[MetricSubmissionRow] = []
1181
+ issues: list[SubmissionIssue] = []
1182
+ seen: set[str] = set()
1183
+ duplicate_rows = 0
1184
+ invalid_rows = 0
1185
+
1186
+ for index, row in enumerate(rows):
1187
+ try:
1188
+ normalized = row if isinstance(row, MetricSubmissionRow) else MetricSubmissionRow(**row)
1189
+ except Exception as exc:
1190
+ invalid_rows += 1
1191
+ issues.append(
1192
+ SubmissionIssue(
1193
+ row_key=f"row_{index}",
1194
+ issue_type="invalid_row",
1195
+ message=f"Row could not be parsed: {exc}",
1196
+ submitted_row=row if isinstance(row, dict) else None,
1197
+ )
1198
+ )
1199
+ continue
1200
+
1201
+ row_key = submission_row_key(normalized)
1202
+ if row_key in seen:
1203
+ duplicate_rows += 1
1204
+ issues.append(
1205
+ SubmissionIssue(
1206
+ row_key=row_key,
1207
+ issue_type="duplicate_row",
1208
+ message="Duplicate date/entity row detected.",
1209
+ submitted_row=normalized.model_dump(),
1210
+ )
1211
+ )
1212
+ continue
1213
+
1214
+ seen.add(row_key)
1215
+ normalized_rows.append(normalized)
1216
+
1217
+ return SubmissionPreview(
1218
+ valid_rows=len(normalized_rows),
1219
+ invalid_rows=invalid_rows,
1220
+ duplicate_rows=duplicate_rows,
1221
+ unique_keys=len(seen),
1222
+ issues=issues,
1223
+ normalized_rows=normalized_rows,
1224
+ )
1225
+
1226
+
1227
+ def submission_row_key(row: MetricSubmissionRow) -> str:
1228
+ """Stable row key for matching submissions and expectations."""
1229
+ return f"{row.date}|{row.entity_type}|{row.entity_name}"
client.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Client for the metric tracker RL environment."""
2
+
3
+ from typing import Dict
4
+
5
+ from openenv.core import EnvClient
6
+ from openenv.core.client_types import StepResult
7
+ from openenv.core.env_server.types import State
8
+
9
+ from .models import MetricTrackerRlAction, MetricTrackerRlObservation
10
+
11
+
12
+ class MetricTrackerRlEnv(
13
+ EnvClient[MetricTrackerRlAction, MetricTrackerRlObservation, State]
14
+ ):
15
+ """Typed client for the metric tracking environment."""
16
+
17
+ def _step_payload(self, action: MetricTrackerRlAction) -> Dict:
18
+ """Serialize the action as JSON for the environment server."""
19
+ return action.model_dump()
20
+
21
+ def _parse_result(self, payload: Dict) -> StepResult[MetricTrackerRlObservation]:
22
+ """Parse environment responses into a typed observation."""
23
+ observation = MetricTrackerRlObservation(**payload.get("observation", {}))
24
+ return StepResult(
25
+ observation=observation,
26
+ reward=payload.get("reward"),
27
+ done=payload.get("done", False),
28
+ )
29
+
30
+ def _parse_state(self, payload: Dict) -> State:
31
+ """Parse environment state payloads."""
32
+ return State(
33
+ episode_id=payload.get("episode_id"),
34
+ step_count=payload.get("step_count", 0),
35
+ )
evaluation.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Deterministic grading for the metric tracker RL environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+
7
+ try:
8
+ from .analysis_tools import preview_submission_rows, submission_row_key
9
+ from .models import MetricSubmissionRow, RewardBreakdown, SubmissionIssue, SubmissionPreview
10
+ except ImportError:
11
+ from analysis_tools import preview_submission_rows, submission_row_key
12
+ from models import MetricSubmissionRow, RewardBreakdown, SubmissionIssue, SubmissionPreview
13
+
14
+
15
+ @dataclass(frozen=True)
16
+ class EvaluationConfig:
17
+ """Tunable parameters for deterministic scoring."""
18
+
19
+ value_tolerance: float = 0.06
20
+ delta_tolerance: float = 0.06
21
+ precision_weight: float = 0.30
22
+ recall_weight: float = 0.30
23
+ anomaly_type_weight: float = 0.12
24
+ detection_method_weight: float = 0.10
25
+ value_weight: float = 0.12
26
+ severity_weight: float = 0.06
27
+ extra_row_penalty: float = 0.03
28
+ duplicate_row_penalty: float = 0.04
29
+ invalid_row_penalty: float = 0.05
30
+ exploit_row_multiplier: float = 3.0
31
+ exploit_penalty: float = 0.15
32
+
33
+
34
+ @dataclass
35
+ class EvaluationResult:
36
+ """Complete scoring result."""
37
+
38
+ preview: SubmissionPreview
39
+ issues: list[SubmissionIssue]
40
+ reward_breakdown: RewardBreakdown
41
+ matched_rows: int
42
+ is_perfect: bool
43
+
44
+
45
+ def evaluate_submission(
46
+ submitted_rows: list[dict] | list[MetricSubmissionRow],
47
+ expected_rows: list[MetricSubmissionRow],
48
+ config: EvaluationConfig | None = None,
49
+ *,
50
+ include_debug_expected: bool = False,
51
+ ) -> EvaluationResult:
52
+ """Grade one submission against deterministic expectations."""
53
+ cfg = config or EvaluationConfig()
54
+ preview = preview_submission_rows(submitted_rows)
55
+ expected_map = {submission_row_key(row): row for row in expected_rows}
56
+ submitted_map = {submission_row_key(row): row for row in preview.normalized_rows}
57
+
58
+ issues = list(preview.issues)
59
+ matched_keys = [key for key in submitted_map if key in expected_map]
60
+ extra_keys = [key for key in submitted_map if key not in expected_map]
61
+ missing_keys = [key for key in expected_map if key not in submitted_map]
62
+
63
+ anomaly_type_hits = 0
64
+ detection_method_hits = 0
65
+ value_hits = 0.0
66
+ severity_hits = 0
67
+
68
+ for key in matched_keys:
69
+ submitted = submitted_map[key]
70
+ expected = expected_map[key]
71
+ field_issues = _field_issues(submitted, expected, cfg, include_debug_expected)
72
+ issues.extend(field_issues)
73
+ if submitted.anomaly_type == expected.anomaly_type:
74
+ anomaly_type_hits += 1
75
+ if submitted.detection_method == expected.detection_method:
76
+ detection_method_hits += 1
77
+ value_hits += _value_match_score(submitted, expected, cfg)
78
+ if submitted.severity == expected.severity:
79
+ severity_hits += 1
80
+
81
+ for key in extra_keys:
82
+ submitted = submitted_map[key]
83
+ issues.append(
84
+ SubmissionIssue(
85
+ row_key=key,
86
+ issue_type="extra_row",
87
+ message="Row is not expected for this episode.",
88
+ submitted_row=submitted.model_dump(),
89
+ expected_row=None,
90
+ )
91
+ )
92
+
93
+ for key in missing_keys:
94
+ expected = expected_map[key]
95
+ issues.append(
96
+ SubmissionIssue(
97
+ row_key=key,
98
+ issue_type="missing_row",
99
+ message="Expected anomaly row is missing from the submission.",
100
+ submitted_row=None,
101
+ expected_row=expected.model_dump() if include_debug_expected else None,
102
+ )
103
+ )
104
+
105
+ valid_submitted = len(preview.normalized_rows)
106
+ matched_count = len(matched_keys)
107
+ expected_count = len(expected_rows)
108
+ precision = matched_count / valid_submitted if valid_submitted else 0.0
109
+ recall = matched_count / expected_count if expected_count else 1.0
110
+ denominator = max(matched_count, 1)
111
+ anomaly_type_accuracy = anomaly_type_hits / denominator if matched_count else 0.0
112
+ detection_method_accuracy = detection_method_hits / denominator if matched_count else 0.0
113
+ value_accuracy = value_hits / denominator if matched_count else 0.0
114
+ severity_accuracy = severity_hits / denominator if matched_count else 0.0
115
+
116
+ extra_penalty = min(0.5, len(extra_keys) * cfg.extra_row_penalty)
117
+ duplicate_penalty = min(0.4, preview.duplicate_rows * cfg.duplicate_row_penalty)
118
+ invalid_penalty = min(0.4, preview.invalid_rows * cfg.invalid_row_penalty)
119
+ exploit_penalty = 0.0
120
+ exploit_limit = max(6, int(expected_count * cfg.exploit_row_multiplier))
121
+ if valid_submitted > exploit_limit:
122
+ exploit_penalty = cfg.exploit_penalty
123
+
124
+ total_score = (
125
+ precision * cfg.precision_weight
126
+ + recall * cfg.recall_weight
127
+ + anomaly_type_accuracy * cfg.anomaly_type_weight
128
+ + detection_method_accuracy * cfg.detection_method_weight
129
+ + value_accuracy * cfg.value_weight
130
+ + severity_accuracy * cfg.severity_weight
131
+ - extra_penalty
132
+ - duplicate_penalty
133
+ - invalid_penalty
134
+ - exploit_penalty
135
+ )
136
+ total_score = max(0.0, min(1.0, round(total_score, 6)))
137
+
138
+ breakdown = RewardBreakdown(
139
+ precision=round(precision, 6),
140
+ recall=round(recall, 6),
141
+ anomaly_type_accuracy=round(anomaly_type_accuracy, 6),
142
+ detection_method_accuracy=round(detection_method_accuracy, 6),
143
+ value_accuracy=round(value_accuracy, 6),
144
+ severity_accuracy=round(severity_accuracy, 6),
145
+ extra_row_penalty=round(extra_penalty, 6),
146
+ duplicate_penalty=round(duplicate_penalty, 6),
147
+ invalid_row_penalty=round(invalid_penalty, 6),
148
+ exploit_penalty=round(exploit_penalty, 6),
149
+ total_score=total_score,
150
+ matched_rows=matched_count,
151
+ expected_rows=expected_count,
152
+ submitted_rows=len(submitted_rows),
153
+ valid_submitted_rows=valid_submitted,
154
+ extra_rows=len(extra_keys),
155
+ duplicate_rows=preview.duplicate_rows,
156
+ invalid_rows=preview.invalid_rows,
157
+ missing_rows=len(missing_keys),
158
+ )
159
+ is_perfect = total_score >= 0.999999 and not issues
160
+ return EvaluationResult(
161
+ preview=preview,
162
+ issues=issues,
163
+ reward_breakdown=breakdown,
164
+ matched_rows=matched_count,
165
+ is_perfect=is_perfect,
166
+ )
167
+
168
+
169
+ def _field_issues(
170
+ submitted: MetricSubmissionRow,
171
+ expected: MetricSubmissionRow,
172
+ cfg: EvaluationConfig,
173
+ include_debug_expected: bool,
174
+ ) -> list[SubmissionIssue]:
175
+ issues: list[SubmissionIssue] = []
176
+ row_key = submission_row_key(expected)
177
+ expected_dump = expected.model_dump() if include_debug_expected else None
178
+ if submitted.anomaly_type != expected.anomaly_type:
179
+ issues.append(
180
+ SubmissionIssue(
181
+ row_key=row_key,
182
+ issue_type="wrong_anomaly_type",
183
+ message=f"Expected anomaly_type={expected.anomaly_type}.",
184
+ submitted_row=submitted.model_dump(),
185
+ expected_row=expected_dump,
186
+ )
187
+ )
188
+ if submitted.detection_method != expected.detection_method:
189
+ issues.append(
190
+ SubmissionIssue(
191
+ row_key=row_key,
192
+ issue_type="wrong_detection_method",
193
+ message=f"Expected detection_method={expected.detection_method}.",
194
+ submitted_row=submitted.model_dump(),
195
+ expected_row=expected_dump,
196
+ )
197
+ )
198
+ if not _close(submitted.baseline_value, expected.baseline_value, cfg.value_tolerance):
199
+ issues.append(
200
+ SubmissionIssue(
201
+ row_key=row_key,
202
+ issue_type="wrong_baseline_value",
203
+ message="Baseline value is outside tolerance.",
204
+ submitted_row=submitted.model_dump(),
205
+ expected_row=expected_dump,
206
+ )
207
+ )
208
+ if not _close(submitted.observed_value, expected.observed_value, cfg.value_tolerance):
209
+ issues.append(
210
+ SubmissionIssue(
211
+ row_key=row_key,
212
+ issue_type="wrong_observed_value",
213
+ message="Observed value is outside tolerance.",
214
+ submitted_row=submitted.model_dump(),
215
+ expected_row=expected_dump,
216
+ )
217
+ )
218
+ if not _close(submitted.delta_value, expected.delta_value, cfg.delta_tolerance):
219
+ issues.append(
220
+ SubmissionIssue(
221
+ row_key=row_key,
222
+ issue_type="wrong_delta_value",
223
+ message="Delta value is outside tolerance.",
224
+ submitted_row=submitted.model_dump(),
225
+ expected_row=expected_dump,
226
+ )
227
+ )
228
+ if submitted.severity != expected.severity:
229
+ issues.append(
230
+ SubmissionIssue(
231
+ row_key=row_key,
232
+ issue_type="wrong_severity",
233
+ message=f"Expected severity={expected.severity}.",
234
+ submitted_row=submitted.model_dump(),
235
+ expected_row=expected_dump,
236
+ )
237
+ )
238
+ return issues
239
+
240
+
241
+ def _value_match_score(
242
+ submitted: MetricSubmissionRow,
243
+ expected: MetricSubmissionRow,
244
+ cfg: EvaluationConfig,
245
+ ) -> float:
246
+ checks = [
247
+ _close(submitted.baseline_value, expected.baseline_value, cfg.value_tolerance),
248
+ _close(submitted.observed_value, expected.observed_value, cfg.value_tolerance),
249
+ _close(submitted.delta_value, expected.delta_value, cfg.delta_tolerance),
250
+ ]
251
+ return sum(1.0 for ok in checks if ok) / len(checks)
252
+
253
+
254
+ def _close(submitted: float, expected: float, tolerance: float) -> bool:
255
+ allowed = max(tolerance, abs(expected) * tolerance)
256
+ return abs(submitted - expected) <= allowed
inference.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool-driven inference for the metric tracker RL environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import asyncio
6
+ import json
7
+ import os
8
+ import textwrap
9
+ from dataclasses import dataclass, field
10
+ from typing import Any
11
+
12
+ from openai import APIStatusError, OpenAI
13
+
14
+ from metric_tracker_rl import DEFAULT_TASK_ORDER, MetricTrackerRlAction, MetricTrackerRlEnv, get_task_spec
15
+ from metric_tracker_rl.analysis_tools import available_analysis_methods
16
+ from metric_tracker_rl.models import (
17
+ MetricSubmissionRow,
18
+ MetricTrackerRlObservation,
19
+ PayloadGeneratorMethod,
20
+ )
21
+
22
+
23
+ IMAGE_NAME = os.getenv("IMAGE_NAME") or "metric_tracker:latest"
24
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
25
+ API_BASE_URL = (
26
+ os.getenv("API_BASE_URL")
27
+ or os.getenv("OPENAI_BASE_URL")
28
+ or "https://router.huggingface.co/v1"
29
+ )
30
+ MODEL_NAME = os.getenv("MODEL_NAME") or os.getenv("OPENAI_MODEL") or "Qwen/Qwen2.5-72B-Instruct"
31
+ BASE_URL = os.getenv("BASE_URL")
32
+ TASK_NAME = os.getenv("MetricTrackerRl_TASK", "multi_task_agent_baseline")
33
+ BENCHMARK = os.getenv("MetricTrackerRl_BENCHMARK", "metric_tracker_rl")
34
+ TEMPERATURE = float(os.getenv("TEMPERATURE", "0"))
35
+ MAX_TOKENS = min(int(os.getenv("MAX_TOKENS", "1000")), 4096)
36
+ MAX_TOOL_ROUNDS = int(os.getenv("MAX_TOOL_ROUNDS", "16"))
37
+
38
+ SYSTEM_PROMPT = textwrap.dedent(
39
+ """
40
+ You are solving a multi-anomaly analytics benchmark with tool use.
41
+
42
+ Rules:
43
+ - Use only the shared safe analysis methods.
44
+ - Do not request full hidden answers or assume direct access to ground truth.
45
+ - Prefer declarative payload generators over manual row construction.
46
+ - Start from the default reset observation only.
47
+ - Start by trying `get_median_filter_rows` across different metrics to learn which metrics produce useful anomaly rows.
48
+ - Compare candidate metrics, then refine with raw-data inspection and median/std methods only when needed.
49
+ - Prefer: task_overview -> get_median_filter_rows on several metrics -> compare useful results -> payload_generator -> submit_payload_generator.
50
+ - Keep notes brief and factual.
51
+ """
52
+ ).strip()
53
+
54
+
55
+ @dataclass
56
+ class ToolRuntimeState:
57
+ """Mutable state shared across tool calls."""
58
+
59
+ method_log: list[dict[str, Any]] = field(default_factory=list)
60
+ last_preview: dict[str, Any] | None = None
61
+
62
+
63
+ def log_start(task: str, env: str, model: str) -> None:
64
+ print(f"[START] task={task} env={env} model={model}", flush=True)
65
+
66
+
67
+ def log_method(tool_name: str, arguments: dict[str, Any], note: str) -> None:
68
+ print(
69
+ f"[METHOD] name={tool_name} args={json.dumps(arguments, sort_keys=True)} why={note}",
70
+ flush=True,
71
+ )
72
+
73
+
74
+ def log_payload_generator_methods(tool_name: str, generator_methods: list[dict[str, Any]]) -> None:
75
+ print(
76
+ f"[PAYLOAD_GENERATOR_METHODS] source={tool_name} methods={json.dumps(generator_methods, sort_keys=True)}",
77
+ flush=True,
78
+ )
79
+
80
+
81
+ def log_step(step: int, action: str, reward: float, done: bool, error: str | None) -> None:
82
+ error_val = error if error else "null"
83
+ print(
84
+ f"[STEP] step={step} action={action} reward={reward:.3f} done={str(done).lower()} error={error_val}",
85
+ flush=True,
86
+ )
87
+
88
+
89
+ def log_end(success: bool, steps: int, score: float, method_log: list[dict[str, Any]]) -> None:
90
+ print(
91
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} methods={len(method_log)}",
92
+ flush=True,
93
+ )
94
+ print(json.dumps({"method_log": method_log}, indent=2), flush=True)
95
+
96
+
97
+ def log_task_boundary(task_id: str, difficulty: str, phase: str) -> None:
98
+ print(f"[TASK_{phase}] task_id={task_id} difficulty={difficulty}", flush=True)
99
+
100
+
101
+ def tool_schemas() -> list[dict[str, Any]]:
102
+ """OpenAI-compatible tool definitions."""
103
+ shared_schemas = []
104
+ for spec in available_analysis_methods():
105
+ properties = {}
106
+ required = []
107
+ if spec.name in {"rows_for_date", "hourly_rows_for_date", "detect_funnel_break", "check_impossible_counts"}:
108
+ properties = {"date": {"type": "string"}}
109
+ required = ["date"]
110
+ elif spec.name in {"compare_rate_to_median", "compare_count_to_median"}:
111
+ properties = {
112
+ "date": {"type": "string"},
113
+ "entity_name": {"type": "string"},
114
+ }
115
+ required = ["date", "entity_name"]
116
+ elif spec.name == "list_suspicious_dates":
117
+ properties = {"limit": {"type": "integer", "default": 10}}
118
+ elif spec.name == "preview_submission":
119
+ properties = {
120
+ "rows": {
121
+ "type": "array",
122
+ "items": {"type": "object"},
123
+ }
124
+ }
125
+ elif spec.name == "show_raw_data":
126
+ properties = {"limit": {"type": "integer", "default": 5}}
127
+ elif spec.name in {"get_metric_median", "get_metric_std_dev_from_median"}:
128
+ properties = {
129
+ "metric_name": {"type": "string"},
130
+ "metric_names": {"type": "array", "items": {"type": "string"}},
131
+ }
132
+ elif spec.name == "get_rows_with_abs_diff_from_median_gt":
133
+ properties = {
134
+ "metric_name": {"type": "string"},
135
+ "metric_names": {"type": "array", "items": {"type": "string"}},
136
+ "threshold": {"type": "number"},
137
+ }
138
+ required = ["threshold"]
139
+ elif spec.name in {
140
+ "get_median_filter_rows",
141
+ "get_rate_drop_from_median_rows",
142
+ "get_rate_spike_from_median_rows",
143
+ "get_absolute_drop_in_event_count_rows",
144
+ "get_absolute_spike_in_event_count_rows",
145
+ }:
146
+ properties = {
147
+ "metric_name": {"type": "string"},
148
+ "metric_names": {"type": "array", "items": {"type": "string"}},
149
+ "threshold_multiplier": {"type": "number"},
150
+ }
151
+ required = ["threshold_multiplier"]
152
+ elif spec.name in {
153
+ "get_funnel_break_rows",
154
+ "get_hourly_traffic_mix_shift_rows",
155
+ "get_instrumentation_data_quality_issue_rows",
156
+ }:
157
+ properties = {
158
+ "threshold_multiplier": {"type": "number"},
159
+ }
160
+ required = ["threshold_multiplier"]
161
+ elif spec.name == "payload_generator":
162
+ properties = {
163
+ "generator_methods": {
164
+ "type": "array",
165
+ "items": {"type": "object"},
166
+ }
167
+ }
168
+ required = ["generator_methods"]
169
+ shared_schemas.append(
170
+ {
171
+ "type": "function",
172
+ "function": {
173
+ "name": spec.name,
174
+ "description": spec.description,
175
+ "parameters": {
176
+ "type": "object",
177
+ "properties": properties,
178
+ "required": required,
179
+ "additionalProperties": False,
180
+ },
181
+ },
182
+ }
183
+ )
184
+ shared_schemas.append(
185
+ {
186
+ "type": "function",
187
+ "function": {
188
+ "name": "submit_payload_generator",
189
+ "description": "Submit declarative payload generator methods for environment-side payload generation and grading.",
190
+ "parameters": {
191
+ "type": "object",
192
+ "properties": {
193
+ "generator_methods": {
194
+ "type": "array",
195
+ "items": {"type": "object"},
196
+ }
197
+ },
198
+ "required": ["generator_methods"],
199
+ "additionalProperties": False,
200
+ },
201
+ },
202
+ }
203
+ )
204
+ shared_schemas.append(
205
+ {
206
+ "type": "function",
207
+ "function": {
208
+ "name": "submit_solution",
209
+ "description": "Submit the final anomaly payload to the environment.",
210
+ "parameters": {
211
+ "type": "object",
212
+ "properties": {
213
+ "rows": {
214
+ "type": "array",
215
+ "items": {"type": "object"},
216
+ }
217
+ },
218
+ "required": ["rows"],
219
+ "additionalProperties": False,
220
+ },
221
+ },
222
+ }
223
+ )
224
+ return shared_schemas
225
+
226
+
227
+ def build_initial_user_prompt(observation: MetricTrackerRlObservation) -> str:
228
+ return textwrap.dedent(
229
+ f"""
230
+ Solve the RL environment with tools.
231
+
232
+ Initial observation:
233
+ {json.dumps(observation.model_dump(exclude={"debug"}), indent=2)}
234
+
235
+ Prefer building a payload generator first, then submit it.
236
+ Start by calling `get_median_filter_rows` on several different metrics and see which ones return useful anomaly rows.
237
+ If a metric returns nothing or low-signal rows, try another metric.
238
+ For funnel, hourly mix, or data-quality tasks, use the family-specific generator methods instead.
239
+
240
+ Final payload rows use:
241
+ `date`, `entity_type`, `entity_name`, `anomaly_type`, `detection_method`,
242
+ `baseline_value`, `observed_value`, `delta_value`, `severity`.
243
+
244
+ Supported generator method example:
245
+ `{{"method_name":"get_median_filter_rows","threshold_multiplier":2.0}}`
246
+ or
247
+ `{{"method_name":"get_median_filter_rows","metric_names":["app_open_to_order_placed","orders_placed"],"threshold_multiplier":2.0}}`
248
+
249
+ Use shared analysis methods only. Prefer `submit_payload_generator` over `submit_solution`.
250
+ """
251
+ ).strip()
252
+
253
+
254
+ def create_chat_completion(client: OpenAI, **kwargs):
255
+ try:
256
+ return client.chat.completions.create(**kwargs)
257
+ except APIStatusError as exc:
258
+ if exc.status_code == 402:
259
+ raise RuntimeError(
260
+ "The configured inference provider rejected the request with HTTP 402. "
261
+ "Your Hugging Face router credits are depleted. Update `.env.inference` "
262
+ "with a working provider/key, or switch `API_BASE_URL`/`MODEL_NAME`."
263
+ ) from exc
264
+ raise
265
+
266
+
267
+ def decode_arguments(raw_arguments: str | None) -> dict[str, Any]:
268
+ if not raw_arguments:
269
+ return {}
270
+ return json.loads(raw_arguments)
271
+
272
+
273
+ def preview_text(text: str, limit: int = 220) -> str:
274
+ return text.replace("\n", " ")[:limit]
275
+
276
+
277
+ async def connect_env() -> MetricTrackerRlEnv:
278
+ if BASE_URL:
279
+ return MetricTrackerRlEnv(base_url=BASE_URL)
280
+ return await MetricTrackerRlEnv.from_docker_image(IMAGE_NAME)
281
+
282
+
283
+ async def execute_tool_call(
284
+ env: MetricTrackerRlEnv,
285
+ observation: MetricTrackerRlObservation,
286
+ runtime_state: ToolRuntimeState,
287
+ tool_name: str,
288
+ arguments: dict[str, Any],
289
+ ) -> tuple[dict[str, Any], Any | None, MetricTrackerRlObservation]:
290
+ """Execute one model-requested tool locally."""
291
+ if tool_name == "submit_payload_generator":
292
+ methods = [
293
+ PayloadGeneratorMethod(**item)
294
+ for item in arguments.get("generator_methods", [])
295
+ ]
296
+ runtime_state.method_log.append(
297
+ {
298
+ "tool_name": tool_name,
299
+ "arguments": arguments,
300
+ "generator_methods": [item.model_dump() for item in methods],
301
+ "note": _tool_note(tool_name, arguments),
302
+ }
303
+ )
304
+ result = await env.step(MetricTrackerRlAction(payload_generators=methods))
305
+ return (
306
+ {
307
+ "status": result.observation.status,
308
+ "message": result.observation.message,
309
+ "reward": result.reward,
310
+ "done": result.done,
311
+ "generated_rows": [row.model_dump() for row in result.observation.generated_rows],
312
+ "submission_issues": [issue.model_dump() for issue in result.observation.submission_issues],
313
+ "reward_breakdown": (
314
+ result.observation.reward_breakdown.model_dump()
315
+ if result.observation.reward_breakdown
316
+ else None
317
+ ),
318
+ },
319
+ result,
320
+ result.observation,
321
+ )
322
+ if tool_name == "submit_solution":
323
+ rows = [MetricSubmissionRow(**row) for row in arguments.get("rows", [])]
324
+ result = await env.step(MetricTrackerRlAction(classifications=rows))
325
+ return (
326
+ {
327
+ "status": result.observation.status,
328
+ "message": result.observation.message,
329
+ "reward": result.reward,
330
+ "done": result.done,
331
+ "reward_breakdown": (
332
+ result.observation.reward_breakdown.model_dump()
333
+ if result.observation.reward_breakdown
334
+ else None
335
+ ),
336
+ "issue_count": len(result.observation.submission_issues),
337
+ "correct_row_count": result.observation.correct_row_count,
338
+ },
339
+ result,
340
+ result.observation,
341
+ )
342
+
343
+ result = await env.step(
344
+ MetricTrackerRlAction(
345
+ analysis_method=tool_name,
346
+ analysis_args=arguments,
347
+ )
348
+ )
349
+ output = result.observation.analysis_result or {
350
+ "method": tool_name,
351
+ "arguments": arguments,
352
+ "result": None,
353
+ }
354
+ log_arguments = {
355
+ "tool_name": tool_name,
356
+ "arguments": arguments,
357
+ "note": _tool_note(tool_name, arguments),
358
+ }
359
+ if tool_name == "payload_generator":
360
+ log_arguments["generator_methods"] = arguments.get("generator_methods", [])
361
+ runtime_state.method_log.append(
362
+ log_arguments
363
+ )
364
+ if tool_name == "preview_submission":
365
+ runtime_state.last_preview = output
366
+ return output, None, result.observation
367
+
368
+
369
+ def _tool_note(tool_name: str, arguments: dict[str, Any]) -> str:
370
+ notes = {
371
+ "task_overview": "bootstrap the task and payload schema",
372
+ "list_dates": "confirm the date range",
373
+ "list_entities": "confirm valid entities",
374
+ "rows_for_date": "inspect daily counts on one date",
375
+ "hourly_rows_for_date": "inspect hourly traffic shape",
376
+ "compare_rate_to_median": "check a conversion-rate anomaly against median baseline",
377
+ "compare_count_to_median": "check an absolute count anomaly against median baseline",
378
+ "detect_funnel_break": "test whether a funnel step is broken",
379
+ "check_impossible_counts": "test for instrumentation or impossible count issues",
380
+ "list_suspicious_dates": "prioritize dates worth deeper inspection",
381
+ "preview_submission": "validate payload structure before submit",
382
+ "show_raw_data": "inspect daily aggregate rows in head() form",
383
+ "get_metric_median": "measure a baseline median for one metric",
384
+ "get_metric_std_dev_from_median": "measure metric spread around the median",
385
+ "get_rows_with_abs_diff_from_median_gt": "inspect dates outside a chosen absolute-difference threshold",
386
+ "get_median_filter_rows": "generate candidate anomaly rows using median and std-from-median filtering",
387
+ "get_rate_drop_from_median_rows": "generate candidate conversion-rate drop rows using median and std-from-median filtering",
388
+ "get_rate_spike_from_median_rows": "generate candidate conversion-rate spike rows using median and std-from-median filtering",
389
+ "get_absolute_drop_in_event_count_rows": "generate candidate event-count drop rows using median and std-from-median filtering",
390
+ "get_absolute_spike_in_event_count_rows": "generate candidate event-count spike rows using median and std-from-median filtering",
391
+ "get_funnel_break_rows": "generate candidate funnel-break rows across funnel steps",
392
+ "get_hourly_traffic_mix_shift_rows": "generate candidate hourly traffic mix shift rows across dates",
393
+ "get_instrumentation_data_quality_issue_rows": "generate candidate impossible-count or instrumentation-issue rows across dates",
394
+ "payload_generator": "merge multiple generator methods into one candidate payload",
395
+ "submit_payload_generator": "submit generator methods for environment-side generation and grading",
396
+ }
397
+ return notes.get(tool_name, f"run {tool_name} with {arguments}")
398
+
399
+
400
+ async def run_agent_loop(
401
+ client: OpenAI,
402
+ env: MetricTrackerRlEnv,
403
+ observation: MetricTrackerRlObservation,
404
+ ) -> tuple[Any, str, int, list[dict[str, Any]]]:
405
+ """Run a tool-calling loop until the env is solved or the round limit is hit."""
406
+ runtime_state = ToolRuntimeState()
407
+ current_observation = observation
408
+ messages: list[dict[str, Any]] = [
409
+ {"role": "system", "content": SYSTEM_PROMPT},
410
+ {"role": "user", "content": build_initial_user_prompt(current_observation)},
411
+ ]
412
+ last_result = None
413
+ final_text = ""
414
+ tool_rounds = 0
415
+
416
+ for _ in range(MAX_TOOL_ROUNDS):
417
+ completion = create_chat_completion(
418
+ client,
419
+ model=MODEL_NAME,
420
+ messages=messages,
421
+ tools=tool_schemas(),
422
+ tool_choice="auto",
423
+ temperature=TEMPERATURE,
424
+ max_tokens=MAX_TOKENS,
425
+ stream=False,
426
+ )
427
+ message = completion.choices[0].message
428
+ assistant_payload: dict[str, Any] = {
429
+ "role": "assistant",
430
+ "content": message.content or "",
431
+ }
432
+ if message.tool_calls:
433
+ assistant_payload["tool_calls"] = [
434
+ {
435
+ "id": tool_call.id,
436
+ "type": tool_call.type,
437
+ "function": {
438
+ "name": tool_call.function.name,
439
+ "arguments": tool_call.function.arguments,
440
+ },
441
+ }
442
+ for tool_call in message.tool_calls
443
+ ]
444
+ messages.append(assistant_payload)
445
+
446
+ if not message.tool_calls:
447
+ final_text = (message.content or "").strip()
448
+ break
449
+
450
+ tool_rounds += 1
451
+ for tool_call in message.tool_calls:
452
+ tool_name = tool_call.function.name
453
+ arguments = decode_arguments(tool_call.function.arguments)
454
+ if tool_name != "submit_solution":
455
+ log_method(tool_name, arguments, _tool_note(tool_name, arguments))
456
+ if tool_name in {"payload_generator", "submit_payload_generator"}:
457
+ log_payload_generator_methods(
458
+ tool_name,
459
+ arguments.get("generator_methods", []),
460
+ )
461
+ tool_output, maybe_result, current_observation = await execute_tool_call(
462
+ env,
463
+ current_observation,
464
+ runtime_state,
465
+ tool_name,
466
+ arguments,
467
+ )
468
+ messages.append(
469
+ {
470
+ "role": "tool",
471
+ "tool_call_id": tool_call.id,
472
+ "content": json.dumps(tool_output),
473
+ }
474
+ )
475
+ if maybe_result is not None:
476
+ last_result = maybe_result
477
+
478
+ if last_result is not None:
479
+ completion = create_chat_completion(
480
+ client,
481
+ model=MODEL_NAME,
482
+ messages=messages,
483
+ temperature=TEMPERATURE,
484
+ max_tokens=MAX_TOKENS,
485
+ stream=False,
486
+ )
487
+ final_text = (completion.choices[0].message.content or "").strip()
488
+ break
489
+
490
+ return last_result, final_text, tool_rounds, runtime_state.method_log
491
+
492
+
493
+ async def run_single_task(
494
+ client: OpenAI,
495
+ env: MetricTrackerRlEnv,
496
+ task_id: str,
497
+ ) -> dict[str, Any]:
498
+ """Run one named benchmark task and return a reproducible summary."""
499
+ task_spec = get_task_spec(task_id)
500
+ log_task_boundary(task_spec.task_id, task_spec.difficulty, "START")
501
+ reset_result = await env.reset(task_id=task_spec.task_id)
502
+ final_result, final_text, tool_rounds, method_log = await run_agent_loop(
503
+ client,
504
+ env,
505
+ reset_result.observation,
506
+ )
507
+ if final_result is None:
508
+ raise RuntimeError(f"The model never submitted a graded action for task `{task_spec.task_id}`.")
509
+
510
+ reward = float(final_result.reward or 0.0)
511
+ success = bool(final_result.done and reward >= 0.999999)
512
+ log_step(
513
+ step=1,
514
+ action=preview_text(final_text or "graded_submission"),
515
+ reward=reward,
516
+ done=bool(final_result.done),
517
+ error=None,
518
+ )
519
+ log_end(success=success, steps=1, score=reward, method_log=method_log)
520
+ log_task_boundary(task_spec.task_id, task_spec.difficulty, "END")
521
+ return {
522
+ "task_id": task_spec.task_id,
523
+ "difficulty": task_spec.difficulty,
524
+ "objective": task_spec.objective,
525
+ "grader_name": task_spec.grader_name,
526
+ "normalized_score": max(0.0, min(1.0, reward)),
527
+ "done": final_result.done,
528
+ "success": success,
529
+ "final_status": final_result.observation.status,
530
+ "final_message": final_result.observation.message,
531
+ "issue_count": len(final_result.observation.submission_issues),
532
+ "correct_row_count": final_result.observation.correct_row_count,
533
+ "expected_row_count": final_result.observation.expected_row_count,
534
+ "tool_rounds": tool_rounds,
535
+ "assistant_summary": final_text,
536
+ "reward_breakdown": (
537
+ final_result.observation.reward_breakdown.model_dump()
538
+ if final_result.observation.reward_breakdown
539
+ else None
540
+ ),
541
+ }
542
+
543
+
544
+ async def main() -> None:
545
+ if not API_KEY:
546
+ raise RuntimeError("Set OPENAI_API_KEY, HF_TOKEN, or API_KEY.")
547
+
548
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
549
+ env = await connect_env()
550
+ task_summaries: list[dict[str, Any]] = []
551
+
552
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
553
+
554
+ try:
555
+ for task_id in DEFAULT_TASK_ORDER:
556
+ task_summaries.append(await run_single_task(client, env, task_id))
557
+ finally:
558
+ try:
559
+ await env.close()
560
+ except Exception:
561
+ pass
562
+
563
+ average_score = (
564
+ round(sum(item["normalized_score"] for item in task_summaries) / len(task_summaries), 6)
565
+ if task_summaries
566
+ else 0.0
567
+ )
568
+ print(
569
+ json.dumps(
570
+ {
571
+ "benchmark": BENCHMARK,
572
+ "model": MODEL_NAME,
573
+ "task_count": len(task_summaries),
574
+ "task_ids": [item["task_id"] for item in task_summaries],
575
+ "average_score": average_score,
576
+ "successful_tasks": sum(1 for item in task_summaries if item["success"]),
577
+ "tasks": task_summaries,
578
+ },
579
+ indent=2,
580
+ ),
581
+ flush=True,
582
+ )
583
+
584
+
585
+ if __name__ == "__main__":
586
+ asyncio.run(main())
models.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Data models for the metric tracker RL environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Literal
6
+
7
+ from pydantic import BaseModel, Field
8
+
9
+ from openenv.core.env_server.types import Action, Observation
10
+
11
+
12
+ class MetricRecord(BaseModel):
13
+ """Hourly or daily aggregate metrics for the app funnel."""
14
+
15
+ date: str = Field(..., description="ISO date in YYYY-MM-DD format.")
16
+ hour: int | None = Field(
17
+ default=None,
18
+ description="Hour bucket in 24h format. Null for daily aggregates.",
19
+ )
20
+ app_opens: int = Field(default=0, description="Count of app_open events.")
21
+ menu_opens: int = Field(default=0, description="Count of menu_open events.")
22
+ product_added_to_cart: int = Field(
23
+ default=0,
24
+ description="Count of product_added_to_cart events.",
25
+ )
26
+ orders_placed: int = Field(default=0, description="Count of order_placed events.")
27
+ payment_successful: int = Field(
28
+ default=0,
29
+ description="Count of payment_successful events.",
30
+ )
31
+
32
+
33
+ class ConversionMetricDefinition(BaseModel):
34
+ """Definition for a conversion metric that the agent can cite."""
35
+
36
+ name: str = Field(..., description="Stable conversion metric identifier.")
37
+ numerator: str = Field(..., description="Numerator event.")
38
+ denominator: str = Field(..., description="Denominator event.")
39
+ description: str = Field(..., description="Human-readable formula.")
40
+
41
+
42
+ class MethodSpec(BaseModel):
43
+ """Description of a shared safe analysis method."""
44
+
45
+ name: str = Field(..., description="Method name.")
46
+ description: str = Field(..., description="What the method does.")
47
+ parameters: list[str] = Field(
48
+ default_factory=list,
49
+ description="Ordered parameter names for the method.",
50
+ )
51
+
52
+
53
+ class MetricSubmissionRow(BaseModel):
54
+ """Submitted anomaly row."""
55
+
56
+ date: str = Field(..., description="ISO date in YYYY-MM-DD format.")
57
+ entity_type: str = Field(
58
+ ...,
59
+ description=(
60
+ "Stable entity family such as conversion_rate, event_count, funnel_step, "
61
+ "hourly_mix, or data_quality."
62
+ ),
63
+ )
64
+ entity_name: str = Field(..., description="Stable entity identifier.")
65
+ anomaly_type: str = Field(..., description="Stable anomaly type identifier.")
66
+ detection_method: str = Field(..., description="Shared analysis method used.")
67
+ baseline_value: float = Field(..., description="Reference baseline value.")
68
+ observed_value: float = Field(..., description="Observed anomalous value.")
69
+ delta_value: float = Field(..., description="Observed minus baseline.")
70
+ severity: Literal["low", "medium", "high", "critical"] = Field(
71
+ ...,
72
+ description="Severity label.",
73
+ )
74
+
75
+
76
+ class PayloadGeneratorMethod(BaseModel):
77
+ """A declarative payload generation method."""
78
+
79
+ method_name: str = Field(
80
+ ...,
81
+ description="Generator method name, for example get_median_filter_rows.",
82
+ )
83
+ metric_name: str | None = Field(
84
+ default=None,
85
+ description="Single count metric or conversion metric name. Optional.",
86
+ )
87
+ metric_names: list[str] = Field(
88
+ default_factory=list,
89
+ description="Optional list of metrics to run. Empty means all metrics.",
90
+ )
91
+ threshold_multiplier: float = Field(
92
+ ...,
93
+ description="Multiplier applied to the metric std-from-median value.",
94
+ )
95
+
96
+
97
+ class SyntheticAnomalyGenerator(BaseModel):
98
+ """A declarative reset-time synthetic anomaly generator."""
99
+
100
+ method_name: str = Field(
101
+ default="metric_stddev_shift",
102
+ description="Synthetic generator method name.",
103
+ )
104
+ metric_name: str | None = Field(
105
+ default=None,
106
+ description="Single count metric or conversion metric name. Optional.",
107
+ )
108
+ metric_names: list[str] = Field(
109
+ default_factory=list,
110
+ description="Optional list of metrics to generate on. Empty means use metric_name.",
111
+ )
112
+ date: str | None = Field(
113
+ default=None,
114
+ description="Single ISO date to inject on. Optional.",
115
+ )
116
+ dates: list[str] = Field(
117
+ default_factory=list,
118
+ description="Optional list of ISO dates to inject on.",
119
+ )
120
+ stddev_factor: float = Field(
121
+ default=2.0,
122
+ description="Multiplier applied to std_dev_from_median when creating the target value.",
123
+ )
124
+ direction: Literal["up", "down", "auto"] = Field(
125
+ default="auto",
126
+ description="Whether to shift the metric upward or downward.",
127
+ )
128
+
129
+
130
+ class SyntheticGeneratorApplication(BaseModel):
131
+ """Resolved synthetic generator application used for the active episode."""
132
+
133
+ method_name: str = Field(..., description="Synthetic generator method used.")
134
+ date: str = Field(..., description="ISO date the generator was applied to.")
135
+ metric_name: str = Field(..., description="Metric name used by the generator.")
136
+ metric_type: Literal["event_count", "conversion_rate"] = Field(
137
+ ...,
138
+ description="Resolved metric family.",
139
+ )
140
+ direction: Literal["up", "down"] = Field(..., description="Resolved direction.")
141
+ anomaly_type: str = Field(..., description="Expected anomaly type generated.")
142
+ detection_method: str = Field(..., description="Shared analysis method that should detect it.")
143
+ baseline_value: float = Field(..., description="Median baseline used during generation.")
144
+ pre_applied_value: float = Field(..., description="Metric value before generation.")
145
+ std_dev_from_median: float = Field(..., description="Std-from-median used during generation.")
146
+ stddev_factor: float = Field(..., description="Configured stddev factor.")
147
+ threshold_value: float = Field(..., description="stddev_factor * std_dev_from_median.")
148
+ target_value: float = Field(..., description="Requested target value before rebalancing.")
149
+ actual_value: float = Field(..., description="Observed value after generation.")
150
+ formula: str = Field(..., description="Human-readable formula used for generation.")
151
+
152
+
153
+ class SubmissionIssue(BaseModel):
154
+ """Feedback about a submitted row or missing expectation."""
155
+
156
+ row_key: str = Field(..., description="Stable key in date|entity_type|entity_name form.")
157
+ issue_type: str = Field(..., description="Issue class.")
158
+ message: str = Field(..., description="Human-readable explanation.")
159
+ submitted_row: dict[str, Any] | None = Field(
160
+ default=None,
161
+ description="Submitted row fragment when relevant.",
162
+ )
163
+ expected_row: dict[str, Any] | None = Field(
164
+ default=None,
165
+ description="Expected row fragment when debug is enabled.",
166
+ )
167
+
168
+
169
+ class RewardBreakdown(BaseModel):
170
+ """Deterministic grading components."""
171
+
172
+ precision: float = 0.0
173
+ recall: float = 0.0
174
+ anomaly_type_accuracy: float = 0.0
175
+ detection_method_accuracy: float = 0.0
176
+ value_accuracy: float = 0.0
177
+ severity_accuracy: float = 0.0
178
+ extra_row_penalty: float = 0.0
179
+ duplicate_penalty: float = 0.0
180
+ invalid_row_penalty: float = 0.0
181
+ exploit_penalty: float = 0.0
182
+ total_score: float = 0.0
183
+ matched_rows: int = 0
184
+ expected_rows: int = 0
185
+ submitted_rows: int = 0
186
+ valid_submitted_rows: int = 0
187
+ extra_rows: int = 0
188
+ duplicate_rows: int = 0
189
+ invalid_rows: int = 0
190
+ missing_rows: int = 0
191
+
192
+
193
+ class SubmissionPreview(BaseModel):
194
+ """Safe preview of a candidate submission before grading."""
195
+
196
+ valid_rows: int = 0
197
+ invalid_rows: int = 0
198
+ duplicate_rows: int = 0
199
+ unique_keys: int = 0
200
+ issues: list[SubmissionIssue] = Field(default_factory=list)
201
+ normalized_rows: list[MetricSubmissionRow] = Field(default_factory=list)
202
+
203
+
204
+ class BenchmarkTaskSpec(BaseModel):
205
+ """Public metadata for a benchmark task."""
206
+
207
+ task_id: str = Field(..., description="Stable benchmark task identifier.")
208
+ difficulty: Literal["easy", "medium", "hard"] = Field(
209
+ ...,
210
+ description="Canonical task difficulty.",
211
+ )
212
+ instruction: str = Field(..., description="Task instruction shown to the agent.")
213
+ objective: str = Field(..., description="Concrete success objective.")
214
+ scenario_family: str = Field(..., description="Scenario family used to generate the task episode.")
215
+ anomaly_density: str = Field(..., description="Relative anomaly density for the task episode.")
216
+ anomaly_count: int = Field(..., description="Number of anomalous rows expected for the task.")
217
+ grader_name: str = Field(..., description="Programmatic grader used for the task.")
218
+
219
+
220
+ class MetricTrackerRlAction(Action):
221
+ """Submitted anomaly payload for the current episode."""
222
+
223
+ classifications: list[MetricSubmissionRow] = Field(
224
+ default_factory=list,
225
+ description="Submitted anomaly rows for the dataset.",
226
+ )
227
+ analysis_method: str | None = Field(
228
+ default=None,
229
+ description="Optional shared analysis method to call instead of grading a submission.",
230
+ )
231
+ analysis_args: dict[str, Any] = Field(
232
+ default_factory=dict,
233
+ description="Arguments for the selected analysis method.",
234
+ )
235
+ payload_generators: list[PayloadGeneratorMethod] = Field(
236
+ default_factory=list,
237
+ description="Declarative payload generation methods to run inside the environment.",
238
+ )
239
+
240
+
241
+ class MetricTrackerRlObservation(Observation):
242
+ """Observation containing the dataset and analysis surface."""
243
+
244
+ task_id: str = Field(
245
+ default="",
246
+ description="Stable identifier for the active benchmark task.",
247
+ )
248
+ status: str = Field(
249
+ default="ready",
250
+ description="Episode status: ready, in_progress, evaluated, or completed.",
251
+ )
252
+ message: str = Field(default="", description="Human-readable environment feedback.")
253
+ instruction: str = Field(
254
+ default="",
255
+ description="Task presented to the model for the current episode.",
256
+ )
257
+ conversion_metric_definitions: list[ConversionMetricDefinition] = Field(
258
+ default_factory=list,
259
+ description="Conversion formulas the model may cite.",
260
+ )
261
+ available_synthetic_generator_methods: list[MethodSpec] = Field(
262
+ default_factory=list,
263
+ description="Reset-time synthetic generator methods available for seeded data creation.",
264
+ )
265
+ applied_synthetic_generators: list[SyntheticGeneratorApplication] = Field(
266
+ default_factory=list,
267
+ description="Resolved synthetic generator applications used for the active episode.",
268
+ )
269
+ available_methods: list[MethodSpec] = Field(
270
+ default_factory=list,
271
+ description="Safe shared analysis methods available to agents and humans.",
272
+ )
273
+ available_tasks: list[BenchmarkTaskSpec] = Field(
274
+ default_factory=list,
275
+ description="Catalog of benchmark tasks available in this environment.",
276
+ )
277
+ daily_metrics: list[MetricRecord] = Field(
278
+ default_factory=list,
279
+ description="Deprecated raw daily data field. Kept empty in standard mode.",
280
+ )
281
+ hourly_metrics: list[MetricRecord] = Field(
282
+ default_factory=list,
283
+ description="Deprecated raw hourly data field. Kept empty in standard mode.",
284
+ )
285
+ analysis_result: dict[str, Any] | None = Field(
286
+ default=None,
287
+ description="Result of the latest analysis-method call.",
288
+ )
289
+ generated_rows: list[MetricSubmissionRow] = Field(
290
+ default_factory=list,
291
+ description="Rows generated from payload generator methods, if used.",
292
+ )
293
+ submitted_rows: list[MetricSubmissionRow] = Field(
294
+ default_factory=list,
295
+ description="Most recent submitted anomaly rows.",
296
+ )
297
+ submission_preview: SubmissionPreview | None = Field(
298
+ default=None,
299
+ description="Safe preview information for the latest submitted payload.",
300
+ )
301
+ submission_issues: list[SubmissionIssue] = Field(
302
+ default_factory=list,
303
+ description="Feedback for the latest submitted payload.",
304
+ )
305
+ reward_breakdown: RewardBreakdown | None = Field(
306
+ default=None,
307
+ description="Deterministic reward components for the latest step.",
308
+ )
309
+ expected_row_count: int = Field(
310
+ default=0,
311
+ description="Number of expected anomaly rows in the current episode.",
312
+ )
313
+ correct_row_count: int = Field(
314
+ default=0,
315
+ description="Number of matched anomaly rows in the latest step.",
316
+ )
317
+ config: dict[str, Any] = Field(
318
+ default_factory=dict,
319
+ description="Episode configuration visible in standard mode.",
320
+ )
321
+ debug: dict[str, Any] | None = Field(
322
+ default=None,
323
+ description="Developer-only debug payload. Hidden in standard mode.",
324
+ )
openenv.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: metric_tracker_rl
3
+ type: space
4
+ runtime: fastapi
5
+ app: server.app:app
6
+ port: 8000
7
+
payload_generation.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Shared method registry and submission preview helpers."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ try:
8
+ from .analysis_tools import (
9
+ SharedAnalysisToolkit,
10
+ available_analysis_methods,
11
+ preview_submission_rows,
12
+ submission_row_key,
13
+ )
14
+ from .models import MetricSubmissionRow, SubmissionPreview
15
+ from .server.data_generator import available_synthetic_generator_methods
16
+ except ImportError:
17
+ from analysis_tools import (
18
+ SharedAnalysisToolkit,
19
+ available_analysis_methods,
20
+ preview_submission_rows,
21
+ submission_row_key,
22
+ )
23
+ from models import MetricSubmissionRow, SubmissionPreview
24
+ from server.data_generator import available_synthetic_generator_methods
25
+
26
+
27
+ def available_payload_generation_methods():
28
+ """Backward-compatible alias for the shared analysis method list."""
29
+ return available_analysis_methods()
30
+
31
+
32
+ def preview_submission(
33
+ rows: list[MetricSubmissionRow] | list[dict[str, Any]],
34
+ ) -> SubmissionPreview:
35
+ """Validate a submission without using hidden labels."""
36
+ return preview_submission_rows(rows)
37
+
38
+
39
+ __all__ = [
40
+ "SharedAnalysisToolkit",
41
+ "available_analysis_methods",
42
+ "available_payload_generation_methods",
43
+ "available_synthetic_generator_methods",
44
+ "preview_submission",
45
+ "submission_row_key",
46
+ ]
pyproject.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ [build-system]
8
+ requires = ["setuptools>=45", "wheel"]
9
+ build-backend = "setuptools.build_meta"
10
+
11
+ [project]
12
+ name = "openenv-metric_tracker_rl"
13
+ version = "0.1.0"
14
+ description = "Metric Tracker Rl environment for OpenEnv"
15
+ requires-python = ">=3.10"
16
+ dependencies = [
17
+ # Core OpenEnv runtime (provides FastAPI server + HTTP client types)
18
+ # install from github
19
+ # "openenv-core[core] @ git+https://github.com/meta-pytorch/OpenEnv.git",
20
+ "openenv-core[core]>=0.2.1",
21
+ # Environment-specific dependencies
22
+ # Add all dependencies needed for your environment here
23
+ # Examples:
24
+ # "numpy>=1.19.0",
25
+ # "torch>=2.0.0",
26
+ # "gymnasium>=0.29.0",
27
+ # "openspiel>=1.0.0",
28
+ # "smolagents>=1.22.0,<2",
29
+ "gradio>=5.0.0",
30
+ "pandas>=2.2.0",
31
+ "plotly>=5.24.0",
32
+ "openai>=1.0.0",
33
+ ]
34
+
35
+ [project.optional-dependencies]
36
+ dev = [
37
+ "pytest>=8.0.0",
38
+ "pytest-cov>=4.0.0",
39
+ ]
40
+
41
+ [project.scripts]
42
+ # Server entry point - enables running via: uv run --project . server
43
+ # or: python -m metric_tracker_rl.server.app
44
+ server = "metric_tracker_rl.server.app:main"
45
+
46
+ [tool.setuptools]
47
+ include-package-data = true
48
+ packages = ["metric_tracker_rl", "metric_tracker_rl.server"]
49
+ package-dir = { "metric_tracker_rl" = ".", "metric_tracker_rl.server" = "server" }
server/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """Metric Tracker Rl environment server components."""
8
+
9
+ from .metric_tracker_rl_environment import MetricTrackerRlEnvironment
10
+
11
+ __all__ = ["MetricTrackerRlEnvironment"]
server/app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ FastAPI application for the Metric Tracker Rl Environment.
9
+
10
+ This module creates an HTTP server that exposes the MetricTrackerRlEnvironment
11
+ over HTTP and WebSocket endpoints, compatible with EnvClient.
12
+
13
+ Endpoints:
14
+ - POST /reset: Reset the environment
15
+ - POST /step: Execute an action
16
+ - GET /state: Get current environment state
17
+ - GET /schema: Get action/observation schemas
18
+ - WS /ws: WebSocket endpoint for persistent sessions
19
+
20
+ Usage:
21
+ # Development (with auto-reload):
22
+ uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
23
+
24
+ # Production:
25
+ uvicorn server.app:app --host 0.0.0.0 --port 8000 --workers 4
26
+
27
+ # Or run directly:
28
+ python -m server.app
29
+ """
30
+
31
+ try:
32
+ from openenv.core.env_server.http_server import create_app
33
+ except Exception as e: # pragma: no cover
34
+ raise ImportError(
35
+ "openenv is required for the web interface. Install dependencies with '\n uv sync\n'"
36
+ ) from e
37
+
38
+ try:
39
+ from ..models import MetricTrackerRlAction, MetricTrackerRlObservation
40
+ from .gradio_ui import build_metric_tracker_gradio_app
41
+ from .metric_tracker_rl_environment import MetricTrackerRlEnvironment
42
+ except ImportError:
43
+ from models import MetricTrackerRlAction, MetricTrackerRlObservation
44
+ from server.gradio_ui import build_metric_tracker_gradio_app
45
+ from server.metric_tracker_rl_environment import MetricTrackerRlEnvironment
46
+
47
+
48
+ # Create the app with web interface and README integration
49
+ app = create_app(
50
+ MetricTrackerRlEnvironment,
51
+ MetricTrackerRlAction,
52
+ MetricTrackerRlObservation,
53
+ env_name="metric_tracker_rl",
54
+ gradio_builder=build_metric_tracker_gradio_app,
55
+ max_concurrent_envs=1, # increase this number to allow more concurrent WebSocket sessions
56
+ )
57
+
58
+
59
+ def main(host: str = "0.0.0.0", port: int = 8000):
60
+ """
61
+ Entry point for direct execution via uv run or python -m.
62
+
63
+ This function enables running the server without Docker:
64
+ uv run --project . server
65
+ uv run --project . server --port 8001
66
+ python -m metric_tracker_rl.server.app
67
+
68
+ Args:
69
+ host: Host address to bind to (default: "0.0.0.0")
70
+ port: Port number to listen on (default: 8000)
71
+
72
+ For production deployments, consider using uvicorn directly with
73
+ multiple workers:
74
+ uvicorn metric_tracker_rl.server.app:app --workers 4
75
+ """
76
+ import uvicorn
77
+
78
+ uvicorn.run(app, host=host, port=port)
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
server/data_generator.py ADDED
@@ -0,0 +1,1016 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Synthetic multi-anomaly data generator for the metric tracker RL environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import random
6
+ from dataclasses import dataclass, field
7
+ from datetime import date, timedelta
8
+ from statistics import median
9
+
10
+ try:
11
+ from ..analysis_tools import COUNT_METRICS, FUNNEL_STEPS, SharedAnalysisToolkit, AnalysisContext
12
+ from ..models import (
13
+ ConversionMetricDefinition,
14
+ MethodSpec,
15
+ MetricRecord,
16
+ MetricSubmissionRow,
17
+ SyntheticAnomalyGenerator,
18
+ SyntheticGeneratorApplication,
19
+ )
20
+ except ImportError:
21
+ from analysis_tools import COUNT_METRICS, FUNNEL_STEPS, SharedAnalysisToolkit, AnalysisContext
22
+ from models import (
23
+ ConversionMetricDefinition,
24
+ MethodSpec,
25
+ MetricRecord,
26
+ MetricSubmissionRow,
27
+ SyntheticAnomalyGenerator,
28
+ SyntheticGeneratorApplication,
29
+ )
30
+
31
+
32
+ ALL_SCENARIO_FAMILIES: tuple[str, ...] = (
33
+ "mixed",
34
+ "rate_drop_from_median",
35
+ "rate_spike_from_median",
36
+ "absolute_drop_in_event_count",
37
+ "absolute_spike_in_event_count",
38
+ "funnel_break",
39
+ "hourly_traffic_mix_shift",
40
+ "instrumentation_data_quality_issue",
41
+ )
42
+
43
+ SYNTHETIC_GENERATOR_METHOD_SPECS: tuple[MethodSpec, ...] = (
44
+ MethodSpec(
45
+ name="metric_stddev_shift",
46
+ description=(
47
+ "Inject a count or conversion anomaly on specific dates by setting the metric to "
48
+ "median +/- stddev_factor * std_dev_from_median."
49
+ ),
50
+ parameters=["metric_name", "metric_names", "date", "dates", "stddev_factor", "direction"],
51
+ ),
52
+ )
53
+
54
+
55
+ def available_synthetic_generator_methods() -> list[MethodSpec]:
56
+ """Return supported reset-time synthetic generator methods."""
57
+ return list(SYNTHETIC_GENERATOR_METHOD_SPECS)
58
+
59
+
60
+ @dataclass(frozen=True)
61
+ class GeneratorConfig:
62
+ """Configurable parameters for synthetic metric generation."""
63
+
64
+ conversion_definitions: tuple[ConversionMetricDefinition, ...] = (
65
+ ConversionMetricDefinition(
66
+ name="app_open_to_menu_open",
67
+ numerator="menu_opens",
68
+ denominator="app_opens",
69
+ description="menu_opens / app_opens * 100",
70
+ ),
71
+ ConversionMetricDefinition(
72
+ name="menu_open_to_product_added_to_cart",
73
+ numerator="product_added_to_cart",
74
+ denominator="menu_opens",
75
+ description="product_added_to_cart / menu_opens * 100",
76
+ ),
77
+ ConversionMetricDefinition(
78
+ name="product_added_to_cart_to_order_placed",
79
+ numerator="orders_placed",
80
+ denominator="product_added_to_cart",
81
+ description="orders_placed / product_added_to_cart * 100",
82
+ ),
83
+ ConversionMetricDefinition(
84
+ name="order_placed_to_payment_successful",
85
+ numerator="payment_successful",
86
+ denominator="orders_placed",
87
+ description="payment_successful / orders_placed * 100",
88
+ ),
89
+ ConversionMetricDefinition(
90
+ name="app_open_to_order_placed",
91
+ numerator="orders_placed",
92
+ denominator="app_opens",
93
+ description="orders_placed / app_opens * 100",
94
+ ),
95
+ ConversionMetricDefinition(
96
+ name="app_open_to_payment_successful",
97
+ numerator="payment_successful",
98
+ denominator="app_opens",
99
+ description="payment_successful / app_opens * 100",
100
+ ),
101
+ )
102
+ num_weeks: int = 4
103
+ end_date_offset_days: int = 1
104
+ base_daily_app_opens: int = 18000
105
+ weekday_factors: tuple[float, ...] = (0.95, 1.0, 1.02, 1.04, 1.06, 1.12, 1.08)
106
+ hourly_weights: tuple[float, ...] = (
107
+ 0.010,
108
+ 0.008,
109
+ 0.007,
110
+ 0.007,
111
+ 0.010,
112
+ 0.018,
113
+ 0.028,
114
+ 0.040,
115
+ 0.050,
116
+ 0.055,
117
+ 0.058,
118
+ 0.060,
119
+ 0.058,
120
+ 0.056,
121
+ 0.054,
122
+ 0.052,
123
+ 0.054,
124
+ 0.060,
125
+ 0.072,
126
+ 0.078,
127
+ 0.075,
128
+ 0.060,
129
+ 0.038,
130
+ 0.025,
131
+ )
132
+ baseline_rates: dict[str, float] = field(
133
+ default_factory=lambda: {
134
+ "menu_opens": 0.63,
135
+ "product_added_to_cart": 0.29,
136
+ "orders_placed": 0.44,
137
+ "payment_successful": 0.91,
138
+ }
139
+ )
140
+
141
+ @property
142
+ def num_days(self) -> int:
143
+ return self.num_weeks * 7
144
+
145
+
146
+ @dataclass(frozen=True)
147
+ class EpisodeConfig:
148
+ """Per-episode configuration."""
149
+
150
+ seed: int = 0
151
+ scenario_family: str = "mixed"
152
+ difficulty: str = "medium"
153
+ anomaly_density: str = "medium"
154
+ anomaly_count: int = 3
155
+ anomaly_generators: tuple[SyntheticAnomalyGenerator, ...] = ()
156
+
157
+ def normalized(self) -> "EpisodeConfig":
158
+ family = self.scenario_family if self.scenario_family in ALL_SCENARIO_FAMILIES else "mixed"
159
+ difficulty = self.difficulty if self.difficulty in {"easy", "medium", "hard"} else "medium"
160
+ density = self.anomaly_density if self.anomaly_density in {"low", "medium", "high"} else "medium"
161
+ return EpisodeConfig(
162
+ seed=int(self.seed),
163
+ scenario_family=family,
164
+ difficulty=difficulty,
165
+ anomaly_density=density,
166
+ anomaly_count=max(1, int(self.anomaly_count or 3)),
167
+ anomaly_generators=tuple(self.anomaly_generators or ()),
168
+ )
169
+
170
+
171
+ @dataclass
172
+ class PlannedAnomaly:
173
+ """Internal anomaly schedule item."""
174
+
175
+ date: str
176
+ anomaly_type: str
177
+ entity_type: str
178
+ entity_name: str
179
+ detection_method: str
180
+ details: dict[str, str]
181
+
182
+
183
+ @dataclass
184
+ class EpisodeData:
185
+ """Synthetic dataset and ground truth used for one episode."""
186
+
187
+ config: EpisodeConfig
188
+ scenario_label: str
189
+ daily_metrics: list[MetricRecord]
190
+ hourly_metrics: list[MetricRecord]
191
+ expected_rows: list[MetricSubmissionRow]
192
+ anomaly_schedule: list[dict[str, str]]
193
+ applied_synthetic_generators: list[SyntheticGeneratorApplication]
194
+
195
+
196
+ class MetricDataGenerator:
197
+ """Reusable synthetic data generator used by the env and custom UI."""
198
+
199
+ def __init__(self, config: GeneratorConfig | None = None, seed: int | None = None) -> None:
200
+ self.config = config or GeneratorConfig()
201
+ self._default_seed = int(seed or 0)
202
+
203
+ def generate_episode(self, episode_config: EpisodeConfig | None = None) -> EpisodeData:
204
+ """Generate one seeded episode."""
205
+ config = (episode_config or EpisodeConfig(seed=self._default_seed)).normalized()
206
+ rng = random.Random(config.seed)
207
+ end_date = date.today() - timedelta(days=self.config.end_date_offset_days)
208
+ start_date = end_date - timedelta(days=self.config.num_days - 1)
209
+ base_hourly = self._generate_base_hourly_metrics(start_date, rng, config)
210
+ applied_synthetic_generators: list[SyntheticGeneratorApplication] = []
211
+ if self._use_synthetic_metric_generators(config):
212
+ anomaly_plan, applied_synthetic_generators = self._apply_metric_generators(
213
+ base_hourly,
214
+ rng,
215
+ config,
216
+ )
217
+ else:
218
+ anomaly_plan = self._plan_anomalies(base_hourly, rng, config)
219
+ self._apply_anomalies(base_hourly, anomaly_plan, rng, config)
220
+ daily_metrics, hourly_metrics = self._materialize_metrics(base_hourly)
221
+ if applied_synthetic_generators:
222
+ self._refresh_applied_generator_actuals(
223
+ applied_synthetic_generators,
224
+ daily_metrics,
225
+ )
226
+ expected_rows = self._build_expected_rows(daily_metrics, hourly_metrics, anomaly_plan, config)
227
+ anomaly_schedule = [
228
+ {
229
+ "date": item.date,
230
+ "anomaly_type": item.anomaly_type,
231
+ "entity_type": item.entity_type,
232
+ "entity_name": item.entity_name,
233
+ "detection_method": item.detection_method,
234
+ }
235
+ for item in anomaly_plan
236
+ ]
237
+ return EpisodeData(
238
+ config=config,
239
+ scenario_label=config.scenario_family,
240
+ daily_metrics=daily_metrics,
241
+ hourly_metrics=hourly_metrics,
242
+ expected_rows=expected_rows,
243
+ anomaly_schedule=anomaly_schedule,
244
+ applied_synthetic_generators=applied_synthetic_generators,
245
+ )
246
+
247
+ def _use_synthetic_metric_generators(self, episode_config: EpisodeConfig) -> bool:
248
+ if episode_config.anomaly_generators:
249
+ return True
250
+ return episode_config.scenario_family in {
251
+ "mixed",
252
+ "rate_drop_from_median",
253
+ "rate_spike_from_median",
254
+ "absolute_drop_in_event_count",
255
+ "absolute_spike_in_event_count",
256
+ }
257
+
258
+ def _generate_base_hourly_metrics(
259
+ self,
260
+ start_date: date,
261
+ rng: random.Random,
262
+ episode_config: EpisodeConfig,
263
+ ) -> dict[str, list[MetricRecord]]:
264
+ hourly: dict[str, list[MetricRecord]] = {}
265
+ difficulty_noise = {"easy": 0.015, "medium": 0.025, "hard": 0.035}[episode_config.difficulty]
266
+ for day_index in range(self.config.num_days):
267
+ current_date = start_date + timedelta(days=day_index)
268
+ date_key = current_date.isoformat()
269
+ weekday_factor = self.config.weekday_factors[current_date.weekday()]
270
+ trend_factor = 1.0 + day_index * 0.0025
271
+ noise_factor = 1.0 + rng.uniform(-0.02, 0.02)
272
+ total_app_opens = round(
273
+ self.config.base_daily_app_opens * weekday_factor * trend_factor * noise_factor
274
+ )
275
+ weights = self._hour_weights(current_date.weekday(), rng)
276
+ hourly_app_opens = self._allocate_total(total_app_opens, weights, rng)
277
+ day_rows: list[MetricRecord] = []
278
+ for hour, app_opens in enumerate(hourly_app_opens):
279
+ menu_rate = self._bounded(
280
+ self.config.baseline_rates["menu_opens"] * (1.0 + rng.uniform(-difficulty_noise, difficulty_noise)),
281
+ 0.50,
282
+ 0.80,
283
+ )
284
+ cart_rate = self._bounded(
285
+ self.config.baseline_rates["product_added_to_cart"]
286
+ * (1.0 + rng.uniform(-difficulty_noise * 1.2, difficulty_noise * 1.2)),
287
+ 0.18,
288
+ 0.42,
289
+ )
290
+ order_rate = self._bounded(
291
+ self.config.baseline_rates["orders_placed"]
292
+ * (1.0 + rng.uniform(-difficulty_noise * 1.2, difficulty_noise * 1.2)),
293
+ 0.28,
294
+ 0.62,
295
+ )
296
+ payment_rate = self._bounded(
297
+ self.config.baseline_rates["payment_successful"]
298
+ * (1.0 + rng.uniform(-difficulty_noise, difficulty_noise)),
299
+ 0.76,
300
+ 0.99,
301
+ )
302
+ menu_opens = round(app_opens * menu_rate)
303
+ carts = round(menu_opens * cart_rate)
304
+ orders = round(carts * order_rate)
305
+ payments = round(orders * payment_rate)
306
+ day_rows.append(
307
+ MetricRecord(
308
+ date=date_key,
309
+ hour=hour,
310
+ app_opens=app_opens,
311
+ menu_opens=menu_opens,
312
+ product_added_to_cart=carts,
313
+ orders_placed=orders,
314
+ payment_successful=payments,
315
+ )
316
+ )
317
+ hourly[date_key] = day_rows
318
+ return hourly
319
+
320
+ def _plan_anomalies(
321
+ self,
322
+ base_hourly: dict[str, list[MetricRecord]],
323
+ rng: random.Random,
324
+ episode_config: EpisodeConfig,
325
+ ) -> list[PlannedAnomaly]:
326
+ dates = sorted(base_hourly)
327
+ candidate_dates = dates[3:-2] if len(dates) > 8 else dates
328
+ family_pool = (
329
+ list(ALL_SCENARIO_FAMILIES[1:])
330
+ if episode_config.scenario_family == "mixed"
331
+ else [episode_config.scenario_family]
332
+ )
333
+ target_count = max(
334
+ 1,
335
+ int(
336
+ episode_config.anomaly_count
337
+ or {"low": 3, "medium": 5, "high": 7}[episode_config.anomaly_density]
338
+ ),
339
+ )
340
+ plan: list[PlannedAnomaly] = []
341
+ used_pairs: set[tuple[str, str, str]] = set()
342
+ family_order = family_pool[:]
343
+ rng.shuffle(family_order)
344
+ family_index = 0
345
+
346
+ while len(plan) < target_count:
347
+ if family_index >= len(family_order):
348
+ family_order = family_pool[:]
349
+ rng.shuffle(family_order)
350
+ family_index = 0
351
+ anomaly_type = family_order[family_index]
352
+ family_index += 1
353
+ date_key = rng.choice(candidate_dates)
354
+ entity_type, entity_name, detection_method, details = self._pick_entity_for_family(
355
+ anomaly_type,
356
+ rng,
357
+ )
358
+ dedupe_key = (date_key, entity_type, entity_name)
359
+ if dedupe_key in used_pairs:
360
+ continue
361
+ used_pairs.add(dedupe_key)
362
+ plan.append(
363
+ PlannedAnomaly(
364
+ date=date_key,
365
+ anomaly_type=anomaly_type,
366
+ entity_type=entity_type,
367
+ entity_name=entity_name,
368
+ detection_method=detection_method,
369
+ details=details,
370
+ )
371
+ )
372
+ plan.sort(key=lambda item: (item.date, item.entity_type, item.entity_name))
373
+ return plan
374
+
375
+ def _pick_entity_for_family(
376
+ self,
377
+ anomaly_type: str,
378
+ rng: random.Random,
379
+ ) -> tuple[str, str, str, dict[str, str]]:
380
+ if anomaly_type in {"rate_drop_from_median", "rate_spike_from_median"}:
381
+ definition = rng.choice(list(self.config.conversion_definitions))
382
+ return (
383
+ "conversion_rate",
384
+ definition.name,
385
+ "compare_rate_to_median",
386
+ {"conversion_name": definition.name},
387
+ )
388
+ if anomaly_type in {"absolute_drop_in_event_count", "absolute_spike_in_event_count"}:
389
+ metric_name = rng.choice(list(COUNT_METRICS))
390
+ return (
391
+ "event_count",
392
+ metric_name,
393
+ "compare_count_to_median",
394
+ {"metric_name": metric_name},
395
+ )
396
+ if anomaly_type == "funnel_break":
397
+ numerator, denominator = rng.choice(list(FUNNEL_STEPS))
398
+ return (
399
+ "funnel_step",
400
+ f"{numerator}_from_{denominator}",
401
+ "detect_funnel_break",
402
+ {"numerator": numerator, "denominator": denominator},
403
+ )
404
+ if anomaly_type == "hourly_traffic_mix_shift":
405
+ return (
406
+ "hourly_mix",
407
+ "app_opens:daytime_share",
408
+ "hourly_rows_for_date",
409
+ {},
410
+ )
411
+ numerator, denominator = rng.choice(list(FUNNEL_STEPS))
412
+ return (
413
+ "data_quality",
414
+ f"{numerator}_lte_{denominator}",
415
+ "check_impossible_counts",
416
+ {"numerator": numerator, "denominator": denominator},
417
+ )
418
+
419
+ def _apply_anomalies(
420
+ self,
421
+ hourly: dict[str, list[MetricRecord]],
422
+ plan: list[PlannedAnomaly],
423
+ rng: random.Random,
424
+ episode_config: EpisodeConfig,
425
+ ) -> None:
426
+ difficulty = episode_config.difficulty
427
+ for item in plan:
428
+ rows = hourly[item.date]
429
+ if item.anomaly_type == "rate_drop_from_median":
430
+ self._apply_rate_change(rows, item.details["conversion_name"], rng, difficulty, direction="down")
431
+ elif item.anomaly_type == "rate_spike_from_median":
432
+ self._apply_rate_change(rows, item.details["conversion_name"], rng, difficulty, direction="up")
433
+ elif item.anomaly_type == "absolute_drop_in_event_count":
434
+ self._apply_count_change(rows, item.details["metric_name"], rng, difficulty, direction="down")
435
+ elif item.anomaly_type == "absolute_spike_in_event_count":
436
+ self._apply_count_change(rows, item.details["metric_name"], rng, difficulty, direction="up")
437
+ elif item.anomaly_type == "funnel_break":
438
+ self._apply_funnel_break(rows, item.details["numerator"], item.details["denominator"], rng, difficulty)
439
+ elif item.anomaly_type == "hourly_traffic_mix_shift":
440
+ self._apply_hourly_mix_shift(rows, rng, difficulty)
441
+ elif item.anomaly_type == "instrumentation_data_quality_issue":
442
+ self._apply_data_quality_issue(rows, item.details["numerator"], item.details["denominator"], rng, difficulty)
443
+
444
+ def _apply_rate_change(
445
+ self,
446
+ rows: list[MetricRecord],
447
+ conversion_name: str,
448
+ rng: random.Random,
449
+ difficulty: str,
450
+ *,
451
+ direction: str,
452
+ ) -> None:
453
+ definition = next(item for item in self.config.conversion_definitions if item.name == conversion_name)
454
+ multipliers = {
455
+ "easy": (0.74, 1.32),
456
+ "medium": (0.82, 1.22),
457
+ "hard": (0.88, 1.15),
458
+ }[difficulty]
459
+ multiplier = multipliers[0] if direction == "down" else multipliers[1]
460
+ for row in rows:
461
+ denominator_value = getattr(row, definition.denominator)
462
+ observed = round(denominator_value * multiplier * self._base_rate_from_metric(definition.numerator))
463
+ setattr_value = min(max(observed, 0), denominator_value)
464
+ self._set_metric_and_rebalance(row, definition.numerator, setattr_value)
465
+
466
+ def _apply_count_change(
467
+ self,
468
+ rows: list[MetricRecord],
469
+ metric_name: str,
470
+ rng: random.Random,
471
+ difficulty: str,
472
+ *,
473
+ direction: str,
474
+ ) -> None:
475
+ multipliers = {
476
+ "easy": (0.58, 1.42),
477
+ "medium": (0.72, 1.28),
478
+ "hard": (0.82, 1.18),
479
+ }[difficulty]
480
+ multiplier = multipliers[0] if direction == "down" else multipliers[1]
481
+ for row in rows:
482
+ original = getattr(row, metric_name)
483
+ updated = max(0, round(original * multiplier))
484
+ self._set_metric_and_rebalance(row, metric_name, updated)
485
+
486
+ def _apply_funnel_break(
487
+ self,
488
+ rows: list[MetricRecord],
489
+ numerator: str,
490
+ denominator: str,
491
+ rng: random.Random,
492
+ difficulty: str,
493
+ ) -> None:
494
+ if numerator == "menu_opens":
495
+ return
496
+ drop = {"easy": 0.45, "medium": 0.58, "hard": 0.7}[difficulty]
497
+ for row in rows:
498
+ denominator_value = getattr(row, denominator)
499
+ broken_value = max(0, round(denominator_value * drop))
500
+ self._set_metric_and_rebalance(row, numerator, broken_value)
501
+
502
+ def _apply_hourly_mix_shift(
503
+ self,
504
+ rows: list[MetricRecord],
505
+ rng: random.Random,
506
+ difficulty: str,
507
+ ) -> None:
508
+ total = sum(row.app_opens for row in rows)
509
+ if total <= 0:
510
+ return
511
+ shift = {"easy": 0.28, "medium": 0.20, "hard": 0.14}[difficulty]
512
+ boosted_hours = {0, 1, 2, 3, 4, 21, 22, 23}
513
+ weights = []
514
+ for row in rows:
515
+ base = row.app_opens / total
516
+ if row.hour in boosted_hours:
517
+ base *= 1.0 + shift
518
+ elif 9 <= (row.hour or 0) <= 18:
519
+ base *= max(0.2, 1.0 - shift)
520
+ weights.append(base)
521
+ normalized = [value / sum(weights) for value in weights]
522
+ redistributed = self._allocate_total(total, normalized, rng)
523
+ for row, app_opens in zip(rows, redistributed, strict=False):
524
+ row.app_opens = app_opens
525
+ menu_rate = self._ratio(row.menu_opens, max(row.app_opens, 1))
526
+ row.menu_opens = min(row.app_opens, round(app_opens * menu_rate))
527
+ cart_rate = self._ratio(row.product_added_to_cart, max(row.menu_opens, 1))
528
+ row.product_added_to_cart = min(row.menu_opens, round(row.menu_opens * cart_rate))
529
+ order_rate = self._ratio(row.orders_placed, max(row.product_added_to_cart, 1))
530
+ row.orders_placed = min(row.product_added_to_cart, round(row.product_added_to_cart * order_rate))
531
+ payment_rate = self._ratio(row.payment_successful, max(row.orders_placed, 1))
532
+ row.payment_successful = min(row.orders_placed, round(row.orders_placed * payment_rate))
533
+
534
+ def _apply_data_quality_issue(
535
+ self,
536
+ rows: list[MetricRecord],
537
+ numerator: str,
538
+ denominator: str,
539
+ rng: random.Random,
540
+ difficulty: str,
541
+ ) -> None:
542
+ affected_hours = {"easy": 5, "medium": 4, "hard": 3}[difficulty]
543
+ for row in rng.sample(rows, k=min(affected_hours, len(rows))):
544
+ denominator_value = getattr(row, denominator)
545
+ violation = max(1, round(denominator_value * {"easy": 0.12, "medium": 0.08, "hard": 0.05}[difficulty]))
546
+ setattr(row, numerator, denominator_value + violation)
547
+ self._rebalance_downstream_from(row, numerator)
548
+
549
+ def _apply_metric_generators(
550
+ self,
551
+ hourly: dict[str, list[MetricRecord]],
552
+ rng: random.Random,
553
+ episode_config: EpisodeConfig,
554
+ ) -> tuple[list[PlannedAnomaly], list[SyntheticGeneratorApplication]]:
555
+ generator_specs = self._resolve_metric_generators(hourly, rng, episode_config)
556
+ if not generator_specs:
557
+ return [], []
558
+
559
+ daily_metrics, hourly_metrics = self._materialize_metrics(hourly)
560
+ toolkit = SharedAnalysisToolkit(
561
+ AnalysisContext(
562
+ daily_metrics=daily_metrics,
563
+ hourly_metrics=hourly_metrics,
564
+ conversion_definitions=list(self.config.conversion_definitions),
565
+ config=episode_config.__dict__,
566
+ )
567
+ )
568
+
569
+ anomaly_plan: list[PlannedAnomaly] = []
570
+ applications: list[SyntheticGeneratorApplication] = []
571
+ seen_pairs: set[tuple[str, str]] = set()
572
+ for spec in generator_specs:
573
+ for date_key in self._resolve_generator_dates(spec, hourly, rng):
574
+ for metric_name in self._resolve_generator_metrics(spec):
575
+ dedupe_key = (date_key, metric_name)
576
+ if dedupe_key in seen_pairs:
577
+ continue
578
+ seen_pairs.add(dedupe_key)
579
+ application = self._build_metric_generator_application(
580
+ toolkit=toolkit,
581
+ date_key=date_key,
582
+ metric_name=metric_name,
583
+ spec=spec,
584
+ rng=rng,
585
+ )
586
+ self._apply_metric_generator_application(hourly[date_key], application)
587
+ applications.append(application)
588
+ anomaly_plan.append(
589
+ PlannedAnomaly(
590
+ date=date_key,
591
+ anomaly_type=application.anomaly_type,
592
+ entity_type=application.metric_type,
593
+ entity_name=metric_name,
594
+ detection_method=application.detection_method,
595
+ details={"metric_name": metric_name},
596
+ )
597
+ )
598
+ applications.sort(key=lambda item: (item.date, item.metric_name))
599
+ anomaly_plan.sort(key=lambda item: (item.date, item.entity_type, item.entity_name))
600
+ return anomaly_plan, applications
601
+
602
+ def _resolve_metric_generators(
603
+ self,
604
+ hourly: dict[str, list[MetricRecord]],
605
+ rng: random.Random,
606
+ episode_config: EpisodeConfig,
607
+ ) -> list[SyntheticAnomalyGenerator]:
608
+ if episode_config.anomaly_generators:
609
+ return list(episode_config.anomaly_generators)
610
+
611
+ dates = sorted(hourly)
612
+ candidate_dates = dates[3:-2] if len(dates) > 8 else dates
613
+ metric_pool = self._metric_pool_for_family(episode_config.scenario_family)
614
+ if not metric_pool:
615
+ return []
616
+
617
+ used_pairs: set[tuple[str, str]] = set()
618
+ generated: list[SyntheticAnomalyGenerator] = []
619
+ default_stddev = {"easy": 2.6, "medium": 2.2, "hard": 1.8}[episode_config.difficulty]
620
+ while len(generated) < max(1, episode_config.anomaly_count):
621
+ date_key = rng.choice(candidate_dates)
622
+ metric_name = rng.choice(metric_pool)
623
+ if (date_key, metric_name) in used_pairs:
624
+ continue
625
+ used_pairs.add((date_key, metric_name))
626
+ generated.append(
627
+ SyntheticAnomalyGenerator(
628
+ method_name="metric_stddev_shift",
629
+ metric_name=metric_name,
630
+ date=date_key,
631
+ stddev_factor=default_stddev,
632
+ direction=self._default_direction_for_family(episode_config.scenario_family, rng),
633
+ )
634
+ )
635
+ return generated
636
+
637
+ def _metric_pool_for_family(self, scenario_family: str) -> list[str]:
638
+ conversion_metrics = [item.name for item in self.config.conversion_definitions]
639
+ if scenario_family in {"rate_drop_from_median", "rate_spike_from_median"}:
640
+ return conversion_metrics
641
+ if scenario_family in {"absolute_drop_in_event_count", "absolute_spike_in_event_count"}:
642
+ return list(COUNT_METRICS)
643
+ if scenario_family == "mixed":
644
+ return list(COUNT_METRICS) + conversion_metrics
645
+ return []
646
+
647
+ @staticmethod
648
+ def _default_direction_for_family(scenario_family: str, rng: random.Random) -> str:
649
+ if scenario_family in {"rate_drop_from_median", "absolute_drop_in_event_count"}:
650
+ return "down"
651
+ if scenario_family in {"rate_spike_from_median", "absolute_spike_in_event_count"}:
652
+ return "up"
653
+ return "down" if rng.random() < 0.5 else "up"
654
+
655
+ def _resolve_generator_dates(
656
+ self,
657
+ spec: SyntheticAnomalyGenerator,
658
+ hourly: dict[str, list[MetricRecord]],
659
+ rng: random.Random,
660
+ ) -> list[str]:
661
+ dates = [item for item in spec.dates if item in hourly]
662
+ if spec.date and spec.date in hourly:
663
+ dates.append(spec.date)
664
+ if not dates:
665
+ dates = [rng.choice(sorted(hourly))]
666
+ seen = set()
667
+ deduped = []
668
+ for item in dates:
669
+ if item in seen:
670
+ continue
671
+ seen.add(item)
672
+ deduped.append(item)
673
+ return deduped
674
+
675
+ def _resolve_generator_metrics(self, spec: SyntheticAnomalyGenerator) -> list[str]:
676
+ metrics = [item for item in spec.metric_names if item]
677
+ if spec.metric_name:
678
+ metrics.append(spec.metric_name)
679
+ if not metrics:
680
+ metrics = list(COUNT_METRICS) + [item.name for item in self.config.conversion_definitions]
681
+ seen = set()
682
+ deduped = []
683
+ for item in metrics:
684
+ if item in seen:
685
+ continue
686
+ seen.add(item)
687
+ deduped.append(item)
688
+ return deduped
689
+
690
+ def _build_metric_generator_application(
691
+ self,
692
+ *,
693
+ toolkit: SharedAnalysisToolkit,
694
+ date_key: str,
695
+ metric_name: str,
696
+ spec: SyntheticAnomalyGenerator,
697
+ rng: random.Random,
698
+ ) -> SyntheticGeneratorApplication:
699
+ stats = toolkit.get_metric_std_dev_from_median(metric_name)
700
+ descriptor = toolkit._metric_descriptor(metric_name)
701
+ baseline_value = float(stats["median_value"])
702
+ std_dev_from_median = float(stats["std_dev_from_median"])
703
+ pre_applied_value = float(descriptor["per_date_values"][date_key])
704
+ direction = spec.direction if spec.direction != "auto" else ("down" if rng.random() < 0.5 else "up")
705
+ sign = -1.0 if direction == "down" else 1.0
706
+ threshold_value = round(std_dev_from_median * float(spec.stddev_factor), 4)
707
+ metric_type = "event_count" if metric_name in COUNT_METRICS else "conversion_rate"
708
+ if metric_type == "event_count":
709
+ minimum_shift = max(50.0, baseline_value * toolkit._count_threshold_fraction()) * 1.05
710
+ applied_shift = max(threshold_value, round(minimum_shift, 4))
711
+ target_value = max(0.0, baseline_value + sign * applied_shift)
712
+ anomaly_type = "absolute_spike_in_event_count" if sign > 0 else "absolute_drop_in_event_count"
713
+ detection_method = "compare_count_to_median"
714
+ else:
715
+ applied_shift = max(threshold_value, round(toolkit._rate_threshold() * 1.05, 4))
716
+ target_value = self._bounded(baseline_value + sign * applied_shift, 0.0, 100.0)
717
+ anomaly_type = "rate_spike_from_median" if sign > 0 else "rate_drop_from_median"
718
+ detection_method = "compare_rate_to_median"
719
+ return SyntheticGeneratorApplication(
720
+ method_name=spec.method_name,
721
+ date=date_key,
722
+ metric_name=metric_name,
723
+ metric_type=metric_type,
724
+ direction="up" if sign > 0 else "down",
725
+ anomaly_type=anomaly_type,
726
+ detection_method=detection_method,
727
+ baseline_value=round(baseline_value, 4),
728
+ pre_applied_value=round(pre_applied_value, 4),
729
+ std_dev_from_median=round(std_dev_from_median, 4),
730
+ stddev_factor=round(float(spec.stddev_factor), 4),
731
+ threshold_value=threshold_value,
732
+ target_value=round(target_value, 4),
733
+ actual_value=round(target_value, 4),
734
+ formula=(
735
+ f"{metric_name} = median {'+' if sign > 0 else '-'} "
736
+ "max(stddev_factor * std_dev_from_median, detector_threshold)"
737
+ ),
738
+ )
739
+
740
+ def _apply_metric_generator_application(
741
+ self,
742
+ rows: list[MetricRecord],
743
+ application: SyntheticGeneratorApplication,
744
+ ) -> None:
745
+ if application.metric_type == "event_count":
746
+ self._apply_daily_count_target(
747
+ rows,
748
+ application.metric_name,
749
+ int(round(application.target_value)),
750
+ )
751
+ return
752
+ self._apply_daily_conversion_target(
753
+ rows,
754
+ application.metric_name,
755
+ float(application.target_value),
756
+ )
757
+
758
+ def _apply_daily_count_target(
759
+ self,
760
+ rows: list[MetricRecord],
761
+ metric_name: str,
762
+ target_total: int,
763
+ ) -> None:
764
+ target_total = max(0, target_total)
765
+ current_values = [max(0, getattr(row, metric_name)) for row in rows]
766
+ current_total = sum(current_values)
767
+ if current_total > 0:
768
+ weights = [value / current_total for value in current_values]
769
+ else:
770
+ app_total = sum(max(0, row.app_opens) for row in rows) or len(rows)
771
+ weights = [max(0, row.app_opens) / app_total for row in rows]
772
+ allocated = self._allocate_total(target_total, weights, random.Random(target_total + len(rows)))
773
+ for row, value in zip(rows, allocated, strict=False):
774
+ self._set_metric_and_rebalance(row, metric_name, value)
775
+
776
+ def _apply_daily_conversion_target(
777
+ self,
778
+ rows: list[MetricRecord],
779
+ conversion_name: str,
780
+ target_rate_pct: float,
781
+ ) -> None:
782
+ definition = next(item for item in self.config.conversion_definitions if item.name == conversion_name)
783
+ bounded_rate = self._bounded(target_rate_pct / 100.0, 0.0, 1.0)
784
+ for row in rows:
785
+ denominator_value = getattr(row, definition.denominator)
786
+ numerator_target = round(denominator_value * bounded_rate)
787
+ self._set_metric_and_rebalance(row, definition.numerator, numerator_target)
788
+
789
+ def _refresh_applied_generator_actuals(
790
+ self,
791
+ applications: list[SyntheticGeneratorApplication],
792
+ daily_metrics: list[MetricRecord],
793
+ ) -> None:
794
+ by_date = {row.date: row for row in daily_metrics}
795
+ conversion_map = {item.name: item for item in self.config.conversion_definitions}
796
+ for application in applications:
797
+ record = by_date.get(application.date)
798
+ if record is None:
799
+ continue
800
+ if application.metric_type == "event_count":
801
+ actual_value = float(getattr(record, application.metric_name))
802
+ else:
803
+ definition = conversion_map[application.metric_name]
804
+ denominator = getattr(record, definition.denominator)
805
+ actual_value = round(
806
+ (getattr(record, definition.numerator) / denominator * 100.0)
807
+ if denominator > 0
808
+ else 0.0,
809
+ 4,
810
+ )
811
+ application.actual_value = round(actual_value, 4)
812
+
813
+ def _build_expected_rows(
814
+ self,
815
+ daily_metrics: list[MetricRecord],
816
+ hourly_metrics: list[MetricRecord],
817
+ plan: list[PlannedAnomaly],
818
+ episode_config: EpisodeConfig,
819
+ ) -> list[MetricSubmissionRow]:
820
+ toolkit = SharedAnalysisToolkit(
821
+ AnalysisContext(
822
+ daily_metrics=daily_metrics,
823
+ hourly_metrics=hourly_metrics,
824
+ conversion_definitions=list(self.config.conversion_definitions),
825
+ config={
826
+ "seed": episode_config.seed,
827
+ "scenario_family": episode_config.scenario_family,
828
+ "difficulty": episode_config.difficulty,
829
+ "anomaly_density": episode_config.anomaly_density,
830
+ "anomaly_count": episode_config.anomaly_count,
831
+ },
832
+ )
833
+ )
834
+ rows: list[MetricSubmissionRow] = []
835
+ for item in plan:
836
+ if item.detection_method == "compare_rate_to_median":
837
+ result = toolkit.compare_rate_to_median(item.date, item.entity_name)
838
+ elif item.detection_method == "compare_count_to_median":
839
+ result = toolkit.compare_count_to_median(item.date, item.entity_name)
840
+ elif item.detection_method == "detect_funnel_break":
841
+ candidates = toolkit.detect_funnel_break(item.date)["candidates"]
842
+ result = next((row for row in candidates if row["entity_name"] == item.entity_name), None)
843
+ if result is None:
844
+ numerator = item.details["numerator"]
845
+ denominator = item.details["denominator"]
846
+ daily_row = next(row for row in daily_metrics if row.date == item.date)
847
+ baseline_series = [
848
+ (getattr(row, numerator) / getattr(row, denominator) * 100.0)
849
+ if getattr(row, denominator) > 0
850
+ else 0.0
851
+ for row in daily_metrics
852
+ ]
853
+ baseline = round(median(baseline_series), 4)
854
+ observed = round(
855
+ (getattr(daily_row, numerator) / getattr(daily_row, denominator) * 100.0)
856
+ if getattr(daily_row, denominator) > 0
857
+ else 0.0,
858
+ 4,
859
+ )
860
+ delta = round(observed - baseline, 4)
861
+ result = {
862
+ "entity_type": item.entity_type,
863
+ "entity_name": item.entity_name,
864
+ "baseline_value": baseline,
865
+ "observed_value": observed,
866
+ "delta_value": delta,
867
+ "severity": self._severity_from_ratio(abs(delta), 5.0, 10.0, 15.0),
868
+ }
869
+ elif item.detection_method == "check_impossible_counts":
870
+ impossible = toolkit.check_impossible_counts(item.date)
871
+ result = {
872
+ "date": item.date,
873
+ "entity_type": item.entity_type,
874
+ "entity_name": item.entity_name,
875
+ "anomaly_type": item.anomaly_type,
876
+ "detection_method": item.detection_method,
877
+ "baseline_value": 0.0,
878
+ "observed_value": round(impossible["total_excess"], 4),
879
+ "delta_value": round(impossible["total_excess"], 4),
880
+ "severity": self._severity_from_ratio(impossible["total_excess"], 20.0, 60.0, 120.0),
881
+ }
882
+ else:
883
+ observed_share = toolkit.hourly_rows_for_date(item.date)["summary"]["daytime_share"]
884
+ baseline_share = toolkit._median_daytime_share()
885
+ delta = round(observed_share - baseline_share, 4)
886
+ result = {
887
+ "date": item.date,
888
+ "entity_type": item.entity_type,
889
+ "entity_name": item.entity_name,
890
+ "anomaly_type": item.anomaly_type,
891
+ "detection_method": item.detection_method,
892
+ "baseline_value": round(baseline_share, 4),
893
+ "observed_value": round(observed_share, 4),
894
+ "delta_value": delta,
895
+ "severity": self._severity_from_ratio(abs(delta) * 100.0, 10.0, 18.0, 25.0),
896
+ }
897
+
898
+ if not result:
899
+ continue
900
+ normalized = dict(result)
901
+ normalized["date"] = item.date
902
+ normalized["anomaly_type"] = item.anomaly_type
903
+ normalized["detection_method"] = item.detection_method
904
+ rows.append(MetricSubmissionRow(**normalized))
905
+ deduped = {f"{row.date}|{row.entity_type}|{row.entity_name}": row for row in rows}
906
+ return sorted(deduped.values(), key=lambda row: (row.date, row.entity_type, row.entity_name))
907
+
908
+ def _materialize_metrics(
909
+ self,
910
+ base_hourly: dict[str, list[MetricRecord]],
911
+ ) -> tuple[list[MetricRecord], list[MetricRecord]]:
912
+ hourly_metrics = []
913
+ daily_metrics = []
914
+ for date_key in sorted(base_hourly):
915
+ rows = base_hourly[date_key]
916
+ hourly_metrics.extend(rows)
917
+ daily_metrics.append(
918
+ MetricRecord(
919
+ date=date_key,
920
+ hour=None,
921
+ app_opens=sum(item.app_opens for item in rows),
922
+ menu_opens=sum(item.menu_opens for item in rows),
923
+ product_added_to_cart=sum(item.product_added_to_cart for item in rows),
924
+ orders_placed=sum(item.orders_placed for item in rows),
925
+ payment_successful=sum(item.payment_successful for item in rows),
926
+ )
927
+ )
928
+ return daily_metrics, hourly_metrics
929
+
930
+ def _set_metric_and_rebalance(self, row: MetricRecord, metric_name: str, value: int) -> None:
931
+ caps = {
932
+ "app_opens": None,
933
+ "menu_opens": row.app_opens,
934
+ "product_added_to_cart": row.menu_opens,
935
+ "orders_placed": row.product_added_to_cart,
936
+ "payment_successful": row.orders_placed,
937
+ }
938
+ cap = caps.get(metric_name)
939
+ bounded = max(0, value if cap is None else min(value, cap))
940
+ setattr(row, metric_name, bounded)
941
+ self._rebalance_downstream_from(row, metric_name)
942
+ self._rebalance_upstream_to(row, metric_name)
943
+
944
+ def _rebalance_downstream_from(self, row: MetricRecord, metric_name: str) -> None:
945
+ order = list(COUNT_METRICS)
946
+ start_index = order.index(metric_name)
947
+ for index in range(start_index + 1, len(order)):
948
+ parent_name = order[index - 1]
949
+ current_name = order[index]
950
+ parent_value = getattr(row, parent_name)
951
+ current_value = min(getattr(row, current_name), parent_value)
952
+ setattr(row, current_name, max(0, current_value))
953
+
954
+ def _rebalance_upstream_to(self, row: MetricRecord, metric_name: str) -> None:
955
+ order = list(COUNT_METRICS)
956
+ start_index = order.index(metric_name)
957
+ for index in range(start_index - 1, -1, -1):
958
+ child_name = order[index + 1]
959
+ current_name = order[index]
960
+ child_value = getattr(row, child_name)
961
+ current_value = max(getattr(row, current_name), child_value)
962
+ setattr(row, current_name, current_value)
963
+
964
+ def _base_rate_from_metric(self, metric_name: str) -> float:
965
+ if metric_name == "menu_opens":
966
+ return self.config.baseline_rates["menu_opens"]
967
+ if metric_name == "product_added_to_cart":
968
+ return self.config.baseline_rates["product_added_to_cart"]
969
+ if metric_name == "orders_placed":
970
+ return self.config.baseline_rates["orders_placed"]
971
+ if metric_name == "payment_successful":
972
+ return self.config.baseline_rates["payment_successful"]
973
+ return 1.0
974
+
975
+ def _hour_weights(self, weekday: int, rng: random.Random) -> list[float]:
976
+ weekend_multiplier = 1.12 if weekday >= 5 else 1.0
977
+ weights = [
978
+ max(0.001, value * weekend_multiplier * (1.0 + rng.uniform(-0.08, 0.08)))
979
+ for value in self.config.hourly_weights
980
+ ]
981
+ total = sum(weights)
982
+ return [value / total for value in weights]
983
+
984
+ @staticmethod
985
+ def _allocate_total(total: int, weights: list[float], rng: random.Random) -> list[int]:
986
+ raw = [total * weight for weight in weights]
987
+ integers = [int(value) for value in raw]
988
+ remainder = total - sum(integers)
989
+ ranked = sorted(
990
+ range(len(weights)),
991
+ key=lambda index: (raw[index] - integers[index], rng.random()),
992
+ reverse=True,
993
+ )
994
+ for index in ranked[:remainder]:
995
+ integers[index] += 1
996
+ return integers
997
+
998
+ @staticmethod
999
+ def _ratio(numerator: int, denominator: int) -> float:
1000
+ if denominator <= 0:
1001
+ return 0.0
1002
+ return numerator / denominator
1003
+
1004
+ @staticmethod
1005
+ def _bounded(value: float, lower: float, upper: float) -> float:
1006
+ return min(max(value, lower), upper)
1007
+
1008
+ @staticmethod
1009
+ def _severity_from_ratio(value: float, medium: float, high: float, critical: float) -> str:
1010
+ if value >= critical:
1011
+ return "critical"
1012
+ if value >= high:
1013
+ return "high"
1014
+ if value >= medium:
1015
+ return "medium"
1016
+ return "low"
server/gradio_ui.py ADDED
@@ -0,0 +1,728 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Custom Gradio UI for testing the metric tracker RL environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+
7
+ import pandas as pd
8
+
9
+ try:
10
+ from ..analysis_tools import available_analysis_methods
11
+ from ..tasks import DEFAULT_TASK_ORDER, available_task_specs, get_task_spec
12
+ except ImportError:
13
+ from analysis_tools import available_analysis_methods
14
+ from tasks import DEFAULT_TASK_ORDER, available_task_specs, get_task_spec
15
+
16
+ try:
17
+ import gradio as gr
18
+ except ImportError: # pragma: no cover
19
+ gr = None
20
+
21
+
22
+ GENERATOR_METHODS = [
23
+ "get_median_filter_rows",
24
+ "get_rate_drop_from_median_rows",
25
+ "get_rate_spike_from_median_rows",
26
+ "get_absolute_drop_in_event_count_rows",
27
+ "get_absolute_spike_in_event_count_rows",
28
+ "get_funnel_break_rows",
29
+ "get_hourly_traffic_mix_shift_rows",
30
+ "get_instrumentation_data_quality_issue_rows",
31
+ ]
32
+ METHOD_CHOICES = [item.name for item in available_analysis_methods()]
33
+ TASK_CHOICES = list(DEFAULT_TASK_ORDER)
34
+ TASK_SUMMARIES = {
35
+ item.task_id: item.model_dump()
36
+ for item in available_task_specs()
37
+ }
38
+ METRIC_CHOICES = [
39
+ "app_opens",
40
+ "menu_opens",
41
+ "product_added_to_cart",
42
+ "orders_placed",
43
+ "payment_successful",
44
+ "app_open_to_menu_open",
45
+ "menu_open_to_product_added_to_cart",
46
+ "product_added_to_cart_to_order_placed",
47
+ "order_placed_to_payment_successful",
48
+ "app_open_to_order_placed",
49
+ "app_open_to_payment_successful",
50
+ ]
51
+
52
+
53
+ def build_metric_tracker_gradio_app(
54
+ web_manager,
55
+ action_fields,
56
+ metadata,
57
+ is_chat_env,
58
+ title,
59
+ quick_start_md,
60
+ ):
61
+ """Build a method-driven and generator-driven debugger."""
62
+ del action_fields, metadata, is_chat_env, quick_start_md
63
+ if gr is None: # pragma: no cover
64
+ raise ImportError("gradio is required to build the custom metric tracker UI.")
65
+
66
+ with gr.Blocks() as demo:
67
+ gr.Markdown(
68
+ f"""
69
+ # {title} Generator Debugger
70
+
71
+ The UI now supports the same named benchmark tasks used by the agent baseline.
72
+ Pick a task to load its canonical easy, medium, or hard setup, then optionally
73
+ override the reset controls for custom debugging.
74
+
75
+ Standard mode exposes method calls only. You inspect data through methods like
76
+ `show_raw_data`, `get_metric_median`, `get_metric_std_dev_from_median`,
77
+ and `get_rows_with_abs_diff_from_median_gt`, then assemble payload generators
78
+ such as `get_median_filter_rows(metric_name, threshold_multiplier)`.
79
+ """
80
+ )
81
+
82
+ session_state = gr.State(_empty_state())
83
+
84
+ with gr.Row():
85
+ task_id = gr.Dropdown(
86
+ label="Named Task",
87
+ choices=TASK_CHOICES,
88
+ value=TASK_CHOICES[0],
89
+ )
90
+ task_details = gr.JSON(
91
+ label="Selected Task Details",
92
+ value=TASK_SUMMARIES[TASK_CHOICES[0]],
93
+ )
94
+
95
+ with gr.Row():
96
+ initial_task = get_task_spec(TASK_CHOICES[0])
97
+ seed = gr.Number(label="Seed", value=initial_task.seed, precision=0)
98
+ scenario_family = gr.Dropdown(
99
+ label="Scenario Family",
100
+ choices=[
101
+ "mixed",
102
+ "rate_drop_from_median",
103
+ "rate_spike_from_median",
104
+ "absolute_drop_in_event_count",
105
+ "absolute_spike_in_event_count",
106
+ "funnel_break",
107
+ "hourly_traffic_mix_shift",
108
+ "instrumentation_data_quality_issue",
109
+ ],
110
+ value=initial_task.scenario_family,
111
+ )
112
+ difficulty = gr.Dropdown(
113
+ label="Difficulty",
114
+ choices=["easy", "medium", "hard"],
115
+ value=initial_task.difficulty,
116
+ )
117
+ anomaly_density = gr.Dropdown(
118
+ label="Anomaly Density",
119
+ choices=["low", "medium", "high"],
120
+ value=initial_task.anomaly_density,
121
+ )
122
+ anomaly_count = gr.Number(label="Anomaly Count", value=initial_task.anomaly_count, precision=0)
123
+ debug_mode = gr.Checkbox(label="Debug Mode", value=False)
124
+
125
+ reset_anomalies = gr.Code(
126
+ label="Reset Anomalies JSON",
127
+ language="json",
128
+ value="[]",
129
+ interactive=True,
130
+ )
131
+
132
+ with gr.Row():
133
+ reset_btn = gr.Button("Reset Episode", variant="primary")
134
+ preview_btn = gr.Button("Preview Generator Payload", variant="secondary")
135
+ submit_btn = gr.Button("Submit Generator Payload", variant="secondary")
136
+ get_state_btn = gr.Button("Get State", variant="secondary")
137
+
138
+ gr.Markdown("## Methods")
139
+ gr.Markdown(
140
+ "Run a method after reset to fetch exactly the daily aggregate data, median, "
141
+ "std-from-median, filtered rows, or generated payload rows you want."
142
+ )
143
+ with gr.Row():
144
+ method_name = gr.Dropdown(
145
+ label="Method",
146
+ choices=METHOD_CHOICES,
147
+ value="show_raw_data",
148
+ )
149
+ method_metric = gr.Dropdown(
150
+ label="metrics",
151
+ choices=METRIC_CHOICES,
152
+ value=[],
153
+ multiselect=True,
154
+ )
155
+ method_threshold = gr.Number(label="threshold / multiplier", value=2.0)
156
+ method_limit = gr.Number(label="limit", value=5, precision=0)
157
+ run_method_btn = gr.Button("Run Method", variant="secondary")
158
+ with gr.Row():
159
+ method_date = gr.Textbox(label="date", placeholder="YYYY-MM-DD")
160
+ method_entity = gr.Textbox(label="entity_name", placeholder="orders_placed or app_open_to_order_placed")
161
+ method_rows_json = gr.Code(
162
+ label="rows JSON for preview_submission",
163
+ language="json",
164
+ value="[]",
165
+ interactive=True,
166
+ )
167
+ analysis_result = gr.JSON(label="Last Method Results")
168
+
169
+ with gr.Tab("Method Data"):
170
+ gr.Markdown(
171
+ "This panel shows only method-returned data. Use `show_raw_data` for daily "
172
+ "aggregate rows, then median/std/filter methods to inspect candidate anomalies."
173
+ )
174
+ method_rows = gr.Dataframe(label="Method Rows", interactive=False)
175
+
176
+ gr.Markdown("## Payload Generators")
177
+ gr.Markdown(
178
+ "Add generator methods here, then preview or submit using the buttons at the top."
179
+ )
180
+ generator_methods_df = gr.Dataframe(
181
+ headers=["method_name", "metric_name", "metric_names", "threshold_multiplier"],
182
+ datatype=["str", "str", "str", "number"],
183
+ label="Generator Methods",
184
+ interactive=True,
185
+ )
186
+ payload_generator_methods = gr.JSON(label="Methods Passed to Payload Generator")
187
+ with gr.Row():
188
+ generator_method_name = gr.Dropdown(label="method_name", choices=GENERATOR_METHODS, value="get_median_filter_rows")
189
+ generator_metric_name = gr.Dropdown(
190
+ label="metrics",
191
+ choices=METRIC_CHOICES,
192
+ value=[],
193
+ multiselect=True,
194
+ )
195
+ generator_multiplier = gr.Number(label="threshold_multiplier", value=2.0)
196
+ with gr.Row():
197
+ add_generator_btn = gr.Button("Add / Update Generator", variant="secondary")
198
+ remove_generator_btn = gr.Button("Remove Generator", variant="secondary")
199
+ clear_generators_btn = gr.Button("Clear Generators", variant="secondary")
200
+
201
+ status = gr.Textbox(label="Status", interactive=False)
202
+ summary = gr.JSON(label="Episode Summary")
203
+ active_task = gr.JSON(label="Active Task", value=TASK_SUMMARIES[TASK_CHOICES[0]])
204
+ task_catalog = gr.JSON(label="Available Tasks", value=list(TASK_SUMMARIES.values()))
205
+ synthetic_methods = gr.JSON(label="Synthetic Generator Methods")
206
+ applied_synthetic_generators = gr.Dataframe(label="Applied Synthetic Generators", interactive=False)
207
+ available_methods = gr.JSON(label="Shared Methods")
208
+ submission_feedback = gr.JSON(label="Submission Feedback")
209
+ reward_breakdown = gr.JSON(label="Reward Breakdown")
210
+ generated_rows = gr.Dataframe(label="Generated Payload Rows", interactive=False)
211
+ raw_json = gr.Code(label="Latest Environment Response", language="json", interactive=False)
212
+ debug_snapshot = gr.JSON(label="Debug Snapshot")
213
+
214
+ def apply_task_defaults(selected_task_id: str):
215
+ task = get_task_spec(selected_task_id)
216
+ return (
217
+ task.seed,
218
+ task.scenario_family,
219
+ task.difficulty,
220
+ task.anomaly_density,
221
+ task.anomaly_count,
222
+ task.to_model().model_dump(),
223
+ )
224
+
225
+ async def reset_episode(selected_task_id, seed_value, family, level, density, anomaly_count_value, reset_anomalies_json, debug_enabled):
226
+ try:
227
+ parsed_anomalies = json.loads(reset_anomalies_json or "[]")
228
+ if not isinstance(parsed_anomalies, list):
229
+ raise ValueError("Reset anomalies JSON must be a list.")
230
+ except Exception as exc:
231
+ return (
232
+ _empty_state(),
233
+ f"Invalid reset anomalies JSON: {exc}",
234
+ {"status": "error"},
235
+ {},
236
+ list(TASK_SUMMARIES.values()),
237
+ [],
238
+ _generator_frame([]),
239
+ [],
240
+ {},
241
+ {},
242
+ {},
243
+ _generator_frame([]),
244
+ _generator_frame([]),
245
+ "",
246
+ _debug_snapshot(web_manager, debug_enabled),
247
+ _generator_frame([]),
248
+ [],
249
+ )
250
+ web_manager.env.set_debug_mode(bool(debug_enabled))
251
+ data = await web_manager.reset_environment(
252
+ {
253
+ "task_id": selected_task_id,
254
+ "seed": int(seed_value or 0),
255
+ "scenario_family": family,
256
+ "difficulty": level,
257
+ "anomaly_density": density,
258
+ "anomaly_count": int(anomaly_count_value or 3),
259
+ "anomalies": parsed_anomalies,
260
+ }
261
+ )
262
+ method_data = await web_manager.step_environment(
263
+ {
264
+ "analysis_method": "show_raw_data",
265
+ "analysis_args": {"limit": 5},
266
+ "classifications": [],
267
+ "payload_generators": [],
268
+ }
269
+ )
270
+ state = _state_from_response(data)
271
+ state["latest_response"] = method_data
272
+ state["last_method_result"] = method_data.get("observation", {}).get("analysis_result")
273
+ obs = data.get("observation", {})
274
+ method_result = state["last_method_result"] or {}
275
+ available_tasks = obs.get("available_tasks") or list(TASK_SUMMARIES.values())
276
+ active_task_payload = next(
277
+ (item for item in available_tasks if item.get("task_id") == obs.get("task_id")),
278
+ {
279
+ "task_id": obs.get("task_id"),
280
+ "instruction": obs.get("instruction"),
281
+ "objective": obs.get("message"),
282
+ "difficulty": (obs.get("config") or {}).get("difficulty"),
283
+ "grader_name": (obs.get("config") or {}).get("grader_name"),
284
+ },
285
+ )
286
+ return (
287
+ state,
288
+ obs.get("message", ""),
289
+ {
290
+ "task_id": obs.get("task_id"),
291
+ "status": obs.get("status"),
292
+ "config": obs.get("config"),
293
+ "expected_row_count": obs.get("expected_row_count"),
294
+ },
295
+ active_task_payload,
296
+ available_tasks,
297
+ [item for item in obs.get("available_synthetic_generator_methods", [])],
298
+ pd.DataFrame([item for item in obs.get("applied_synthetic_generators", [])]),
299
+ [item for item in obs.get("available_methods", [])],
300
+ method_result,
301
+ obs.get("submission_issues") or [],
302
+ obs.get("reward_breakdown") or {},
303
+ _method_frame(method_result),
304
+ pd.DataFrame(),
305
+ json.dumps(method_data, indent=2),
306
+ _debug_snapshot(web_manager, debug_enabled),
307
+ _generator_frame(state["payload_generators"]),
308
+ state["payload_generators"],
309
+ )
310
+
311
+ async def run_method(
312
+ payload: dict,
313
+ selected_method: str,
314
+ metric_names: list[str],
315
+ method_date_value: str,
316
+ method_entity_value: str,
317
+ method_rows_value: str,
318
+ threshold: float,
319
+ limit_value: int,
320
+ ):
321
+ if not payload.get("active"):
322
+ return payload, {"error": "Reset the environment first."}, "", gr.skip(), gr.skip(), gr.skip()
323
+ args = _method_args(
324
+ selected_method,
325
+ metric_names,
326
+ method_date_value,
327
+ method_entity_value,
328
+ method_rows_value,
329
+ threshold,
330
+ limit_value,
331
+ payload["payload_generators"],
332
+ )
333
+ data = await web_manager.step_environment(
334
+ {
335
+ "analysis_method": selected_method,
336
+ "analysis_args": args,
337
+ "classifications": [],
338
+ "payload_generators": [],
339
+ }
340
+ )
341
+ payload["latest_response"] = data
342
+ payload["last_method_result"] = data.get("observation", {}).get("analysis_result")
343
+ method_result = payload["last_method_result"] or {}
344
+ generated = method_result.get("result", {}).get("generated_rows", [])
345
+ method_frame = _method_frame(method_result)
346
+ return (
347
+ payload,
348
+ method_result,
349
+ data.get("observation", {}).get("message", ""),
350
+ method_frame,
351
+ pd.DataFrame(generated),
352
+ json.dumps(data, indent=2),
353
+ )
354
+
355
+ def add_or_update_generator(payload: dict, method_name_value: str, metric_names: list[str], threshold_multiplier: float):
356
+ if not payload.get("active"):
357
+ return payload, _generator_frame([]), []
358
+ metric_names = [item for item in (metric_names or []) if item]
359
+ row = {
360
+ "method_name": method_name_value,
361
+ "metric_name": metric_names[0] if len(metric_names) == 1 else None,
362
+ "metric_names": metric_names,
363
+ "threshold_multiplier": float(threshold_multiplier),
364
+ }
365
+ keyed = {
366
+ _generator_row_key(item): item
367
+ for item in payload["payload_generators"]
368
+ }
369
+ keyed[_generator_row_key(row)] = row
370
+ payload["payload_generators"] = list(keyed.values())
371
+ return payload, _generator_frame(payload["payload_generators"]), payload["payload_generators"]
372
+
373
+ def remove_generator(payload: dict, method_name_value: str, metric_names: list[str]):
374
+ if not payload.get("active"):
375
+ return payload, _generator_frame([]), []
376
+ metric_names = [item for item in (metric_names or []) if item]
377
+ payload["payload_generators"] = [
378
+ item
379
+ for item in payload["payload_generators"]
380
+ if not (
381
+ item.get("method_name") == method_name_value
382
+ and [name for name in item.get("metric_names", []) if name] == metric_names
383
+ )
384
+ ]
385
+ return payload, _generator_frame(payload["payload_generators"]), payload["payload_generators"]
386
+
387
+ def clear_generators(payload: dict):
388
+ payload["payload_generators"] = []
389
+ return payload, _generator_frame([]), []
390
+
391
+ def sync_generator_rows(payload: dict, generator_rows):
392
+ normalized = _normalize_generator_rows(generator_rows)
393
+ payload["payload_generators"] = normalized
394
+ return payload, _generator_frame(normalized), normalized
395
+
396
+ async def preview_payload(payload: dict, generator_rows):
397
+ if not payload.get("active"):
398
+ return payload, {"error": "Reset the environment first."}, _generator_frame([]), []
399
+ payload["payload_generators"] = _normalize_generator_rows(generator_rows)
400
+ if not payload.get("payload_generators"):
401
+ return payload, {"error": "Add at least one payload generator first."}, _generator_frame([]), []
402
+ data = await web_manager.step_environment(
403
+ {
404
+ "analysis_method": "payload_generator",
405
+ "analysis_args": {"generator_methods": payload["payload_generators"]},
406
+ "classifications": [],
407
+ "payload_generators": [],
408
+ }
409
+ )
410
+ payload["latest_response"] = data
411
+ payload["last_method_result"] = data.get("observation", {}).get("analysis_result")
412
+ result = payload["last_method_result"] or {}
413
+ return payload, result, pd.DataFrame(result.get("result", {}).get("generated_rows", [])), payload["payload_generators"]
414
+
415
+ async def submit_payload(payload: dict, debug_enabled: bool, generator_rows):
416
+ if not payload.get("active"):
417
+ return payload, "Reset the environment first.", gr.skip(), gr.skip(), gr.skip(), "", gr.skip(), gr.skip(), gr.skip()
418
+ payload["payload_generators"] = _normalize_generator_rows(generator_rows)
419
+ if not payload.get("payload_generators"):
420
+ return (
421
+ payload,
422
+ "Add at least one payload generator before submitting.",
423
+ {
424
+ "status": "ready",
425
+ "generator_count": 0,
426
+ },
427
+ {"error": "No payload generators configured."},
428
+ {},
429
+ "",
430
+ _debug_snapshot(web_manager, debug_enabled),
431
+ _generator_frame([]),
432
+ [],
433
+ )
434
+ data = await web_manager.step_environment(
435
+ {
436
+ "payload_generators": payload["payload_generators"],
437
+ "classifications": [],
438
+ }
439
+ )
440
+ payload["latest_response"] = data
441
+ obs = data.get("observation", {})
442
+ summary = {
443
+ "task_id": obs.get("task_id"),
444
+ "status": obs.get("status"),
445
+ "message": obs.get("message"),
446
+ "config": obs.get("config"),
447
+ "expected_row_count": obs.get("expected_row_count"),
448
+ "correct_row_count": obs.get("correct_row_count"),
449
+ "generated_row_count": len(obs.get("generated_rows") or []),
450
+ "submitted_row_count": len(obs.get("submitted_rows") or []),
451
+ "issue_count": len(obs.get("submission_issues") or []),
452
+ "reward": data.get("reward", 0.0),
453
+ "done": data.get("done", False),
454
+ }
455
+ feedback = {
456
+ "message": obs.get("message", ""),
457
+ "issue_count": len(obs.get("submission_issues") or []),
458
+ "issues": obs.get("submission_issues") or [],
459
+ "generated_row_count": len(obs.get("generated_rows") or []),
460
+ "generator_count": len(payload.get("payload_generators") or []),
461
+ }
462
+ return (
463
+ payload,
464
+ obs.get("message", ""),
465
+ summary,
466
+ feedback,
467
+ obs.get("reward_breakdown") or {},
468
+ json.dumps(data, indent=2),
469
+ _debug_snapshot(web_manager, debug_enabled),
470
+ pd.DataFrame([row for row in obs.get("generated_rows", [])]),
471
+ payload["payload_generators"],
472
+ )
473
+
474
+ def get_state_sync():
475
+ return json.dumps(web_manager.get_state(), indent=2)
476
+
477
+ reset_btn.click(
478
+ fn=reset_episode,
479
+ inputs=[task_id, seed, scenario_family, difficulty, anomaly_density, anomaly_count, reset_anomalies, debug_mode],
480
+ outputs=[
481
+ session_state,
482
+ status,
483
+ summary,
484
+ active_task,
485
+ task_catalog,
486
+ synthetic_methods,
487
+ applied_synthetic_generators,
488
+ available_methods,
489
+ analysis_result,
490
+ submission_feedback,
491
+ reward_breakdown,
492
+ method_rows,
493
+ generated_rows,
494
+ raw_json,
495
+ debug_snapshot,
496
+ generator_methods_df,
497
+ payload_generator_methods,
498
+ ],
499
+ )
500
+ task_id.change(
501
+ fn=apply_task_defaults,
502
+ inputs=[task_id],
503
+ outputs=[seed, scenario_family, difficulty, anomaly_density, anomaly_count, task_details],
504
+ )
505
+ run_method_btn.click(
506
+ fn=run_method,
507
+ inputs=[
508
+ session_state,
509
+ method_name,
510
+ method_metric,
511
+ method_date,
512
+ method_entity,
513
+ method_rows_json,
514
+ method_threshold,
515
+ method_limit,
516
+ ],
517
+ outputs=[session_state, analysis_result, status, method_rows, generated_rows, raw_json],
518
+ )
519
+ add_generator_btn.click(
520
+ fn=add_or_update_generator,
521
+ inputs=[session_state, generator_method_name, generator_metric_name, generator_multiplier],
522
+ outputs=[session_state, generator_methods_df, payload_generator_methods],
523
+ )
524
+ remove_generator_btn.click(
525
+ fn=remove_generator,
526
+ inputs=[session_state, generator_method_name, generator_metric_name],
527
+ outputs=[session_state, generator_methods_df, payload_generator_methods],
528
+ )
529
+ clear_generators_btn.click(
530
+ fn=clear_generators,
531
+ inputs=[session_state],
532
+ outputs=[session_state, generator_methods_df, payload_generator_methods],
533
+ )
534
+ generator_methods_df.change(
535
+ fn=sync_generator_rows,
536
+ inputs=[session_state, generator_methods_df],
537
+ outputs=[session_state, generator_methods_df, payload_generator_methods],
538
+ )
539
+ preview_btn.click(
540
+ fn=preview_payload,
541
+ inputs=[session_state, generator_methods_df],
542
+ outputs=[session_state, analysis_result, generated_rows, payload_generator_methods],
543
+ )
544
+ submit_btn.click(
545
+ fn=submit_payload,
546
+ inputs=[session_state, debug_mode, generator_methods_df],
547
+ outputs=[session_state, status, summary, submission_feedback, reward_breakdown, raw_json, debug_snapshot, generated_rows, payload_generator_methods],
548
+ )
549
+ get_state_btn.click(fn=get_state_sync, outputs=[raw_json])
550
+
551
+ return demo
552
+
553
+
554
+ def _method_args(
555
+ method_name: str,
556
+ metric_names: list[str],
557
+ method_date: str,
558
+ method_entity: str,
559
+ method_rows_json: str,
560
+ threshold: float,
561
+ limit_value: int,
562
+ payload_generators: list[dict],
563
+ ) -> dict:
564
+ selected = [item for item in (metric_names or []) if item]
565
+ resolved_date = (method_date or "").strip()
566
+ resolved_entity = (method_entity or "").strip()
567
+ if method_name == "show_raw_data":
568
+ return {"limit": int(limit_value or 5)}
569
+ if method_name in {"rows_for_date", "hourly_rows_for_date", "detect_funnel_break", "check_impossible_counts"}:
570
+ return {"date": resolved_date}
571
+ if method_name in {"compare_rate_to_median", "compare_count_to_median"}:
572
+ return {
573
+ "date": resolved_date,
574
+ "entity_name": resolved_entity,
575
+ }
576
+ if method_name in {"get_metric_median", "get_metric_std_dev_from_median"}:
577
+ return {
578
+ "metric_name": selected[0] if len(selected) == 1 else None,
579
+ "metric_names": selected,
580
+ }
581
+ if method_name == "get_rows_with_abs_diff_from_median_gt":
582
+ return {
583
+ "metric_name": selected[0] if len(selected) == 1 else None,
584
+ "metric_names": selected,
585
+ "threshold": float(threshold),
586
+ }
587
+ if method_name in {
588
+ "get_median_filter_rows",
589
+ "get_rate_drop_from_median_rows",
590
+ "get_rate_spike_from_median_rows",
591
+ "get_absolute_drop_in_event_count_rows",
592
+ "get_absolute_spike_in_event_count_rows",
593
+ }:
594
+ return {
595
+ "metric_name": selected[0] if len(selected) == 1 else None,
596
+ "metric_names": selected,
597
+ "threshold_multiplier": float(threshold),
598
+ }
599
+ if method_name in {
600
+ "get_funnel_break_rows",
601
+ "get_hourly_traffic_mix_shift_rows",
602
+ "get_instrumentation_data_quality_issue_rows",
603
+ }:
604
+ return {"threshold_multiplier": float(threshold)}
605
+ if method_name == "payload_generator":
606
+ return {"generator_methods": payload_generators}
607
+ if method_name == "list_suspicious_dates":
608
+ return {"limit": int(limit_value or 10)}
609
+ if method_name == "preview_submission":
610
+ return {"rows": _parse_rows_json(method_rows_json)}
611
+ return {}
612
+
613
+
614
+ def _parse_rows_json(raw_value: str) -> list[dict]:
615
+ if not raw_value or not raw_value.strip():
616
+ return []
617
+ parsed = json.loads(raw_value)
618
+ if not isinstance(parsed, list):
619
+ raise ValueError("rows JSON must be a list.")
620
+ return [item for item in parsed if isinstance(item, dict)]
621
+
622
+
623
+ def _method_frame(method_result: dict) -> pd.DataFrame:
624
+ result = (method_result or {}).get("result") or {}
625
+ if isinstance(result, dict):
626
+ if isinstance(result.get("results"), list):
627
+ rows = []
628
+ for item in result["results"]:
629
+ if isinstance(item, dict) and isinstance(item.get("rows"), list):
630
+ for row in item["rows"]:
631
+ enriched = dict(row)
632
+ enriched["metric_name"] = item.get("metric_name", enriched.get("metric_name"))
633
+ rows.append(enriched)
634
+ elif isinstance(item, dict):
635
+ rows.append(item)
636
+ return pd.DataFrame(rows)
637
+ if isinstance(result.get("rows"), list):
638
+ return pd.DataFrame(result["rows"])
639
+ if isinstance(result.get("dates"), list):
640
+ return pd.DataFrame(result["dates"])
641
+ if isinstance(result.get("generated_rows"), list):
642
+ return pd.DataFrame(result["generated_rows"])
643
+ return pd.DataFrame()
644
+
645
+
646
+ def _state_from_response(data: dict) -> dict:
647
+ return {
648
+ "active": True,
649
+ "payload_generators": [],
650
+ "last_method_result": data.get("observation", {}).get("analysis_result"),
651
+ "latest_response": data,
652
+ }
653
+
654
+
655
+ def _normalize_generator_rows(generator_rows) -> list[dict]:
656
+ if generator_rows is None:
657
+ return []
658
+ if isinstance(generator_rows, pd.DataFrame):
659
+ rows = generator_rows.to_dict(orient="records")
660
+ elif isinstance(generator_rows, list):
661
+ rows = generator_rows
662
+ else:
663
+ return []
664
+
665
+ normalized = []
666
+ for row in rows:
667
+ if not isinstance(row, dict):
668
+ continue
669
+ metric_names = row.get("metric_names", [])
670
+ if isinstance(metric_names, str):
671
+ metric_names = [item for item in metric_names.split(",") if item]
672
+ elif not isinstance(metric_names, list):
673
+ metric_names = []
674
+ normalized.append(
675
+ {
676
+ "method_name": row.get("method_name"),
677
+ "metric_name": row.get("metric_name"),
678
+ "metric_names": metric_names,
679
+ "threshold_multiplier": float(row.get("threshold_multiplier", 0.0)),
680
+ }
681
+ )
682
+ return normalized
683
+
684
+
685
+ def _generator_row_key(row: dict) -> str:
686
+ metric_names = [item for item in (row.get("metric_names") or []) if item]
687
+ return (
688
+ f"{row.get('method_name') or ''}"
689
+ f"|{','.join(metric_names)}"
690
+ f"|{row.get('metric_name') or ''}"
691
+ f"|{float(row.get('threshold_multiplier', 0.0)):.6f}"
692
+ )
693
+
694
+
695
+ def _generator_frame(rows: list[dict]) -> pd.DataFrame:
696
+ normalized = []
697
+ for row in rows or []:
698
+ metric_names = [item for item in (row.get("metric_names") or []) if item]
699
+ normalized.append(
700
+ {
701
+ "method_name": row.get("method_name") or "",
702
+ "metric_name": row.get("metric_name") or "",
703
+ "metric_names": ",".join(metric_names),
704
+ "threshold_multiplier": float(row.get("threshold_multiplier", 0.0)),
705
+ }
706
+ )
707
+ return pd.DataFrame(
708
+ normalized,
709
+ columns=["method_name", "metric_name", "metric_names", "threshold_multiplier"],
710
+ )
711
+
712
+
713
+ def _debug_snapshot(web_manager, debug_enabled: bool) -> dict:
714
+ if not debug_enabled:
715
+ return {}
716
+ try:
717
+ return web_manager.env.export_debug_snapshot()
718
+ except Exception as exc:
719
+ return {"error": str(exc)}
720
+
721
+
722
+ def _empty_state() -> dict:
723
+ return {
724
+ "active": False,
725
+ "payload_generators": [],
726
+ "last_method_result": None,
727
+ "latest_response": None,
728
+ }
server/metric_tracker_rl_environment.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Metric tracking RL environment."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from uuid import uuid4
7
+
8
+ from openenv.core.env_server.interfaces import Environment
9
+ from openenv.core.env_server.types import State
10
+
11
+ try:
12
+ from ..analysis_tools import AnalysisContext, SharedAnalysisToolkit, available_analysis_methods
13
+ from ..evaluation import EvaluationConfig
14
+ from ..models import (
15
+ MetricTrackerRlAction,
16
+ MetricTrackerRlObservation,
17
+ MetricSubmissionRow,
18
+ SyntheticAnomalyGenerator,
19
+ )
20
+ from ..tasks import DEFAULT_TASK_ID, available_task_specs, get_task_spec
21
+ from .data_generator import (
22
+ EpisodeConfig,
23
+ EpisodeData,
24
+ MetricDataGenerator,
25
+ available_synthetic_generator_methods,
26
+ )
27
+ except ImportError:
28
+ from analysis_tools import AnalysisContext, SharedAnalysisToolkit, available_analysis_methods
29
+ from models import (
30
+ MetricTrackerRlAction,
31
+ MetricTrackerRlObservation,
32
+ MetricSubmissionRow,
33
+ SyntheticAnomalyGenerator,
34
+ )
35
+ from tasks import DEFAULT_TASK_ID, available_task_specs, get_task_spec
36
+ from server.data_generator import (
37
+ EpisodeConfig,
38
+ EpisodeData,
39
+ MetricDataGenerator,
40
+ available_synthetic_generator_methods,
41
+ )
42
+ from evaluation import EvaluationConfig
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class RewardConfig:
47
+ """Compatibility wrapper around the evaluator configuration."""
48
+
49
+ evaluation: EvaluationConfig = EvaluationConfig()
50
+
51
+
52
+ class MetricTrackerRlEnvironment(Environment):
53
+ """Iterative multi-anomaly benchmark with safe analysis methods."""
54
+
55
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
56
+
57
+ def __init__(
58
+ self,
59
+ generator: MetricDataGenerator | None = None,
60
+ reward_config: RewardConfig | None = None,
61
+ ) -> None:
62
+ initial_task = get_task_spec(DEFAULT_TASK_ID)
63
+ self._generator = generator or MetricDataGenerator()
64
+ self._reward_config = reward_config or RewardConfig()
65
+ self._state = State(episode_id=str(uuid4()), step_count=0)
66
+ self._episode: EpisodeData | None = None
67
+ self._completed = False
68
+ self._debug_mode = False
69
+ self._active_task = initial_task
70
+ self._next_task_id = initial_task.task_id
71
+ self._next_reset_config = initial_task.build_episode_config()
72
+ self._last_analysis_result: dict | None = None
73
+ self._expose_applied_generators = False
74
+
75
+ def configure_next_reset(
76
+ self,
77
+ *,
78
+ task_id: str | None = None,
79
+ seed: int | None = None,
80
+ scenario_family: str | None = None,
81
+ difficulty: str | None = None,
82
+ anomaly_density: str | None = None,
83
+ anomaly_count: int | None = None,
84
+ anomalies: list[dict] | list[SyntheticAnomalyGenerator] | None = None,
85
+ ) -> None:
86
+ """Update the configuration used for the next reset."""
87
+ base_task = get_task_spec(task_id or self._next_task_id)
88
+ base_config = base_task.build_episode_config() if task_id else self._next_reset_config
89
+ anomaly_generators = tuple(
90
+ item if isinstance(item, SyntheticAnomalyGenerator) else SyntheticAnomalyGenerator(**item)
91
+ for item in (anomalies or [])
92
+ )
93
+ self._next_task_id = base_task.task_id
94
+ self._next_reset_config = EpisodeConfig(
95
+ seed=base_config.seed if seed is None else seed,
96
+ scenario_family=base_config.scenario_family if scenario_family is None else scenario_family,
97
+ difficulty=base_config.difficulty if difficulty is None else difficulty,
98
+ anomaly_density=base_config.anomaly_density if anomaly_density is None else anomaly_density,
99
+ anomaly_count=base_config.anomaly_count if anomaly_count is None else anomaly_count,
100
+ anomaly_generators=anomaly_generators or base_config.anomaly_generators,
101
+ ).normalized()
102
+
103
+ def set_debug_mode(self, enabled: bool) -> None:
104
+ """Enable or disable debug-only environment views."""
105
+ self._debug_mode = bool(enabled)
106
+
107
+ def export_debug_snapshot(self) -> dict:
108
+ """Return a developer-only debug snapshot for the active episode."""
109
+ if not self._debug_mode:
110
+ raise RuntimeError("Debug mode is disabled.")
111
+ if self._episode is None:
112
+ return {}
113
+ return {
114
+ "config": self._episode.config.__dict__,
115
+ "expected_payload": [row.model_dump() for row in self._episode.expected_rows],
116
+ "anomaly_schedule": self._episode.anomaly_schedule,
117
+ "applied_synthetic_generators": [
118
+ row.model_dump() for row in self._episode.applied_synthetic_generators
119
+ ],
120
+ }
121
+
122
+ def reset(
123
+ self,
124
+ task_id: str | None = None,
125
+ seed: int | None = None,
126
+ scenario_family: str | None = None,
127
+ difficulty: str | None = None,
128
+ anomaly_density: str | None = None,
129
+ anomaly_count: int | None = None,
130
+ anomalies: list[dict] | list[SyntheticAnomalyGenerator] | None = None,
131
+ ) -> MetricTrackerRlObservation:
132
+ """Generate a fresh dataset and hidden target payload."""
133
+ if any(value is not None for value in (task_id, seed, scenario_family, difficulty, anomaly_density, anomaly_count)) or anomalies is not None:
134
+ self.configure_next_reset(
135
+ task_id=task_id,
136
+ seed=seed,
137
+ scenario_family=scenario_family,
138
+ difficulty=difficulty,
139
+ anomaly_density=anomaly_density,
140
+ anomaly_count=anomaly_count,
141
+ anomalies=anomalies,
142
+ )
143
+ self._state = State(episode_id=str(uuid4()), step_count=0)
144
+ self._active_task = get_task_spec(self._next_task_id)
145
+ self._episode = self._generator.generate_episode(self._next_reset_config)
146
+ self._completed = False
147
+ self._last_analysis_result = None
148
+ self._expose_applied_generators = anomalies is not None
149
+ return self._build_observation(
150
+ status="ready",
151
+ message=self._active_task.objective,
152
+ reward=0.0,
153
+ done=False,
154
+ )
155
+
156
+ def step(self, action: MetricTrackerRlAction) -> MetricTrackerRlObservation: # type: ignore[override]
157
+ """Evaluate a submitted payload and return deterministic feedback."""
158
+ if self._episode is None:
159
+ return self.reset()
160
+ if self._completed:
161
+ return self._build_observation(
162
+ status="completed",
163
+ message="Dataset already solved. Call reset() to create a new dataset.",
164
+ reward=1.0,
165
+ done=True,
166
+ submitted_rows=action.classifications,
167
+ )
168
+
169
+ if action.analysis_method:
170
+ self._state.step_count += 1
171
+ analysis_result = self._run_analysis(action.analysis_method, action.analysis_args)
172
+ self._last_analysis_result = analysis_result
173
+ return self._build_observation(
174
+ status="analyzed",
175
+ message=f"Ran analysis method `{action.analysis_method}`.",
176
+ reward=0.0,
177
+ done=False,
178
+ analysis_result=analysis_result,
179
+ )
180
+
181
+ submitted_rows = action.classifications
182
+ generated_rows: list[MetricSubmissionRow] = []
183
+ if action.payload_generators:
184
+ generator_result = self._run_analysis(
185
+ "payload_generator",
186
+ {"generator_methods": [item.model_dump() for item in action.payload_generators]},
187
+ )
188
+ self._last_analysis_result = generator_result
189
+ generated_rows = [
190
+ MetricSubmissionRow(**row)
191
+ for row in generator_result["result"]["generated_rows"]
192
+ ]
193
+ submitted_rows = generated_rows
194
+
195
+ self._state.step_count += 1
196
+ result = self._active_task.grade_submission(
197
+ submitted_rows,
198
+ self._episode.expected_rows,
199
+ config=self._reward_config.evaluation,
200
+ include_debug_expected=self._debug_mode,
201
+ )
202
+ self._completed = result.is_perfect
203
+ reward = result.reward_breakdown.total_score
204
+ message = self._submission_message(result)
205
+ return self._build_observation(
206
+ status="evaluated" if result.is_perfect else "in_progress",
207
+ message=message,
208
+ reward=reward,
209
+ done=result.is_perfect,
210
+ submitted_rows=result.preview.normalized_rows,
211
+ reward_breakdown=result.reward_breakdown,
212
+ submission_preview=result.preview,
213
+ issues=result.issues,
214
+ correct_row_count=result.matched_rows,
215
+ analysis_result=self._last_analysis_result,
216
+ generated_rows=generated_rows,
217
+ )
218
+
219
+ @property
220
+ def state(self) -> State:
221
+ """Return current episode state."""
222
+ return self._state
223
+
224
+ def _build_observation(
225
+ self,
226
+ *,
227
+ status: str,
228
+ message: str,
229
+ reward: float,
230
+ done: bool,
231
+ submitted_rows=None,
232
+ reward_breakdown=None,
233
+ submission_preview=None,
234
+ issues=None,
235
+ correct_row_count: int = 0,
236
+ analysis_result=None,
237
+ generated_rows=None,
238
+ ) -> MetricTrackerRlObservation:
239
+ assert self._episode is not None
240
+ metadata = {
241
+ "step": self._state.step_count,
242
+ "current_state": self.state.model_dump(),
243
+ "task_id": self._active_task.task_id,
244
+ "objective": self._active_task.objective,
245
+ "grader_name": self._active_task.grader_name,
246
+ "seed": self._episode.config.seed,
247
+ "scenario_family": self._episode.config.scenario_family,
248
+ "difficulty": self._episode.config.difficulty,
249
+ "anomaly_density": self._episode.config.anomaly_density,
250
+ "anomaly_count": self._episode.config.anomaly_count,
251
+ }
252
+ return MetricTrackerRlObservation(
253
+ task_id=self._active_task.task_id,
254
+ status=status,
255
+ message=message,
256
+ instruction=self._active_task.instruction,
257
+ conversion_metric_definitions=list(self._generator.config.conversion_definitions),
258
+ available_synthetic_generator_methods=available_synthetic_generator_methods(),
259
+ applied_synthetic_generators=(
260
+ self._episode.applied_synthetic_generators
261
+ if self._debug_mode or self._expose_applied_generators
262
+ else []
263
+ ),
264
+ available_methods=available_analysis_methods(),
265
+ available_tasks=available_task_specs(),
266
+ daily_metrics=[],
267
+ hourly_metrics=[],
268
+ analysis_result=analysis_result,
269
+ generated_rows=generated_rows or [],
270
+ submitted_rows=submitted_rows or [],
271
+ submission_preview=submission_preview,
272
+ submission_issues=issues or [],
273
+ reward_breakdown=reward_breakdown,
274
+ expected_row_count=len(self._episode.expected_rows),
275
+ correct_row_count=correct_row_count,
276
+ reward=reward,
277
+ done=done,
278
+ config=metadata,
279
+ debug=(
280
+ {
281
+ "task_id": self._active_task.task_id,
282
+ "expected_payload": [row.model_dump() for row in self._episode.expected_rows],
283
+ "anomaly_schedule": self._episode.anomaly_schedule,
284
+ "reward_breakdown": reward_breakdown.model_dump() if reward_breakdown else None,
285
+ "issues": [item.model_dump() for item in (issues or [])],
286
+ }
287
+ if self._debug_mode
288
+ else None
289
+ ),
290
+ )
291
+
292
+ def _run_analysis(self, method_name: str, arguments: dict) -> dict:
293
+ toolkit = SharedAnalysisToolkit(
294
+ AnalysisContext(
295
+ daily_metrics=self._episode.daily_metrics,
296
+ hourly_metrics=self._episode.hourly_metrics,
297
+ conversion_definitions=list(self._generator.config.conversion_definitions),
298
+ instruction=self._active_task.instruction,
299
+ config={
300
+ "task_id": self._active_task.task_id,
301
+ "objective": self._active_task.objective,
302
+ "grader_name": self._active_task.grader_name,
303
+ **self._episode.config.__dict__,
304
+ },
305
+ )
306
+ )
307
+ if method_name == "task_overview":
308
+ result = toolkit.task_overview()
309
+ elif method_name == "list_dates":
310
+ result = toolkit.list_dates()
311
+ elif method_name == "list_entities":
312
+ result = toolkit.list_entities()
313
+ elif method_name == "rows_for_date":
314
+ result = toolkit.rows_for_date(arguments["date"])
315
+ elif method_name == "hourly_rows_for_date":
316
+ result = toolkit.hourly_rows_for_date(arguments["date"])
317
+ elif method_name == "compare_rate_to_median":
318
+ result = toolkit.compare_rate_to_median(arguments["date"], arguments["entity_name"])
319
+ elif method_name == "compare_count_to_median":
320
+ result = toolkit.compare_count_to_median(arguments["date"], arguments["entity_name"])
321
+ elif method_name == "detect_funnel_break":
322
+ result = toolkit.detect_funnel_break(arguments["date"])
323
+ elif method_name == "check_impossible_counts":
324
+ result = toolkit.check_impossible_counts(arguments["date"])
325
+ elif method_name == "list_suspicious_dates":
326
+ result = toolkit.list_suspicious_dates(limit=arguments.get("limit", 10))
327
+ elif method_name == "preview_submission":
328
+ result = toolkit.preview_submission(arguments.get("rows", []))
329
+ elif method_name == "show_raw_data":
330
+ result = toolkit.show_raw_data(limit=arguments.get("limit", 5))
331
+ elif method_name == "get_metric_median":
332
+ result = toolkit.get_metric_median_multi(
333
+ metric_name=arguments.get("metric_name"),
334
+ metric_names=arguments.get("metric_names", []),
335
+ )
336
+ elif method_name == "get_metric_std_dev_from_median":
337
+ result = toolkit.get_metric_std_dev_from_median_multi(
338
+ metric_name=arguments.get("metric_name"),
339
+ metric_names=arguments.get("metric_names", []),
340
+ )
341
+ elif method_name == "get_rows_with_abs_diff_from_median_gt":
342
+ result = toolkit.get_rows_with_abs_diff_from_median_gt_multi(
343
+ metric_name=arguments.get("metric_name"),
344
+ metric_names=arguments.get("metric_names", []),
345
+ threshold=float(arguments["threshold"]),
346
+ )
347
+ elif method_name == "get_median_filter_rows":
348
+ result = toolkit.get_median_filter_rows_multi(
349
+ metric_name=arguments.get("metric_name"),
350
+ metric_names=arguments.get("metric_names", []),
351
+ threshold_multiplier=float(arguments["threshold_multiplier"]),
352
+ )
353
+ elif method_name == "get_rate_drop_from_median_rows":
354
+ result = toolkit.get_rate_drop_from_median_rows(
355
+ metric_name=arguments.get("metric_name"),
356
+ metric_names=arguments.get("metric_names", []),
357
+ threshold_multiplier=float(arguments["threshold_multiplier"]),
358
+ )
359
+ elif method_name == "get_rate_spike_from_median_rows":
360
+ result = toolkit.get_rate_spike_from_median_rows(
361
+ metric_name=arguments.get("metric_name"),
362
+ metric_names=arguments.get("metric_names", []),
363
+ threshold_multiplier=float(arguments["threshold_multiplier"]),
364
+ )
365
+ elif method_name == "get_absolute_drop_in_event_count_rows":
366
+ result = toolkit.get_absolute_drop_in_event_count_rows(
367
+ metric_name=arguments.get("metric_name"),
368
+ metric_names=arguments.get("metric_names", []),
369
+ threshold_multiplier=float(arguments["threshold_multiplier"]),
370
+ )
371
+ elif method_name == "get_absolute_spike_in_event_count_rows":
372
+ result = toolkit.get_absolute_spike_in_event_count_rows(
373
+ metric_name=arguments.get("metric_name"),
374
+ metric_names=arguments.get("metric_names", []),
375
+ threshold_multiplier=float(arguments["threshold_multiplier"]),
376
+ )
377
+ elif method_name == "get_funnel_break_rows":
378
+ result = toolkit.get_funnel_break_rows(
379
+ threshold_multiplier=float(arguments["threshold_multiplier"]),
380
+ )
381
+ elif method_name == "get_hourly_traffic_mix_shift_rows":
382
+ result = toolkit.get_hourly_traffic_mix_shift_rows(
383
+ threshold_multiplier=float(arguments["threshold_multiplier"]),
384
+ )
385
+ elif method_name == "get_instrumentation_data_quality_issue_rows":
386
+ result = toolkit.get_instrumentation_data_quality_issue_rows(
387
+ threshold_multiplier=float(arguments["threshold_multiplier"]),
388
+ )
389
+ elif method_name == "payload_generator":
390
+ result = toolkit.payload_generator(arguments.get("generator_methods", []))
391
+ else:
392
+ raise ValueError(f"Unsupported analysis method: {method_name}")
393
+
394
+ return {
395
+ "method": method_name,
396
+ "arguments": arguments,
397
+ "result": result,
398
+ }
399
+
400
+ @staticmethod
401
+ def _submission_message(result) -> str:
402
+ if result.is_perfect:
403
+ return "Submission is fully correct."
404
+ extra_issues = [issue for issue in result.issues if issue.issue_type == "extra_row"]
405
+ missing_count = result.reward_breakdown.missing_rows
406
+ if not extra_issues and missing_count > 0:
407
+ return (
408
+ "All submitted rows are anomalies, but a few are missing. "
409
+ f"Missing value count: {missing_count}."
410
+ )
411
+ if extra_issues:
412
+ first = extra_issues[0]
413
+ return f"Specific row is not an anomaly: {first.row_key}."
414
+ return (
415
+ f"Matched {result.reward_breakdown.matched_rows}/"
416
+ f"{result.reward_breakdown.expected_rows} expected rows. Review the feedback."
417
+ )
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv[core]>=0.2.0
2
+ fastapi>=0.115.0
3
+ uvicorn>=0.24.0
4
+ gradio>=5.0.0
5
+ pandas>=2.2.0
6
+ plotly>=5.24.0
tasks.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Named benchmark tasks and deterministic task graders."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+
7
+ try:
8
+ from .evaluation import EvaluationConfig, EvaluationResult, evaluate_submission
9
+ from .models import BenchmarkTaskSpec, MetricSubmissionRow
10
+ from .server.data_generator import EpisodeConfig
11
+ except ImportError:
12
+ from evaluation import EvaluationConfig, EvaluationResult, evaluate_submission
13
+ from models import BenchmarkTaskSpec, MetricSubmissionRow
14
+ from server.data_generator import EpisodeConfig
15
+
16
+
17
+ DEFAULT_GRADER_NAME = "deterministic_exact_match"
18
+
19
+
20
+ @dataclass(frozen=True)
21
+ class TaskSpec:
22
+ """A concrete benchmark task that an agent can solve and be graded on."""
23
+
24
+ task_id: str
25
+ difficulty: str
26
+ instruction: str
27
+ objective: str
28
+ seed: int
29
+ scenario_family: str
30
+ anomaly_density: str
31
+ anomaly_count: int
32
+ grader_name: str = DEFAULT_GRADER_NAME
33
+ evaluation_config: EvaluationConfig = field(default_factory=EvaluationConfig)
34
+
35
+ def build_episode_config(self) -> EpisodeConfig:
36
+ """Return the canonical episode configuration for this task."""
37
+ return EpisodeConfig(
38
+ seed=self.seed,
39
+ scenario_family=self.scenario_family,
40
+ difficulty=self.difficulty,
41
+ anomaly_density=self.anomaly_density,
42
+ anomaly_count=self.anomaly_count,
43
+ ).normalized()
44
+
45
+ def grade_submission(
46
+ self,
47
+ submitted_rows: list[dict] | list[MetricSubmissionRow],
48
+ expected_rows: list[MetricSubmissionRow],
49
+ *,
50
+ config: EvaluationConfig | None = None,
51
+ include_debug_expected: bool = False,
52
+ ) -> EvaluationResult:
53
+ """Grade one candidate submission for this task."""
54
+ return evaluate_submission(
55
+ submitted_rows,
56
+ expected_rows,
57
+ config=config or self.evaluation_config,
58
+ include_debug_expected=include_debug_expected,
59
+ )
60
+
61
+ def to_model(self) -> BenchmarkTaskSpec:
62
+ """Return a typed summary safe to expose in observations."""
63
+ return BenchmarkTaskSpec(
64
+ task_id=self.task_id,
65
+ difficulty=self.difficulty,
66
+ instruction=self.instruction,
67
+ objective=self.objective,
68
+ scenario_family=self.scenario_family,
69
+ anomaly_density=self.anomaly_density,
70
+ anomaly_count=self.anomaly_count,
71
+ grader_name=self.grader_name,
72
+ )
73
+
74
+
75
+ TASKS: dict[str, TaskSpec] = {
76
+ "easy_single_spike": TaskSpec(
77
+ task_id="easy_single_spike",
78
+ difficulty="easy",
79
+ instruction=(
80
+ "Investigate the seeded funnel dataset and submit the single anomalous row. "
81
+ "Use the shared analysis methods before submitting."
82
+ ),
83
+ objective=(
84
+ "Find the one obvious anomaly and submit exactly one correctly populated anomaly row."
85
+ ),
86
+ seed=11,
87
+ scenario_family="absolute_spike_in_event_count",
88
+ anomaly_density="low",
89
+ anomaly_count=1,
90
+ ),
91
+ "medium_mixed_pair": TaskSpec(
92
+ task_id="medium_mixed_pair",
93
+ difficulty="medium",
94
+ instruction=(
95
+ "Investigate the seeded funnel dataset and submit every anomalous row. "
96
+ "Expect both event-count and conversion-rate reasoning."
97
+ ),
98
+ objective=(
99
+ "Find the full set of medium-difficulty anomalies without submitting extras."
100
+ ),
101
+ seed=23,
102
+ scenario_family="mixed",
103
+ anomaly_density="medium",
104
+ anomaly_count=3,
105
+ ),
106
+ "hard_mixed_multi": TaskSpec(
107
+ task_id="hard_mixed_multi",
108
+ difficulty="hard",
109
+ instruction=(
110
+ "Investigate the seeded funnel dataset and submit every anomalous row. "
111
+ "Some anomalies are subtle, so use the analysis methods carefully and avoid over-submitting."
112
+ ),
113
+ objective=(
114
+ "Recover the complete set of hard mixed anomalies while preserving precision."
115
+ ),
116
+ seed=37,
117
+ scenario_family="mixed",
118
+ anomaly_density="high",
119
+ anomaly_count=5,
120
+ ),
121
+ }
122
+
123
+ DEFAULT_TASK_ORDER: tuple[str, ...] = (
124
+ "easy_single_spike",
125
+ "medium_mixed_pair",
126
+ "hard_mixed_multi",
127
+ )
128
+ DEFAULT_TASK_ID = DEFAULT_TASK_ORDER[0]
129
+
130
+
131
+ def get_task_spec(task_id: str) -> TaskSpec:
132
+ """Return the task spec for a known task id."""
133
+ try:
134
+ return TASKS[task_id]
135
+ except KeyError as exc:
136
+ raise ValueError(f"Unsupported task_id: {task_id}") from exc
137
+
138
+
139
+ def available_task_specs() -> list[BenchmarkTaskSpec]:
140
+ """Return typed summaries for all named benchmark tasks."""
141
+ return [TASKS[task_id].to_model() for task_id in DEFAULT_TASK_ORDER]
tests/test_metric_tracker_rl.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from metric_tracker_rl.analysis_tools import AnalysisContext, SharedAnalysisToolkit
4
+ from metric_tracker_rl.evaluation import evaluate_submission
5
+ from metric_tracker_rl.models import MetricSubmissionRow
6
+ from metric_tracker_rl.server.data_generator import ALL_SCENARIO_FAMILIES, EpisodeConfig, MetricDataGenerator
7
+ from metric_tracker_rl.server.metric_tracker_rl_environment import MetricTrackerRlEnvironment
8
+ from metric_tracker_rl import MetricTrackerRlAction
9
+ from metric_tracker_rl.models import PayloadGeneratorMethod
10
+ from metric_tracker_rl.tasks import DEFAULT_TASK_ORDER, TASKS, get_task_spec
11
+
12
+
13
+ def _toolkit_for(seed: int = 11, scenario_family: str = "mixed") -> tuple[SharedAnalysisToolkit, list[MetricSubmissionRow]]:
14
+ generator = MetricDataGenerator()
15
+ episode = generator.generate_episode(
16
+ EpisodeConfig(
17
+ seed=seed,
18
+ scenario_family=scenario_family,
19
+ difficulty="medium",
20
+ anomaly_density="medium",
21
+ anomaly_count=5,
22
+ )
23
+ )
24
+ toolkit = SharedAnalysisToolkit(
25
+ AnalysisContext(
26
+ daily_metrics=episode.daily_metrics,
27
+ hourly_metrics=episode.hourly_metrics,
28
+ conversion_definitions=list(generator.config.conversion_definitions),
29
+ config=episode.config.__dict__,
30
+ )
31
+ )
32
+ return toolkit, episode.expected_rows
33
+
34
+
35
+ def test_seed_reproducibility():
36
+ generator = MetricDataGenerator()
37
+ config = EpisodeConfig(seed=17, scenario_family="mixed", difficulty="hard", anomaly_density="high")
38
+ first = generator.generate_episode(config)
39
+ second = generator.generate_episode(config)
40
+
41
+ assert [row.model_dump() for row in first.daily_metrics] == [row.model_dump() for row in second.daily_metrics]
42
+ assert [row.model_dump() for row in first.hourly_metrics] == [row.model_dump() for row in second.hourly_metrics]
43
+ assert [row.model_dump() for row in first.expected_rows] == [row.model_dump() for row in second.expected_rows]
44
+
45
+
46
+ def test_anomaly_variety():
47
+ generator = MetricDataGenerator()
48
+ family_results = {}
49
+ for family in ALL_SCENARIO_FAMILIES[1:]:
50
+ episode = generator.generate_episode(
51
+ EpisodeConfig(
52
+ seed=7,
53
+ scenario_family=family,
54
+ difficulty="medium",
55
+ anomaly_density="medium",
56
+ anomaly_count=5,
57
+ )
58
+ )
59
+ family_results[family] = {row.anomaly_type for row in episode.expected_rows}
60
+
61
+ assert family_results["rate_drop_from_median"] == {"rate_drop_from_median"}
62
+ assert family_results["rate_spike_from_median"] == {"rate_spike_from_median"}
63
+ assert family_results["absolute_drop_in_event_count"] == {"absolute_drop_in_event_count"}
64
+ assert family_results["absolute_spike_in_event_count"] == {"absolute_spike_in_event_count"}
65
+ assert family_results["funnel_break"] == {"funnel_break"}
66
+ assert family_results["hourly_traffic_mix_shift"] == {"hourly_traffic_mix_shift"}
67
+ assert family_results["instrumentation_data_quality_issue"] == {"instrumentation_data_quality_issue"}
68
+
69
+ mixed = generator.generate_episode(
70
+ EpisodeConfig(
71
+ seed=7,
72
+ scenario_family="mixed",
73
+ difficulty="medium",
74
+ anomaly_density="medium",
75
+ anomaly_count=5,
76
+ )
77
+ )
78
+ assert len(mixed.expected_rows) == 5
79
+ assert {row.anomaly_type for row in mixed.expected_rows}.issubset(
80
+ {
81
+ "rate_drop_from_median",
82
+ "rate_spike_from_median",
83
+ "absolute_drop_in_event_count",
84
+ "absolute_spike_in_event_count",
85
+ }
86
+ )
87
+ assert len({row.anomaly_type for row in mixed.expected_rows}) >= 2
88
+
89
+
90
+ def test_evaluator_scores_perfect_submission():
91
+ _, expected_rows = _toolkit_for()
92
+ result = evaluate_submission(expected_rows, expected_rows)
93
+
94
+ assert result.is_perfect is True
95
+ assert result.reward_breakdown.total_score == 1.0
96
+ assert result.reward_breakdown.extra_rows == 0
97
+ assert result.reward_breakdown.duplicate_rows == 0
98
+ assert result.reward_breakdown.invalid_rows == 0
99
+
100
+
101
+ def test_named_task_registry_covers_easy_medium_hard():
102
+ assert DEFAULT_TASK_ORDER == (
103
+ "easy_single_spike",
104
+ "medium_mixed_pair",
105
+ "hard_mixed_multi",
106
+ )
107
+ assert len(TASKS) == 3
108
+ assert {TASKS[task_id].difficulty for task_id in DEFAULT_TASK_ORDER} == {"easy", "medium", "hard"}
109
+ assert all(TASKS[task_id].grader_name for task_id in DEFAULT_TASK_ORDER)
110
+
111
+
112
+ def test_task_grader_scores_perfect_submission():
113
+ generator = MetricDataGenerator()
114
+ task = get_task_spec("medium_mixed_pair")
115
+ episode = generator.generate_episode(task.build_episode_config())
116
+
117
+ result = task.grade_submission(episode.expected_rows, episode.expected_rows)
118
+
119
+ assert result.is_perfect is True
120
+ assert result.reward_breakdown.total_score == 1.0
121
+
122
+
123
+ def test_duplicate_and_extra_rows_are_penalized():
124
+ _, expected_rows = _toolkit_for()
125
+ extra_row = MetricSubmissionRow(
126
+ date=expected_rows[0].date,
127
+ entity_type="event_count",
128
+ entity_name="nonexistent_metric",
129
+ anomaly_type="absolute_spike_in_event_count",
130
+ detection_method="compare_count_to_median",
131
+ baseline_value=100.0,
132
+ observed_value=120.0,
133
+ delta_value=20.0,
134
+ severity="low",
135
+ )
136
+ submitted = [expected_rows[0], expected_rows[0], extra_row]
137
+ result = evaluate_submission(submitted, expected_rows)
138
+
139
+ assert result.is_perfect is False
140
+ assert result.reward_breakdown.duplicate_rows == 1
141
+ assert result.reward_breakdown.extra_rows == 1
142
+ assert result.reward_breakdown.total_score < 1.0
143
+
144
+
145
+ def test_shared_methods_behave_consistently():
146
+ toolkit, expected_rows = _toolkit_for(seed=3, scenario_family="mixed")
147
+ overview = toolkit.task_overview()
148
+ suspicious = toolkit.list_suspicious_dates(limit=5)
149
+ first_row = expected_rows[0]
150
+
151
+ assert overview["payload_schema"][0] == "date"
152
+ method_names = {item["name"] for item in overview["available_methods"]}
153
+ assert "show_raw_data" in method_names
154
+ assert "get_median_filter_rows" in method_names
155
+ assert "get_funnel_break_rows" in method_names
156
+ assert "get_hourly_traffic_mix_shift_rows" in method_names
157
+ assert "get_instrumentation_data_quality_issue_rows" in method_names
158
+ assert "payload_generator" in method_names
159
+ assert len(suspicious["dates"]) == 5
160
+
161
+ if first_row.detection_method == "compare_rate_to_median":
162
+ result = toolkit.compare_rate_to_median(first_row.date, first_row.entity_name)
163
+ assert result["anomaly_type"] == first_row.anomaly_type
164
+ elif first_row.detection_method == "compare_count_to_median":
165
+ result = toolkit.compare_count_to_median(first_row.date, first_row.entity_name)
166
+ assert result["anomaly_type"] == first_row.anomaly_type
167
+ elif first_row.detection_method == "detect_funnel_break":
168
+ result = toolkit.detect_funnel_break(first_row.date)
169
+ assert any(item["entity_name"] == first_row.entity_name for item in result["candidates"])
170
+ elif first_row.detection_method == "check_impossible_counts":
171
+ result = toolkit.check_impossible_counts(first_row.date)
172
+ assert result["issue_count"] > 0
173
+ else:
174
+ result = toolkit.hourly_rows_for_date(first_row.date)
175
+ assert result["found"] is True
176
+
177
+ raw = toolkit.show_raw_data(limit=3)
178
+ assert raw["returned_rows"] == 3
179
+ median_stats = toolkit.get_metric_median("app_open_to_order_placed")
180
+ std_stats = toolkit.get_metric_std_dev_from_median("app_open_to_order_placed")
181
+ assert median_stats["sample_size"] > 0
182
+ assert std_stats["std_dev_from_median"] >= 0
183
+
184
+
185
+ def test_debug_mode_is_gated():
186
+ env = MetricTrackerRlEnvironment()
187
+ observation = env.reset()
188
+
189
+ assert observation.debug is None
190
+ assert observation.daily_metrics == []
191
+ assert observation.hourly_metrics == []
192
+
193
+ try:
194
+ env.export_debug_snapshot()
195
+ except RuntimeError as exc:
196
+ assert "Debug mode is disabled" in str(exc)
197
+ else:
198
+ raise AssertionError("Expected debug snapshot to be gated.")
199
+
200
+ env.set_debug_mode(True)
201
+ debug_observation = env.reset()
202
+ snapshot = env.export_debug_snapshot()
203
+
204
+ assert debug_observation.debug is not None
205
+ assert "expected_payload" in snapshot
206
+ assert "applied_synthetic_generators" in snapshot
207
+
208
+
209
+ def test_reset_exposes_synthetic_generator_metadata():
210
+ env = MetricTrackerRlEnvironment()
211
+ observation = env.reset()
212
+
213
+ assert observation.task_id == "easy_single_spike"
214
+ assert len(observation.available_tasks) == 3
215
+ assert observation.available_synthetic_generator_methods
216
+ assert observation.available_synthetic_generator_methods[0].name == "metric_stddev_shift"
217
+ assert observation.applied_synthetic_generators == []
218
+
219
+
220
+ def test_named_task_reset_updates_instruction_and_config():
221
+ env = MetricTrackerRlEnvironment()
222
+ observation = env.reset(task_id="hard_mixed_multi")
223
+
224
+ assert observation.task_id == "hard_mixed_multi"
225
+ assert observation.config["task_id"] == "hard_mixed_multi"
226
+ assert observation.config["grader_name"] == "deterministic_exact_match"
227
+ assert observation.config["difficulty"] == "hard"
228
+ assert observation.instruction == get_task_spec("hard_mixed_multi").instruction
229
+
230
+
231
+ def test_custom_reset_anomalies_support_specific_dates_and_stddev_factor():
232
+ env = MetricTrackerRlEnvironment()
233
+ observation = env.reset(
234
+ seed=21,
235
+ scenario_family="mixed",
236
+ anomaly_count=2,
237
+ anomalies=[
238
+ {
239
+ "method_name": "metric_stddev_shift",
240
+ "metric_name": "orders_placed",
241
+ "date": "2026-03-20",
242
+ "stddev_factor": 2.5,
243
+ "direction": "down",
244
+ },
245
+ {
246
+ "method_name": "metric_stddev_shift",
247
+ "metric_name": "app_open_to_order_placed",
248
+ "date": "2026-03-25",
249
+ "stddev_factor": 2.0,
250
+ "direction": "up",
251
+ },
252
+ ],
253
+ )
254
+
255
+ applied = {item.date: item for item in observation.applied_synthetic_generators}
256
+ assert "2026-03-20" in applied
257
+ assert "2026-03-25" in applied
258
+ assert applied["2026-03-20"].metric_name == "orders_placed"
259
+ assert applied["2026-03-20"].stddev_factor == 2.5
260
+ assert applied["2026-03-20"].threshold_value == round(
261
+ applied["2026-03-20"].std_dev_from_median * 2.5,
262
+ 4,
263
+ )
264
+ assert applied["2026-03-25"].metric_type == "conversion_rate"
265
+
266
+
267
+ def test_analysis_methods_run_through_step_api():
268
+ env = MetricTrackerRlEnvironment()
269
+ env.reset()
270
+ analyzed = env.step(
271
+ MetricTrackerRlAction(
272
+ analysis_method="list_suspicious_dates",
273
+ analysis_args={"limit": 3},
274
+ )
275
+ )
276
+
277
+ assert analyzed.analysis_result is not None
278
+ assert analyzed.analysis_result["method"] == "list_suspicious_dates"
279
+ assert len(analyzed.analysis_result["result"]["dates"]) == 3
280
+
281
+
282
+ def test_payload_generator_method_creates_rows():
283
+ toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed")
284
+ result = toolkit.get_median_filter_rows("app_open_to_order_placed", 2.0)
285
+ assert result["details"][0]["threshold"] >= 0
286
+ assert isinstance(result["generated_rows"], list)
287
+
288
+
289
+ def test_payload_generator_method_without_metric_runs_all_metrics():
290
+ toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed")
291
+ result = toolkit.get_median_filter_rows_multi(metric_name=None, metric_names=[], threshold_multiplier=2.0)
292
+ assert "app_opens" in result["metric_names"]
293
+ assert "app_open_to_order_placed" in result["metric_names"]
294
+ assert isinstance(result["generated_rows"], list)
295
+
296
+
297
+ def test_family_specific_generator_methods_create_matching_anomaly_types():
298
+ cases = [
299
+ ("rate_drop_from_median", "get_rate_drop_from_median_rows", 1.5),
300
+ ("rate_spike_from_median", "get_rate_spike_from_median_rows", 1.5),
301
+ ("absolute_drop_in_event_count", "get_absolute_drop_in_event_count_rows", 1.5),
302
+ ("absolute_spike_in_event_count", "get_absolute_spike_in_event_count_rows", 1.5),
303
+ ("funnel_break", "get_funnel_break_rows", 1.0),
304
+ ("hourly_traffic_mix_shift", "get_hourly_traffic_mix_shift_rows", 1.0),
305
+ ("instrumentation_data_quality_issue", "get_instrumentation_data_quality_issue_rows", 1.0),
306
+ ]
307
+
308
+ for family, method_name, threshold_multiplier in cases:
309
+ toolkit, _ = _toolkit_for(seed=7, scenario_family=family)
310
+ method = getattr(toolkit, method_name)
311
+ if "rate_" in method_name or "event_count" in method_name:
312
+ result = method(metric_name=None, metric_names=[], threshold_multiplier=threshold_multiplier)
313
+ else:
314
+ result = method(threshold_multiplier=threshold_multiplier)
315
+ assert result["generated_rows"], method_name
316
+ assert {row["anomaly_type"] for row in result["generated_rows"]} == {family}
317
+
318
+
319
+ def test_metric_summary_methods_without_metric_run_all_metrics():
320
+ toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed")
321
+ medians = toolkit.get_metric_median_multi(metric_name=None, metric_names=[])
322
+ stds = toolkit.get_metric_std_dev_from_median_multi(metric_name=None, metric_names=[])
323
+ diffs = toolkit.get_rows_with_abs_diff_from_median_gt_multi(
324
+ metric_name=None,
325
+ metric_names=[],
326
+ threshold=1.0,
327
+ )
328
+ assert "app_opens" in medians["metric_names"]
329
+ assert "app_open_to_order_placed" in stds["metric_names"]
330
+ assert len(medians["results"]) == len(medians["metric_names"])
331
+ assert len(stds["results"]) == len(stds["metric_names"])
332
+ assert len(diffs["results"]) == len(diffs["metric_names"])
333
+
334
+
335
+ def test_generator_submission_path_runs():
336
+ env = MetricTrackerRlEnvironment()
337
+ env.reset()
338
+ result = env.step(
339
+ MetricTrackerRlAction(
340
+ payload_generators=[
341
+ PayloadGeneratorMethod(
342
+ method_name="get_median_filter_rows",
343
+ metric_name="app_open_to_order_placed",
344
+ threshold_multiplier=2.0,
345
+ )
346
+ ]
347
+ )
348
+ )
349
+ assert result.generated_rows is not None
350
+ assert result.status in {"evaluated", "in_progress", "completed"}
351
+
352
+
353
+ def test_generator_submission_path_supports_family_specific_methods():
354
+ env = MetricTrackerRlEnvironment()
355
+ env.reset(task_id="hard_mixed_multi", scenario_family="funnel_break")
356
+ result = env.step(
357
+ MetricTrackerRlAction(
358
+ payload_generators=[
359
+ PayloadGeneratorMethod(
360
+ method_name="get_funnel_break_rows",
361
+ threshold_multiplier=1.0,
362
+ )
363
+ ]
364
+ )
365
+ )
366
+ assert result.analysis_result is not None
367
+ assert result.analysis_result["result"]["generated_rows"] is not None
uv.lock ADDED
The diff for this file is too large to render. See raw diff