Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from metric_tracker_rl.analysis_tools import AnalysisContext, SharedAnalysisToolkit | |
| from metric_tracker_rl.evaluation import evaluate_submission | |
| from metric_tracker_rl.models import MetricSubmissionRow | |
| from metric_tracker_rl.server.data_generator import ALL_SCENARIO_FAMILIES, EpisodeConfig, MetricDataGenerator | |
| from metric_tracker_rl.server.metric_tracker_rl_environment import MetricTrackerRlEnvironment | |
| from metric_tracker_rl import MetricTrackerRlAction | |
| from metric_tracker_rl.models import PayloadGeneratorMethod | |
| from metric_tracker_rl.tasks import DEFAULT_TASK_ORDER, TASKS, get_task_spec | |
| def _toolkit_for(seed: int = 11, scenario_family: str = "mixed") -> tuple[SharedAnalysisToolkit, list[MetricSubmissionRow]]: | |
| generator = MetricDataGenerator() | |
| episode = generator.generate_episode( | |
| EpisodeConfig( | |
| seed=seed, | |
| scenario_family=scenario_family, | |
| difficulty="medium", | |
| anomaly_density="medium", | |
| anomaly_count=5, | |
| ) | |
| ) | |
| toolkit = SharedAnalysisToolkit( | |
| AnalysisContext( | |
| daily_metrics=episode.daily_metrics, | |
| hourly_metrics=episode.hourly_metrics, | |
| conversion_definitions=list(generator.config.conversion_definitions), | |
| config=episode.config.__dict__, | |
| ) | |
| ) | |
| return toolkit, episode.expected_rows | |
| def test_seed_reproducibility(): | |
| generator = MetricDataGenerator() | |
| config = EpisodeConfig(seed=17, scenario_family="mixed", difficulty="hard", anomaly_density="high") | |
| first = generator.generate_episode(config) | |
| second = generator.generate_episode(config) | |
| assert [row.model_dump() for row in first.daily_metrics] == [row.model_dump() for row in second.daily_metrics] | |
| assert [row.model_dump() for row in first.hourly_metrics] == [row.model_dump() for row in second.hourly_metrics] | |
| assert [row.model_dump() for row in first.expected_rows] == [row.model_dump() for row in second.expected_rows] | |
| def test_anomaly_variety(): | |
| generator = MetricDataGenerator() | |
| family_results = {} | |
| for family in ALL_SCENARIO_FAMILIES[1:]: | |
| episode = generator.generate_episode( | |
| EpisodeConfig( | |
| seed=7, | |
| scenario_family=family, | |
| difficulty="medium", | |
| anomaly_density="medium", | |
| anomaly_count=5, | |
| ) | |
| ) | |
| family_results[family] = {row.anomaly_type for row in episode.expected_rows} | |
| assert family_results["rate_drop_from_median"] == {"rate_drop_from_median"} | |
| assert family_results["rate_spike_from_median"] == {"rate_spike_from_median"} | |
| assert family_results["absolute_drop_in_event_count"] == {"absolute_drop_in_event_count"} | |
| assert family_results["absolute_spike_in_event_count"] == {"absolute_spike_in_event_count"} | |
| assert family_results["funnel_break"] == {"funnel_break"} | |
| assert family_results["hourly_traffic_mix_shift"] == {"hourly_traffic_mix_shift"} | |
| assert family_results["instrumentation_data_quality_issue"] == {"instrumentation_data_quality_issue"} | |
| mixed = generator.generate_episode( | |
| EpisodeConfig( | |
| seed=7, | |
| scenario_family="mixed", | |
| difficulty="medium", | |
| anomaly_density="medium", | |
| anomaly_count=5, | |
| ) | |
| ) | |
| assert len(mixed.expected_rows) == 5 | |
| assert {row.anomaly_type for row in mixed.expected_rows}.issubset( | |
| { | |
| "rate_drop_from_median", | |
| "rate_spike_from_median", | |
| "absolute_drop_in_event_count", | |
| "absolute_spike_in_event_count", | |
| } | |
| ) | |
| assert len({row.anomaly_type for row in mixed.expected_rows}) >= 2 | |
| def test_evaluator_scores_perfect_submission(): | |
| _, expected_rows = _toolkit_for() | |
| result = evaluate_submission(expected_rows, expected_rows) | |
| assert result.is_perfect is True | |
| assert 0.0 < result.reward_breakdown.total_score < 1.0 | |
| assert result.reward_breakdown.total_score == 0.999999 | |
| assert result.reward_breakdown.extra_rows == 0 | |
| assert result.reward_breakdown.duplicate_rows == 0 | |
| assert result.reward_breakdown.invalid_rows == 0 | |
| def test_named_task_registry_covers_easy_medium_hard(): | |
| assert DEFAULT_TASK_ORDER == ( | |
| "easy_single_spike", | |
| "medium_mixed_pair", | |
| "hard_mixed_multi", | |
| ) | |
| assert len(TASKS) == 3 | |
| assert {TASKS[task_id].difficulty for task_id in DEFAULT_TASK_ORDER} == {"easy", "medium", "hard"} | |
| assert all(TASKS[task_id].grader_name for task_id in DEFAULT_TASK_ORDER) | |
| def test_task_grader_scores_perfect_submission(): | |
| generator = MetricDataGenerator() | |
| task = get_task_spec("medium_mixed_pair") | |
| episode = generator.generate_episode(task.build_episode_config()) | |
| result = task.grade_submission(episode.expected_rows, episode.expected_rows) | |
| assert result.is_perfect is True | |
| assert 0.0 < result.reward_breakdown.total_score < 1.0 | |
| assert result.reward_breakdown.total_score == 0.999999 | |
| def test_duplicate_and_extra_rows_are_penalized(): | |
| _, expected_rows = _toolkit_for() | |
| extra_row = MetricSubmissionRow( | |
| date=expected_rows[0].date, | |
| entity_type="event_count", | |
| entity_name="nonexistent_metric", | |
| anomaly_type="absolute_spike_in_event_count", | |
| detection_method="compare_count_to_median", | |
| baseline_value=100.0, | |
| observed_value=120.0, | |
| delta_value=20.0, | |
| severity="low", | |
| ) | |
| submitted = [expected_rows[0], expected_rows[0], extra_row] | |
| result = evaluate_submission(submitted, expected_rows) | |
| assert result.is_perfect is False | |
| assert result.reward_breakdown.duplicate_rows == 1 | |
| assert result.reward_breakdown.extra_rows == 1 | |
| assert result.reward_breakdown.total_score < 1.0 | |
| def test_shared_methods_behave_consistently(): | |
| toolkit, expected_rows = _toolkit_for(seed=3, scenario_family="mixed") | |
| overview = toolkit.task_overview() | |
| suspicious = toolkit.list_suspicious_dates(limit=5) | |
| first_row = expected_rows[0] | |
| assert overview["payload_schema"][0] == "date" | |
| method_names = {item["name"] for item in overview["available_methods"]} | |
| assert "show_raw_data" in method_names | |
| assert "get_median_filter_rows" in method_names | |
| assert "get_funnel_break_rows" in method_names | |
| assert "get_hourly_traffic_mix_shift_rows" in method_names | |
| assert "get_instrumentation_data_quality_issue_rows" in method_names | |
| assert "payload_generator" in method_names | |
| assert len(suspicious["dates"]) == 5 | |
| if first_row.detection_method == "compare_rate_to_median": | |
| result = toolkit.compare_rate_to_median(first_row.date, first_row.entity_name) | |
| assert result["anomaly_type"] == first_row.anomaly_type | |
| elif first_row.detection_method == "compare_count_to_median": | |
| result = toolkit.compare_count_to_median(first_row.date, first_row.entity_name) | |
| assert result["anomaly_type"] == first_row.anomaly_type | |
| elif first_row.detection_method == "detect_funnel_break": | |
| result = toolkit.detect_funnel_break(first_row.date) | |
| assert any(item["entity_name"] == first_row.entity_name for item in result["candidates"]) | |
| elif first_row.detection_method == "check_impossible_counts": | |
| result = toolkit.check_impossible_counts(first_row.date) | |
| assert result["issue_count"] > 0 | |
| else: | |
| result = toolkit.hourly_rows_for_date(first_row.date) | |
| assert result["found"] is True | |
| raw = toolkit.show_raw_data(limit=3) | |
| assert raw["returned_rows"] == 3 | |
| median_stats = toolkit.get_metric_median("app_open_to_order_placed") | |
| std_stats = toolkit.get_metric_std_dev_from_median("app_open_to_order_placed") | |
| assert median_stats["sample_size"] > 0 | |
| assert std_stats["std_dev_from_median"] >= 0 | |
| def test_debug_mode_is_gated(): | |
| env = MetricTrackerRlEnvironment() | |
| observation = env.reset() | |
| assert observation.debug is None | |
| assert observation.daily_metrics == [] | |
| assert observation.hourly_metrics == [] | |
| try: | |
| env.export_debug_snapshot() | |
| except RuntimeError as exc: | |
| assert "Debug mode is disabled" in str(exc) | |
| else: | |
| raise AssertionError("Expected debug snapshot to be gated.") | |
| env.set_debug_mode(True) | |
| debug_observation = env.reset() | |
| snapshot = env.export_debug_snapshot() | |
| assert debug_observation.debug is not None | |
| assert "expected_payload" in snapshot | |
| assert "applied_synthetic_generators" in snapshot | |
| def test_reset_exposes_synthetic_generator_metadata(): | |
| env = MetricTrackerRlEnvironment() | |
| observation = env.reset() | |
| assert observation.task_id == "easy_single_spike" | |
| assert len(observation.available_tasks) == 3 | |
| assert observation.available_synthetic_generator_methods | |
| assert observation.available_synthetic_generator_methods[0].name == "metric_stddev_shift" | |
| assert observation.applied_synthetic_generators == [] | |
| def test_named_task_reset_updates_instruction_and_config(): | |
| env = MetricTrackerRlEnvironment() | |
| observation = env.reset(task_id="hard_mixed_multi") | |
| assert observation.task_id == "hard_mixed_multi" | |
| assert observation.config["task_id"] == "hard_mixed_multi" | |
| assert observation.config["grader_name"] == "deterministic_exact_match" | |
| assert observation.config["difficulty"] == "hard" | |
| assert observation.instruction == get_task_spec("hard_mixed_multi").instruction | |
| def test_custom_reset_anomalies_support_specific_dates_and_stddev_factor(): | |
| env = MetricTrackerRlEnvironment() | |
| observation = env.reset( | |
| seed=21, | |
| scenario_family="mixed", | |
| anomaly_count=2, | |
| anomalies=[ | |
| { | |
| "method_name": "metric_stddev_shift", | |
| "metric_name": "orders_placed", | |
| "date": "2026-03-20", | |
| "stddev_factor": 2.5, | |
| "direction": "down", | |
| }, | |
| { | |
| "method_name": "metric_stddev_shift", | |
| "metric_name": "app_open_to_order_placed", | |
| "date": "2026-03-25", | |
| "stddev_factor": 2.0, | |
| "direction": "up", | |
| }, | |
| ], | |
| ) | |
| applied = {item.date: item for item in observation.applied_synthetic_generators} | |
| assert "2026-03-20" in applied | |
| assert "2026-03-25" in applied | |
| assert applied["2026-03-20"].metric_name == "orders_placed" | |
| assert applied["2026-03-20"].stddev_factor == 2.5 | |
| assert applied["2026-03-20"].threshold_value == round( | |
| applied["2026-03-20"].std_dev_from_median * 2.5, | |
| 4, | |
| ) | |
| assert applied["2026-03-25"].metric_type == "conversion_rate" | |
| def test_analysis_methods_run_through_step_api(): | |
| env = MetricTrackerRlEnvironment() | |
| env.reset() | |
| analyzed = env.step( | |
| MetricTrackerRlAction( | |
| analysis_method="list_suspicious_dates", | |
| analysis_args={"limit": 3}, | |
| ) | |
| ) | |
| assert analyzed.analysis_result is not None | |
| assert analyzed.analysis_result["method"] == "list_suspicious_dates" | |
| assert len(analyzed.analysis_result["result"]["dates"]) == 3 | |
| def test_payload_generator_method_creates_rows(): | |
| toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed") | |
| result = toolkit.get_median_filter_rows("app_open_to_order_placed", 2.0) | |
| assert result["details"][0]["threshold"] >= 0 | |
| assert isinstance(result["generated_rows"], list) | |
| def test_payload_generator_method_without_metric_runs_all_metrics(): | |
| toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed") | |
| result = toolkit.get_median_filter_rows_multi(metric_name=None, metric_names=[], threshold_multiplier=2.0) | |
| assert "app_opens" in result["metric_names"] | |
| assert "app_open_to_order_placed" in result["metric_names"] | |
| assert isinstance(result["generated_rows"], list) | |
| def test_family_specific_generator_methods_create_matching_anomaly_types(): | |
| cases = [ | |
| ("rate_drop_from_median", "get_rate_drop_from_median_rows", 1.5), | |
| ("rate_spike_from_median", "get_rate_spike_from_median_rows", 1.5), | |
| ("absolute_drop_in_event_count", "get_absolute_drop_in_event_count_rows", 1.5), | |
| ("absolute_spike_in_event_count", "get_absolute_spike_in_event_count_rows", 1.5), | |
| ("funnel_break", "get_funnel_break_rows", 1.0), | |
| ("hourly_traffic_mix_shift", "get_hourly_traffic_mix_shift_rows", 1.0), | |
| ("instrumentation_data_quality_issue", "get_instrumentation_data_quality_issue_rows", 1.0), | |
| ] | |
| for family, method_name, threshold_multiplier in cases: | |
| toolkit, _ = _toolkit_for(seed=7, scenario_family=family) | |
| method = getattr(toolkit, method_name) | |
| if "rate_" in method_name or "event_count" in method_name: | |
| result = method(metric_name=None, metric_names=[], threshold_multiplier=threshold_multiplier) | |
| else: | |
| result = method(threshold_multiplier=threshold_multiplier) | |
| assert result["generated_rows"], method_name | |
| assert {row["anomaly_type"] for row in result["generated_rows"]} == {family} | |
| def test_metric_summary_methods_without_metric_run_all_metrics(): | |
| toolkit, _ = _toolkit_for(seed=5, scenario_family="mixed") | |
| medians = toolkit.get_metric_median_multi(metric_name=None, metric_names=[]) | |
| stds = toolkit.get_metric_std_dev_from_median_multi(metric_name=None, metric_names=[]) | |
| diffs = toolkit.get_rows_with_abs_diff_from_median_gt_multi( | |
| metric_name=None, | |
| metric_names=[], | |
| threshold=1.0, | |
| ) | |
| assert "app_opens" in medians["metric_names"] | |
| assert "app_open_to_order_placed" in stds["metric_names"] | |
| assert len(medians["results"]) == len(medians["metric_names"]) | |
| assert len(stds["results"]) == len(stds["metric_names"]) | |
| assert len(diffs["results"]) == len(diffs["metric_names"]) | |
| def test_generator_submission_path_runs(): | |
| env = MetricTrackerRlEnvironment() | |
| env.reset() | |
| result = env.step( | |
| MetricTrackerRlAction( | |
| payload_generators=[ | |
| PayloadGeneratorMethod( | |
| method_name="get_median_filter_rows", | |
| metric_name="app_open_to_order_placed", | |
| threshold_multiplier=2.0, | |
| ) | |
| ] | |
| ) | |
| ) | |
| assert result.generated_rows is not None | |
| assert result.status in {"evaluated", "in_progress", "completed"} | |
| def test_generator_submission_path_supports_family_specific_methods(): | |
| env = MetricTrackerRlEnvironment() | |
| env.reset(task_id="hard_mixed_multi", scenario_family="funnel_break") | |
| result = env.step( | |
| MetricTrackerRlAction( | |
| payload_generators=[ | |
| PayloadGeneratorMethod( | |
| method_name="get_funnel_break_rows", | |
| threshold_multiplier=1.0, | |
| ) | |
| ] | |
| ) | |
| ) | |
| assert result.analysis_result is not None | |
| assert result.analysis_result["result"]["generated_rows"] is not None | |