metric_tracker_rl / tests /test_metric_tracker_rl.py
kushalExplores's picture
Upload folder using huggingface_hub
9ba4021 verified
from __future__ import annotations
from metric_tracker_rl.analysis_tools import AnalysisContext, SharedAnalysisToolkit
from metric_tracker_rl.evaluation import evaluate_submission
from metric_tracker_rl.models import MetricSubmissionRow
from metric_tracker_rl.server.data_generator import ALL_SCENARIO_FAMILIES, EpisodeConfig, MetricDataGenerator
from metric_tracker_rl.server.metric_tracker_rl_environment import MetricTrackerRlEnvironment
from metric_tracker_rl import MetricTrackerRlAction
from metric_tracker_rl.models import PayloadGeneratorMethod
from metric_tracker_rl.tasks import DEFAULT_TASK_ORDER, TASKS, get_task_spec
def _toolkit_for(seed: int = 11, scenario_family: str = "mixed") -> tuple[SharedAnalysisToolkit, list[MetricSubmissionRow]]:
generator = MetricDataGenerator()
episode = generator.generate_episode(
EpisodeConfig(
seed=seed,
scenario_family=scenario_family,
difficulty="medium",
anomaly_density="medium",
anomaly_count=5,
)
)
toolkit = SharedAnalysisToolkit(
AnalysisContext(
daily_metrics=episode.daily_metrics,
hourly_metrics=episode.hourly_metrics,
conversion_definitions=list(generator.config.conversion_definitions),
config=episode.config.__dict__,
)
)
return toolkit, episode.expected_rows
def test_seed_reproducibility():
generator = MetricDataGenerator()
config = EpisodeConfig(seed=17, scenario_family="mixed", difficulty="hard", anomaly_density="high")
first = generator.generate_episode(config)
second = generator.generate_episode(config)
assert [row.model_dump() for row in first.daily_metrics] == [row.model_dump() for row in second.daily_metrics]
assert [row.model_dump() for row in first.hourly_metrics] == [row.model_dump() for row in second.hourly_metrics]
assert [row.model_dump() for row in first.expected_rows] == [row.model_dump() for row in second.expected_rows]
def test_anomaly_variety():
generator = MetricDataGenerator()
family_results = {}
for family in ALL_SCENARIO_FAMILIES[1:]:
episode = generator.generate_episode(
EpisodeConfig(
seed=7,
scenario_family=family,
difficulty="medium",
anomaly_density="medium",
anomaly_count=5,
)
)
family_results[family] = {row.anomaly_type for row in episode.expected_rows}
assert family_results["rate_drop_from_median"] == {"rate_drop_from_median"}
assert family_results["rate_spike_from_median"] == {"rate_spike_from_median"}
assert family_results["absolute_drop_in_event_count"] == {"absolute_drop_in_event_count"}
assert family_results["absolute_spike_in_event_count"] == {"absolute_spike_in_event_count"}
assert family_results["funnel_break"] == {"funnel_break"}
assert family_results["hourly_traffic_mix_shift"] == {"hourly_traffic_mix_shift"}
assert family_results["instrumentation_data_quality_issue"] == {"instrumentation_data_quality_issue"}
mixed = generator.generate_episode(
EpisodeConfig(
seed=7,
scenario_family="mixed",
difficulty="medium",
anomaly_density="medium",
anomaly_count=5,
)
)
assert len(mixed.expected_rows) == 5
assert {row.anomaly_type for row in mixed.expected_rows}.issubset(
{
"rate_drop_from_median",
"rate_spike_from_median",
"absolute_drop_in_event_count",
"absolute_spike_in_event_count",
}
)
assert len({row.anomaly_type for row in mixed.expected_rows}) >= 2
def test_evaluator_scores_perfect_submission():
_, expected_rows = _toolkit_for()
result = evaluate_submission(expected_rows, expected_rows)
assert result.is_perfect is True
assert 0.0 < result.reward_breakdown.total_score < 1.0
assert result.reward_breakdown.total_score == 0.999999
assert result.reward_breakdown.extra_rows == 0
assert result.reward_breakdown.duplicate_rows == 0
assert result.reward_breakdown.invalid_rows == 0
def test_named_task_registry_covers_easy_medium_hard():
assert DEFAULT_TASK_ORDER == (
"easy_single_spike",
"medium_mixed_pair",
"hard_mixed_multi",
)
assert len(TASKS) == 3
assert {TASKS[task_id].difficulty for task_id in DEFAULT_TASK_ORDER} == {"easy", "medium", "hard"}
assert all(TASKS[task_id].grader_name for task_id in DEFAULT_TASK_ORDER)
def test_task_grader_scores_perfect_submission():
generator = MetricDataGenerator()
task = get_task_spec("medium_mixed_pair")
episode = generator.generate_episode(task.build_episode_config())
result = task.grade_submission(episode.expected_rows, episode.expected_rows)
assert result.is_perfect is True
assert 0.0 < result.reward_breakdown.total_score < 1.0
assert result.reward_breakdown.total_score == 0.999999
def test_duplicate_and_extra_rows_are_penalized():
_, expected_rows = _toolkit_for()
extra_row = MetricSubmissionRow(
date=expected_rows[0].date,
entity_type="event_count",
entity_name="nonexistent_metric",
anomaly_type="absolute_spike_in_event_count",
detection_method="compare_count_to_median",
baseline_value=100.0,
observed_value=120.0,
delta_value=20.0,
severity="low",
)
submitted = [expected_rows[0], expected_rows[0], extra_row]
result = evaluate_submission(submitted, expected_rows)
assert result.is_perfect is False
assert result.reward_breakdown.duplicate_rows == 1
assert result.reward_breakdown.extra_rows == 1
assert result.reward_breakdown.total_score < 1.0
def test_shared_methods_behave_consistently():
toolkit, expected_rows = _toolkit_for(seed=3, scenario_family="mixed")
overview = toolkit.task_overview()
suspicious = toolkit.list_suspicious_dates(limit=5)
first_row = expected_rows[0]
assert overview["payload_schema"][0] == "date"
method_names = {item["name"] for item in overview["available_methods"]}
assert "show_raw_data" in method_names
assert "get_median_filter_rows" in method_names
assert "get_funnel_break_rows" in method_names
assert "get_hourly_traffic_mix_shift_rows" in method_names
assert "get_instrumentation_data_quality_issue_rows" in method_names
assert "payload_generator" in method_names
assert len(suspicious["dates"]) == 5
if first_row.detection_method == "compare_rate_to_median":
result = toolkit.compare_rate_to_median(first_row.date, first_row.entity_name)
assert result["anomaly_type"] == first_row.anomaly_type
elif first_row.detection_method == "compare_count_to_median":
result = toolkit.compare_count_to_median(first_row.date, first_row.entity_name)
assert result["anomaly_type"] == first_row.anomaly_type
elif first_row.detection_method == "detect_funnel_break":
result = toolkit.detect_funnel_break(first_row.date)
assert any(item["entity_name"] == first_row.entity_name for item in result["candidates"])
elif first_row.detection_method == "check_impossible_counts":
result = toolkit.check_impossible_counts(first_row.date)
assert result["issue_count"] > 0
else:
result = toolkit.hourly_rows_for_date(first_row.date)
assert result["found"] is True
raw = toolkit.show_raw_data(limit=3)
assert raw["returned_rows"] == 3
median_stats = toolkit.get_metric_median("app_open_to_order_placed")
std_stats = toolkit.get_metric_std_dev_from_median("app_open_to_order_placed")
assert median_stats["sample_size"] > 0
assert std_stats["std_dev_from_median"] >= 0
def test_debug_mode_is_gated():
env = MetricTrackerRlEnvironment()
observation = env.reset()
assert observation.debug is None
assert observation.daily_metrics == []
assert observation.hourly_metrics == []
try:
env.export_debug_snapshot()
except RuntimeError as exc:
assert "Debug mode is disabled" in str(exc)
else:
raise AssertionError("Expected debug snapshot to be gated.")
env.set_debug_mode(True)
debug_observation = env.reset()
snapshot = env.export_debug_snapshot()
assert debug_observation.debug is not None
assert "expected_payload" in snapshot
assert "applied_synthetic_generators" in snapshot
def test_reset_exposes_synthetic_generator_metadata():
env = MetricTrackerRlEnvironment()
observation = env.reset()
assert observation.task_id == "easy_single_spike"
assert len(observation.available_tasks) == 3
assert observation.available_synthetic_generator_methods
assert observation.available_synthetic_generator_methods[0].name == "metric_stddev_shift"
assert observation.applied_synthetic_generators == []
def test_named_task_reset_updates_instruction_and_config():
env = MetricTrackerRlEnvironment()
observation = env.reset(task_id="hard_mixed_multi")
assert observation.task_id == "hard_mixed_multi"
assert observation.config["task_id"] == "hard_mixed_multi"
assert observation.config["grader_name"] == "deterministic_exact_match"
assert observation.config["difficulty"] == "hard"
assert observation.instruction == get_task_spec("hard_mixed_multi").instruction
def test_custom_reset_anomalies_support_specific_dates_and_stddev_factor():
env = MetricTrackerRlEnvironment()
observation = env.reset(
seed=21,
scenario_family="mixed",
anomaly_count=2,
anomalies=[
{
"method_name": "metric_stddev_shift",
"metric_name": "orders_placed",
"date": "2026-03-20",
"stddev_factor": 2.5,
"direction": "down",
},
{
"method_name": "metric_stddev_shift",
"metric_name": "app_open_to_order_placed",
"date": "2026-03-25",
"stddev_factor": 2.0,
"direction": "up",
},
],
)
applied = {item.date: item for item in observation.applied_synthetic_generators}
assert "2026-03-20" in applied
assert "2026-03-25" in applied
assert applied["2026-03-20"].metric_name == "orders_placed"
assert applied["2026-03-20"].stddev_factor == 2.5
assert applied["2026-03-20"].threshold_value == round(
applied["2026-03-20"].std_dev_from_median * 2.5,
4,
)
assert applied["2026-03-25"].metric_type == "conversion_rate"
def test_analysis_methods_run_through_step_api():
env = MetricTrackerRlEnvironment()
env.reset()
analyzed = env.step(
MetricTrackerRlAction(
analysis_method="list_suspicious_dates",
analysis_args={"limit": 3},
)
)
assert analyzed.analysis_result is not None
assert analyzed.analysis_result["method"] == "list_suspicious_dates"
assert len(analyzed.analysis_result["result"]["dates"]) == 3
def test_payload_generator_method_creates_rows():
toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed")
result = toolkit.get_median_filter_rows("app_open_to_order_placed", 2.0)
assert result["details"][0]["threshold"] >= 0
assert isinstance(result["generated_rows"], list)
def test_payload_generator_method_without_metric_runs_all_metrics():
toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed")
result = toolkit.get_median_filter_rows_multi(metric_name=None, metric_names=[], threshold_multiplier=2.0)
assert "app_opens" in result["metric_names"]
assert "app_open_to_order_placed" in result["metric_names"]
assert isinstance(result["generated_rows"], list)
def test_family_specific_generator_methods_create_matching_anomaly_types():
cases = [
("rate_drop_from_median", "get_rate_drop_from_median_rows", 1.5),
("rate_spike_from_median", "get_rate_spike_from_median_rows", 1.5),
("absolute_drop_in_event_count", "get_absolute_drop_in_event_count_rows", 1.5),
("absolute_spike_in_event_count", "get_absolute_spike_in_event_count_rows", 1.5),
("funnel_break", "get_funnel_break_rows", 1.0),
("hourly_traffic_mix_shift", "get_hourly_traffic_mix_shift_rows", 1.0),
("instrumentation_data_quality_issue", "get_instrumentation_data_quality_issue_rows", 1.0),
]
for family, method_name, threshold_multiplier in cases:
toolkit, _ = _toolkit_for(seed=7, scenario_family=family)
method = getattr(toolkit, method_name)
if "rate_" in method_name or "event_count" in method_name:
result = method(metric_name=None, metric_names=[], threshold_multiplier=threshold_multiplier)
else:
result = method(threshold_multiplier=threshold_multiplier)
assert result["generated_rows"], method_name
assert {row["anomaly_type"] for row in result["generated_rows"]} == {family}
def test_metric_summary_methods_without_metric_run_all_metrics():
toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed")
medians = toolkit.get_metric_median_multi(metric_name=None, metric_names=[])
stds = toolkit.get_metric_std_dev_from_median_multi(metric_name=None, metric_names=[])
diffs = toolkit.get_rows_with_abs_diff_from_median_gt_multi(
metric_name=None,
metric_names=[],
threshold=1.0,
)
assert "app_opens" in medians["metric_names"]
assert "app_open_to_order_placed" in stds["metric_names"]
assert len(medians["results"]) == len(medians["metric_names"])
assert len(stds["results"]) == len(stds["metric_names"])
assert len(diffs["results"]) == len(diffs["metric_names"])
def test_generator_submission_path_runs():
env = MetricTrackerRlEnvironment()
env.reset()
result = env.step(
MetricTrackerRlAction(
payload_generators=[
PayloadGeneratorMethod(
method_name="get_median_filter_rows",
metric_name="app_open_to_order_placed",
threshold_multiplier=2.0,
)
]
)
)
assert result.generated_rows is not None
assert result.status in {"evaluated", "in_progress", "completed"}
def test_generator_submission_path_supports_family_specific_methods():
env = MetricTrackerRlEnvironment()
env.reset(task_id="hard_mixed_multi", scenario_family="funnel_break")
result = env.step(
MetricTrackerRlAction(
payload_generators=[
PayloadGeneratorMethod(
method_name="get_funnel_break_rows",
threshold_multiplier=1.0,
)
]
)
)
assert result.analysis_result is not None
assert result.analysis_result["result"]["generated_rows"] is not None