Spaces:
Running
Running
| """MEDUSA scenario generator. | |
| Produces randomised Bronze A (Fact) and Bronze B (Dimension) DataFrames to | |
| drive each training episode. Four canonical scenarios cover the canonical | |
| failure modes described in the MEDUSA blueprint. | |
| """ | |
| from __future__ import annotations | |
| import random | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional | |
| import numpy as np | |
| import pandas as pd | |
| # --------------------------------------------------------------------------- | |
| # Scenario dataclass | |
| # --------------------------------------------------------------------------- | |
| class Scenario: | |
| """One episode's worth of Bronze source data + configuration.""" | |
| id: str | |
| bronze_a: pd.DataFrame # Fact table (source of truth for volume) | |
| bronze_b: pd.DataFrame # Dimension table (must be unique on key) | |
| join_key: str # Column name used to join A and B | |
| tracked_cols: List[str] # Columns in B that require SCD-2 history | |
| is_stale_a: bool # Whether Source A is past the freshness threshold | |
| is_stale_b: bool | |
| time_delta_a: float # Hours since Source A was last refreshed | |
| time_delta_b: float | |
| new_cols_a: List[str] # Extra columns in A not in Silver yet | |
| new_cols_b: List[str] # Extra columns in B not in Silver yet | |
| description: str = "" | |
| # --------------------------------------------------------------------------- | |
| # Internal helpers | |
| # --------------------------------------------------------------------------- | |
| _STALE_THRESHOLD_HOURS = 6.0 | |
| def _make_fact( | |
| rng: random.Random, | |
| n_rows: int, | |
| key_col: str, | |
| null_ratio: float = 0.0, | |
| extra_cols: Optional[List[str]] = None, | |
| ) -> pd.DataFrame: | |
| """Create a synthetic Fact (Bronze A) DataFrame.""" | |
| keys = [f"K{i:04d}" for i in rng.sample(range(1, n_rows * 2), n_rows)] | |
| # Inject nulls into the key | |
| null_mask = rng.sample(range(n_rows), int(n_rows * null_ratio)) | |
| for idx in null_mask: | |
| keys[idx] = None # type: ignore[call-overload] | |
| data = { | |
| key_col: keys, | |
| "fact_value": [rng.uniform(0, 1000) for _ in range(n_rows)], | |
| "fact_category": [rng.choice(["A", "B", "C"]) for _ in range(n_rows)], | |
| "created_at": pd.date_range("2024-01-01", periods=n_rows, freq="h"), | |
| } | |
| for col in (extra_cols or []): | |
| data[col] = [rng.uniform(0, 100) for _ in range(n_rows)] | |
| return pd.DataFrame(data) | |
| def _make_dim( | |
| rng: random.Random, | |
| n_rows: int, | |
| key_col: str, | |
| null_ratio: float = 0.0, | |
| uniqueness: float = 1.0, # < 1.0 means some keys are duplicated | |
| match_keys: Optional[List[str]] = None, # If given, use these as the key pool | |
| extra_cols: Optional[List[str]] = None, | |
| tracked_cols: Optional[List[str]] = None, | |
| ) -> pd.DataFrame: | |
| """Create a synthetic Dimension (Bronze B) DataFrame.""" | |
| if match_keys: | |
| # Choose from overlap pool to control referential integrity | |
| available = list(match_keys) | |
| if len(available) >= n_rows: | |
| keys = rng.sample(available, n_rows) | |
| else: | |
| keys = [rng.choice(available) for _ in range(n_rows)] | |
| else: | |
| keys = [f"K{i:04d}" for i in rng.sample(range(1, n_rows * 3), n_rows)] | |
| # Inject duplicates (lower uniqueness) | |
| if uniqueness < 1.0: | |
| n_dupes = int(n_rows * (1 - uniqueness)) | |
| for i in rng.sample(range(n_rows), n_dupes): | |
| keys[i] = keys[rng.randint(0, i - 1)] if i > 0 else keys[0] | |
| # Inject nulls | |
| null_mask = rng.sample(range(n_rows), int(n_rows * null_ratio)) | |
| for idx in null_mask: | |
| keys[idx] = None # type: ignore[call-overload] | |
| data: dict = {key_col: keys, "dim_name": [f"Name_{k}" for k in keys]} | |
| for col in (tracked_cols or []): | |
| data[col] = [rng.choice(["x", "y", "z"]) for _ in range(n_rows)] | |
| for col in (extra_cols or []): | |
| data[col] = [rng.uniform(0, 100) for _ in range(n_rows)] | |
| return pd.DataFrame(data) | |
| # --------------------------------------------------------------------------- | |
| # Scenario Generator | |
| # --------------------------------------------------------------------------- | |
| class ScenarioGenerator: | |
| """Generates Bronze A/B DataFrames for MEDUSA episodes.""" | |
| STALE_THRESHOLD = _STALE_THRESHOLD_HOURS | |
| JOIN_KEY = "entity_id" | |
| TRACKED_COLS = ["dim_status"] | |
| # Four canonical scenario types | |
| CANONICAL: List[str] = ["clean", "dirty_keys", "stale", "schema_drift"] | |
| def __init__(self, n_fact_rows: int = 200, n_dim_rows: int = 150): | |
| self.n_fact_rows = n_fact_rows | |
| self.n_dim_rows = n_dim_rows | |
| def generate(self, seed: Optional[int] = None) -> Scenario: | |
| """Generate a random scenario. Canonical scenarios cycle through seeds 0-3.""" | |
| rng = random.Random(seed) | |
| if seed is not None and 0 <= seed < len(self.CANONICAL): | |
| return self._canonical(self.CANONICAL[seed], seed) | |
| variant = rng.choice(self.CANONICAL) | |
| return self._canonical(variant, seed) | |
| def _canonical(self, variant: str, seed: Optional[int]) -> Scenario: | |
| rng = random.Random(seed) | |
| np_rng = np.random.default_rng(seed) | |
| key = self.JOIN_KEY | |
| n_a = self.n_fact_rows | |
| n_b = self.n_dim_rows | |
| if variant == "clean": | |
| # Fresh, unique keys, ~100% match rate | |
| fact = _make_fact(rng, n_a, key, null_ratio=0.0) | |
| valid_keys = fact[key].dropna().tolist() | |
| dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0, | |
| match_keys=valid_keys, tracked_cols=self.TRACKED_COLS) | |
| return Scenario( | |
| id=f"clean_{seed}", | |
| bronze_a=fact, bronze_b=dim, | |
| join_key=key, tracked_cols=self.TRACKED_COLS, | |
| is_stale_a=False, is_stale_b=False, | |
| time_delta_a=1.0, time_delta_b=2.0, | |
| new_cols_a=[], new_cols_b=[], | |
| description="Clean scenario: fresh, unique keys, high match rate.", | |
| ) | |
| elif variant == "dirty_keys": | |
| # High null ratio in keys, no trimming / type-casting yet | |
| fact = _make_fact(rng, n_a, key, null_ratio=0.25) | |
| fact[key] = fact[key].apply( | |
| lambda k: f" {k} " if k and rng.random() < 0.3 else k # whitespace noise | |
| ) | |
| dim = _make_dim(rng, n_b, key, null_ratio=0.15, uniqueness=0.85, | |
| tracked_cols=self.TRACKED_COLS) | |
| return Scenario( | |
| id=f"dirty_keys_{seed}", | |
| bronze_a=fact, bronze_b=dim, | |
| join_key=key, tracked_cols=self.TRACKED_COLS, | |
| is_stale_a=False, is_stale_b=False, | |
| time_delta_a=2.0, time_delta_b=3.0, | |
| new_cols_a=[], new_cols_b=[], | |
| description="Dirty keys: nulls + whitespace in join keys.", | |
| ) | |
| elif variant == "stale": | |
| # One or both sources have not refreshed recently | |
| fact = _make_fact(rng, n_a, key, null_ratio=0.0) | |
| valid_keys = fact[key].dropna().tolist() | |
| dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0, | |
| match_keys=valid_keys, tracked_cols=self.TRACKED_COLS) | |
| td_a = rng.uniform(8.0, 24.0) # definitely stale | |
| td_b = rng.uniform(0.5, 4.0) | |
| return Scenario( | |
| id=f"stale_{seed}", | |
| bronze_a=fact, bronze_b=dim, | |
| join_key=key, tracked_cols=self.TRACKED_COLS, | |
| is_stale_a=td_a > self.STALE_THRESHOLD, | |
| is_stale_b=td_b > self.STALE_THRESHOLD, | |
| time_delta_a=td_a, time_delta_b=td_b, | |
| new_cols_a=[], new_cols_b=[], | |
| description=f"Stale scenario: Source A is {td_a:.1f}h old.", | |
| ) | |
| else: # schema_drift | |
| # New columns in A and/or B not yet registered in Silver | |
| extra_a = ["new_metric_a"] | |
| extra_b = ["new_attr_b"] | |
| fact = _make_fact(rng, n_a, key, null_ratio=0.0, extra_cols=extra_a) | |
| valid_keys = fact[key].dropna().tolist() | |
| dim = _make_dim(rng, n_b, key, null_ratio=0.0, uniqueness=1.0, | |
| match_keys=valid_keys, | |
| tracked_cols=self.TRACKED_COLS, extra_cols=extra_b) | |
| return Scenario( | |
| id=f"schema_drift_{seed}", | |
| bronze_a=fact, bronze_b=dim, | |
| join_key=key, tracked_cols=self.TRACKED_COLS, | |
| is_stale_a=False, is_stale_b=False, | |
| time_delta_a=1.0, time_delta_b=1.5, | |
| new_cols_a=extra_a, new_cols_b=extra_b, | |
| description="Schema drift: new columns in A and B.", | |
| ) | |