""" AdAuditEnv — main environment class. Wires together publisher_engine, fraud_engine, response_generator, step_reward, and grader into the OpenEnv Environment interface. """ from __future__ import annotations import json from pathlib import Path from typing import Any, Dict, List, Optional from uuid import uuid4 from openenv.core.env_server.interfaces import Environment try: from ..models import ( AdAuditAction, AdAuditObservation, AdAuditState, BudgetStatus, DailyPublisherMetrics, PublisherState, ) from .fraud_engine import ( decay_suspicion, get_adaptation_stage, update_suspicion, ) from .publisher_engine import generate_daily_traffic from .response_generator import ( generate_alerts, generate_investigation_metrics, generate_trend_summary, ) from .step_reward import compute_step_reward from .grader import grade_episode except ImportError: from models import ( # type: ignore[no-redef] AdAuditAction, AdAuditObservation, AdAuditState, BudgetStatus, DailyPublisherMetrics, PublisherState, ) from server.fraud_engine import ( # type: ignore[no-redef] decay_suspicion, get_adaptation_stage, update_suspicion, ) from server.publisher_engine import generate_daily_traffic # type: ignore[no-redef] from server.response_generator import ( # type: ignore[no-redef] generate_alerts, generate_investigation_metrics, generate_trend_summary, ) from server.step_reward import compute_step_reward # type: ignore[no-redef] from server.grader import grade_episode # type: ignore[no-redef] CASES_DIR = Path(__file__).resolve().parent.parent / "cases" TASK_MAP = { "easy": "easy.json", "medium": "medium.json", "hard": "hard.json", } EPISODE_DAYS = 14 class _PubInternal: """Hidden per-publisher state (not exposed via /state).""" __slots__ = ( "is_fraudulent", "fraud_type", "suspicion_level", "adaptation_stage", "total_fraudulent_spend", "total_legitimate_spend", "total_legitimate_revenue", ) def __init__(self, is_fraudulent: bool = False, fraud_type: str = None): self.is_fraudulent = is_fraudulent self.fraud_type = fraud_type self.suspicion_level = 0.0 self.adaptation_stage = "normal" self.total_fraudulent_spend = 0.0 self.total_legitimate_spend = 0.0 self.total_legitimate_revenue = 0.0 class AdAuditEnv(Environment[AdAuditAction, AdAuditObservation, AdAuditState]): """OpenEnv-compatible RL environment for ad fraud detection.""" SUPPORTS_CONCURRENT_SESSIONS = True @classmethod def get_tasks(cls) -> List[str]: return list(TASK_MAP.keys()) _TASK_CYCLE = ["easy", "medium", "hard"] def __init__(self) -> None: super().__init__() self._case: Dict[str, Any] = {} self._state = AdAuditState() self._pub_cfgs: Dict[str, Dict[str, Any]] = {} self._pub_names: Dict[str, str] = {} self._pub_internal: Dict[str, _PubInternal] = {} self._daily_logs: Dict[str, List[Dict[str, Any]]] = {} self._step_action: Optional[AdAuditAction] = None self._cycle_index: int = 0 self._invalid_action: bool = False # ------------------------------------------------------------------ # reset # ------------------------------------------------------------------ def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any, ) -> AdAuditObservation: task_id = kwargs.get("task_id") or episode_id if not task_id: task_id = self._TASK_CYCLE[self._cycle_index % len(self._TASK_CYCLE)] self._cycle_index += 1 case_file = CASES_DIR / TASK_MAP.get(task_id, f"{task_id}.json") with open(case_file) as f: self._case = json.load(f) campaign = self._case["campaign"] publishers = self._case["publishers"] pub_states: List[PublisherState] = [] self._pub_cfgs = {} self._pub_names = {} self._pub_internal = {} self._daily_logs = {} for pub_id, cfg in publishers.items(): self._pub_cfgs[pub_id] = cfg self._pub_names[pub_id] = cfg.get("name", pub_id) self._daily_logs[pub_id] = [] pub_states.append(PublisherState( publisher_id=pub_id, name=cfg.get("name", pub_id), budget_allocation=cfg.get("budget_allocation", 1.0 / len(publishers)), )) self._pub_internal[pub_id] = _PubInternal( is_fraudulent=cfg.get("is_fraudulent", False), fraud_type=cfg.get("fraud_type"), ) self._state = AdAuditState( episode_id=episode_id or str(uuid4()), step_count=0, case_id=self._case.get("case_id", task_id), current_day=0, publishers=pub_states, investigation_budget_total=campaign.get("investigation_budget", 8), investigation_budget_used=0, ) self._step_action = None self._invalid_action = False return self._advance_day() # ------------------------------------------------------------------ # step # ------------------------------------------------------------------ def step( self, action: AdAuditAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> AdAuditObservation: if self._state.current_day >= EPISODE_DAYS: return self._finalize("Campaign ended.") investigation_result: Optional[Dict[str, Any]] = None self._invalid_action = False at = action.action_type self._state.action_history.append(action.model_dump(exclude_none=True)) self._step_action = action if at == "monitor": pass elif at == "investigate_publisher": investigation_result = self._handle_investigate(action) if investigation_result and "error" in investigation_result: self._invalid_action = True elif at == "flag_fraud": ps = self._get_pub_state(action.publisher_id) if ps is None or ps.is_flagged or not action.fraud_type: self._invalid_action = True else: self._handle_flag_fraud(action) elif at == "submit_report": return self._finalize("Agent submitted report.") else: self._invalid_action = True return self._advance_day(investigation_result=investigation_result) # ------------------------------------------------------------------ # state property # ------------------------------------------------------------------ @property def state(self) -> AdAuditState: return self._state # ------------------------------------------------------------------ # Action handlers # ------------------------------------------------------------------ def _handle_investigate(self, action: AdAuditAction) -> Optional[Dict[str, Any]]: pub_id = action.publisher_id tool = action.tool if not pub_id or not tool: return {"error": "publisher_id and tool are required"} ps = self._get_pub_state(pub_id) if ps is None: valid = [p.publisher_id for p in self._state.publishers] return {"error": f"unknown publisher_id: {pub_id}. Valid IDs: {valid}"} if ps.is_flagged: return {"error": f"{pub_id} is already flagged."} budget_remaining = ( self._state.investigation_budget_total - self._state.investigation_budget_used ) if budget_remaining <= 0: return {"error": "no investigation budget remaining"} self._state.investigation_budget_used += 1 cfg = self._pub_cfgs.get(pub_id, {}) hi = self._pub_internal[pub_id] if tool not in ps.tools_used: ps.tools_used.append(tool) if hi.is_fraudulent: hi.suspicion_level = update_suspicion( hi.suspicion_level, tool, cfg.get("suspicion_reactivity", 1.0), ) hi.adaptation_stage = get_adaptation_stage(hi.suspicion_level) return generate_investigation_metrics( case_id=self._state.case_id, publisher_id=pub_id, publisher_cfg=cfg, tool_name=tool, adaptation_stage=hi.adaptation_stage, ) def _handle_flag_fraud(self, action: AdAuditAction) -> None: pub_id = action.publisher_id ps = self._get_pub_state(pub_id) hi = self._pub_internal[pub_id] ps.is_flagged = True ps.day_flagged = self._state.current_day + 1 is_correct = hi.is_fraudulent type_correct = (action.fraud_type == hi.fraud_type) if is_correct else False self._state.flags_submitted.append({ "publisher_id": pub_id, "fraud_type": action.fraud_type, "evidence": action.evidence or [], "day": self._state.current_day + 1, "correct": is_correct, "type_correct": type_correct, }) # ------------------------------------------------------------------ # Day advancement # ------------------------------------------------------------------ def _advance_day( self, investigation_result: Optional[Dict[str, Any]] = None, ) -> AdAuditObservation: self._state.current_day += 1 self._state.step_count = self._state.current_day day = self._state.current_day # Decay suspicion for publishers NOT investigated today investigated_today = set() if self._state.action_history: last = self._state.action_history[-1] if last.get("action_type") == "investigate_publisher" and last.get("publisher_id"): investigated_today.add(last["publisher_id"]) for ps in self._state.publishers: hi = self._pub_internal[ps.publisher_id] if hi.is_fraudulent and ps.publisher_id not in investigated_today: hi.suspicion_level = decay_suspicion(hi.suspicion_level) hi.adaptation_stage = get_adaptation_stage(hi.suspicion_level) # Generate traffic daily_traffic: List[Dict[str, Any]] = [] daily_metrics: List[DailyPublisherMetrics] = [] campaign = self._case["campaign"] benchmarks = campaign.get("industry_benchmarks", {}) daily_fraud_spend = 0.0 for ps in self._state.publishers: cfg = self._pub_cfgs.get(ps.publisher_id, {}) hi = self._pub_internal[ps.publisher_id] traffic = generate_daily_traffic( day=day, publisher_cfg=cfg, budget_allocation=ps.budget_allocation, adaptation_stage=hi.adaptation_stage, is_paused=ps.is_flagged, ) daily_traffic.append(traffic) self._daily_logs[ps.publisher_id].append(traffic) hi.total_legitimate_spend += traffic["legitimate_spend"] hi.total_fraudulent_spend += traffic["fraudulent_spend"] hi.total_legitimate_revenue += traffic["legitimate_revenue"] if hi.is_fraudulent and not ps.is_flagged: daily_fraud_spend += traffic["fraudulent_spend"] daily_metrics.append(DailyPublisherMetrics( publisher_id=ps.publisher_id, name=ps.name, impressions=traffic["impressions"], clicks=traffic["clicks"], conversions=traffic["conversions"], spend=traffic["spend"], ctr=traffic["ctr"], cvr=traffic["cvr"], )) # --- Compute step reward --- action = self._step_action total_budget = campaign["total_budget"] if action is None: step_reward = 0.0 elif self._invalid_action: step_reward = compute_step_reward( action_type="invalid", daily_fraud_spend=daily_fraud_spend, total_budget=total_budget, day=day, episode_days=EPISODE_DAYS, ) elif action.action_type == "flag_fraud": last_flag = self._state.flags_submitted[-1] if self._state.flags_submitted else {} step_reward = compute_step_reward( action_type="flag_fraud", flag_correct=last_flag.get("correct"), flag_type_correct=last_flag.get("type_correct"), daily_fraud_spend=daily_fraud_spend, total_budget=total_budget, day=day, episode_days=EPISODE_DAYS, ) elif action.action_type == "investigate_publisher": pub_cfg = self._pub_cfgs.get(action.publisher_id, {}) step_reward = compute_step_reward( action_type="investigate_publisher", publisher_cfg=pub_cfg, daily_fraud_spend=daily_fraud_spend, total_budget=total_budget, day=day, episode_days=EPISODE_DAYS, ) else: step_reward = compute_step_reward( action_type=action.action_type, daily_fraud_spend=daily_fraud_spend, total_budget=total_budget, day=day, episode_days=EPISODE_DAYS, ) # Trend + alerts trend_data = "" # TODO: generate_trend_summary(self._daily_logs, self._pub_names, day) raw_metrics = [ {"publisher_id": m.publisher_id, "ctr": m.ctr, "cvr": m.cvr, "impressions": m.impressions, "clicks": m.clicks} for m in daily_metrics ] alerts = [] # TODO: generate_alerts(raw_metrics, benchmarks, self._pub_names) cumulative_metrics = self._compute_cumulative_metrics() total_spend = sum( hi.total_legitimate_spend + hi.total_fraudulent_spend for hi in self._pub_internal.values() ) budget_status = BudgetStatus( total_campaign_budget=campaign["total_budget"], spent_so_far=round(total_spend, 2), remaining=round(campaign["total_budget"] - total_spend, 2), investigation_budget_remaining=( self._state.investigation_budget_total - self._state.investigation_budget_used ), daily_spend_rate=round(total_spend / day, 2) if day > 0 else 0.0, ) pub_status = { ps.publisher_id: ("flagged" if ps.is_flagged else "active") for ps in self._state.publishers } # Termination done = False done_reason: Optional[str] = None if day >= EPISODE_DAYS: done = True done_reason = f"Campaign ended (day {EPISODE_DAYS})." elif budget_status.remaining <= 0: done = True done_reason = "Campaign budget exhausted." # On episode end, compute grader (stored separately, not in step reward) if done: self._state.grader_inputs = grade_episode( self._build_grader_state(), self._case, ) step_reward = round(max(0.0, min(1.0, step_reward)), 4) self._state.daily_rewards.append(step_reward) self._state.cumulative_reward += step_reward return self._apply_transform(AdAuditObservation( day=day, campaign_day_total=EPISODE_DAYS, daily_metrics=daily_metrics, cumulative_metrics=cumulative_metrics, trend_data=trend_data, investigation_results=investigation_result, alerts=alerts, budget_status=budget_status, publisher_status=pub_status, cumulative_reward=round(self._state.cumulative_reward, 4), done=done, done_reason=done_reason, reward=step_reward, )) # ------------------------------------------------------------------ # Finalize # ------------------------------------------------------------------ def _finalize(self, reason: str) -> AdAuditObservation: self._state.grader_inputs = grade_episode( self._build_grader_state(), self._case, ) grader_score = self._state.grader_inputs.get("final_score", 0.0) step_reward = grader_score self._state.daily_rewards.append(step_reward) self._state.cumulative_reward += step_reward return AdAuditObservation( day=self._state.current_day, campaign_day_total=EPISODE_DAYS, trend_data=f"Episode complete. Grader score: {grader_score:.4f}", done=True, done_reason=reason, reward=step_reward, cumulative_reward=round(self._state.cumulative_reward, 4), publisher_status={ ps.publisher_id: ("flagged" if ps.is_flagged else "active") for ps in self._state.publishers }, ) # ------------------------------------------------------------------ # Helpers # ------------------------------------------------------------------ def _get_pub_state(self, pub_id: Optional[str]) -> Optional[PublisherState]: for ps in self._state.publishers: if ps.publisher_id == pub_id: return ps return None def _build_grader_state(self) -> Dict[str, Any]: """Build the state dict the grader expects, including hidden fields.""" state_dict = self._state.model_dump() # Enrich publisher entries with hidden internal state for grading for pub_dict in state_dict["publishers"]: pub_id = pub_dict["publisher_id"] hi = self._pub_internal[pub_id] pub_dict["is_fraudulent"] = hi.is_fraudulent pub_dict["fraud_type"] = hi.fraud_type pub_dict["suspicion_level"] = hi.suspicion_level pub_dict["total_fraudulent_spend"] = hi.total_fraudulent_spend pub_dict["total_legitimate_spend"] = hi.total_legitimate_spend return state_dict def _compute_cumulative_metrics(self) -> List[DailyPublisherMetrics]: result = [] for ps in self._state.publishers: logs = self._daily_logs.get(ps.publisher_id, []) if not logs: continue total_imp = sum(d["impressions"] for d in logs) total_clicks = sum(d["clicks"] for d in logs) total_conv = sum(d["conversions"] for d in logs) total_spend = sum(d["spend"] for d in logs) ctr = total_clicks / total_imp if total_imp > 0 else 0.0 cvr = total_conv / total_clicks if total_clicks > 0 else 0.0 result.append(DailyPublisherMetrics( publisher_id=ps.publisher_id, name=ps.name, impressions=total_imp, clicks=total_clicks, conversions=total_conv, spend=round(total_spend, 2), ctr=round(ctr, 4), cvr=round(cvr, 4), )) return result # Alias used by app.py and server/__init__.py AdAuditEnvironment = AdAuditEnv