File size: 5,249 Bytes
4bdb808 d6243f2 4bdb808 d6243f2 4bdb808 d6243f2 4bdb808 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | """
Pydantic models for the AdAudit environment.
Defines Action, Observation, and State types that conform to the OpenEnv spec.
"""
from typing import Any, Dict, List, Literal, Optional
from pydantic import BaseModel, Field, field_validator
from openenv.core.env_server.types import Action, Observation, State
# ---------------------------------------------------------------------------
# Action
# ---------------------------------------------------------------------------
class AdAuditAction(Action):
"""Single discrete action the agent takes each day."""
action_type: Literal[
"monitor",
"investigate_publisher",
"flag_fraud",
"submit_report",
"invalid",
] = Field(..., description="The type of action to take")
publisher_id: Optional[str] = Field(
default=None, description="Target publisher for investigate/flag actions"
)
# investigate_publisher
tool: Optional[Literal[
"click_timestamps",
"ip_distribution",
"device_fingerprints",
"referral_urls",
"viewability_scores",
"conversion_quality",
]] = Field(default=None, description="Investigation tool to use")
# flag_fraud
fraud_type: Optional[Literal[
"bot_traffic",
"domain_spoofing",
"click_injection",
]] = Field(default=None, description="Fraud type to flag")
evidence: Optional[List[str]] = Field(
default=None,
description="Evidence tool names, comma-separated (e.g. click_timestamps, ip_distribution)",
)
@field_validator("evidence", mode="before")
@classmethod
def _coerce_evidence(cls, v):
if isinstance(v, str):
import json
try:
parsed = json.loads(v)
if isinstance(parsed, list):
return parsed
except (json.JSONDecodeError, ValueError):
pass
# Handle bare string like "click_timestamps"
stripped = v.strip("[] ")
return [s.strip().strip("'\"") for s in stripped.split(",") if s.strip()]
return v
# submit_report
summary: Optional[str] = Field(default=None)
# ---------------------------------------------------------------------------
# Observation helpers
# ---------------------------------------------------------------------------
class DailyPublisherMetrics(BaseModel):
"""Traffic metrics for one publisher on one day."""
publisher_id: str
name: str
impressions: int
clicks: int
conversions: int
spend: float
ctr: float
cvr: float
class BudgetStatus(BaseModel):
"""Campaign and investigation budget snapshot."""
total_campaign_budget: float
spent_so_far: float
remaining: float
investigation_budget_remaining: int
daily_spend_rate: float
# ---------------------------------------------------------------------------
# Observation
# ---------------------------------------------------------------------------
class AdAuditObservation(Observation):
"""What the agent sees after each step.
Inherits ``done``, ``reward``, and ``metadata`` from the OpenEnv
``Observation`` base class. ``reward`` carries the daily P&L.
"""
day: int = Field(..., description="Current campaign day (1-30)")
campaign_day_total: int = Field(default=14)
daily_metrics: List[DailyPublisherMetrics] = Field(default_factory=list)
cumulative_metrics: List[DailyPublisherMetrics] = Field(default_factory=list)
trend_data: str = Field(default="", description="Trend summary")
investigation_results: Optional[Dict[str, Any]] = Field(
default=None, description="Structured metrics from investigation tool"
)
alerts: List[str] = Field(default_factory=list)
budget_status: Optional[BudgetStatus] = None
publisher_status: Dict[str, str] = Field(
default_factory=dict,
description="publisher_id -> active|flagged",
)
cumulative_reward: float = Field(default=0.0)
done_reason: Optional[str] = Field(default=None)
# ---------------------------------------------------------------------------
# State (hidden — used for grading / debugging)
# ---------------------------------------------------------------------------
class PublisherState(BaseModel):
"""Public publisher state (visible via /state)."""
publisher_id: str
name: str
is_flagged: bool = False
budget_allocation: float = 0.0
tools_used: List[str] = Field(default_factory=list)
day_flagged: Optional[int] = None
class AdAuditState(State):
"""Full internal state for debugging and grading.
Inherits ``episode_id`` and ``step_count`` from OpenEnv ``State``.
"""
case_id: str = ""
current_day: int = 0
publishers: List[PublisherState] = Field(default_factory=list)
action_history: List[Dict[str, Any]] = Field(default_factory=list)
daily_rewards: List[float] = Field(default_factory=list)
cumulative_reward: float = 0.0
investigation_budget_total: int = 0
investigation_budget_used: int = 0
flags_submitted: List[Dict[str, Any]] = Field(default_factory=list)
grader_inputs: Dict[str, Any] = Field(default_factory=dict)
|