""" environment.py (Task 3 – Rule Checker) ----------------------------------------- OpenEnv-compliant RL environment. Episode setup ───────────── - A Solidity contract is selected that contains at least one function violating a known property. - The agent sees: contract description + the property in natural English. - The agent must identify which function breaks that property. Observation at reset ──────────────────── extra.property_english – the violated property in plain English extra.hint – instructions for the agent Actions & rewards ───────────────── list_functions -0.05 see all function names get_function_metadata -0.05 signature / visibility / modifiers / params get_function_code -0.10 full Solidity source of any function get_state_variables -0.05 list or inspect state variables get_call_graph -0.08 function call graph get_property_specification -0.03 formal pre/post-condition version of property submit_function terminal: +5.0 / +1.5 / -1.5 (ONE attempt) repeated_query -0.40 Difficulty: Easy The property text directly names the invariant broken; reading 2-3 functions should let most agents identify the culprit quickly. """ from __future__ import annotations import random from typing import Any, Dict, List, Optional, Set from data.data_loader import load_contracts, sample_task3_episode from env.base_env import BaseEnv from env.schemas import ( Action, ActionType, Observation, Reward, ResetResult, StateResult, StepResult, ) from .grader import Task3Grader from server.tasks.task3 import actions TASK_ID = "task3_rule_checker" AVAILABLE_ACTIONS = [ ActionType.LIST_FUNCTIONS, ActionType.GET_FUNCTION_METADATA, ActionType.GET_FUNCTION_CODE, ActionType.GET_STATE_VARIABLE, ActionType.GET_CALL_GRAPH, ActionType.GET_PROPERTY_SPECIFICATION, ActionType.SUBMIT_FUNCTION, ] class Task3Environment(BaseEnv): """Task 3: Rule Checker — identify the function that violates a given property.""" def __init__(self, contracts_path: Optional[str] = None) -> None: self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts() self._rng = random.Random() self._max_steps = 20 # Episode state — initialised by reset() self._contract: Dict[str, Any] = {} self._target_fn: Dict[str, Any] = {} self._grader: Optional[Task3Grader] = None self._step_count: int = 0 self._cum_reward: float = 0.0 self._done: bool = False self._query_hist: List[str] = [] self._seen: Set[str] = set() # ── OpenEnv interface ───────────────────────────────────────────────────── def reset(self, seed: Optional[int] = None) -> ResetResult: if seed is not None: self._rng.seed(seed) self._contract, self._target_fn = sample_task3_episode( self._contracts, self._rng ) self._grader = Task3Grader( target_function=self._target_fn, property_specification=self._target_fn.get("property_specification", ""), max_steps = self._max_steps ) self._step_count = 0 self._cum_reward = 0.0 self._done = False self._query_hist = [] self._seen = set() obs = self._build_obs( last_action=None, last_result=( f"New episode started.\n" f"Contract : {self._contract['contract_name']}\n\n" f"Property : {self._target_fn.get('property', '')}\n\n" f"Find the function in this contract that violates the property above.\n" f"Use list_functions then get_function_code to investigate.\n" f"Submit with submit_function, params={{\"function_name\": \"...\"}}.\n" f"Only ONE submission allowed." ), ) return ResetResult(observation=obs, info={"task_id": TASK_ID}) def step(self, action: Action) -> StepResult: if self._done: raise RuntimeError("Episode is done. Call reset() to start a new episode.") if self._step_count > self._max_steps: raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.") self._step_count += 1 result_text, reward = self._dispatch(action) self._cum_reward += reward.value self._query_hist.append(f"[{action.action_type}] → {result_text[:100]}") obs = self._build_obs( last_action=action.action_type, last_result=result_text, ) return StepResult( observation=obs, reward=reward, done=self._done, info={"step": self._step_count, "cumulative_reward": self._cum_reward}, ) def state(self) -> StateResult: return StateResult( task_id=TASK_ID, contract_name=self._contract.get("contract_name", ""), target_function=self._target_fn.get("name", ""), step_count=self._step_count, cumulative_reward=self._cum_reward, done=self._done, query_history=list(self._query_hist), ) # ── Internal helpers ────────────────────────────────────────────────────── def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation: return Observation( task_id=TASK_ID, contract_name=self._contract.get("contract_name", ""), last_action=last_action, last_action_result=last_result, done=self._done, extra={ "property_english": self._target_fn.get("property", ""), "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""), "hint": ( "Read the property, then inspect function code to find which one violates it. " "Submit with: submit_function, params={'function_name': ''}. " "ONE submission per episode." ), }, ) def _qkey(self, at: str, params: Dict[str, Any]) -> str: return f"{at}:{sorted(params.items())}" def _is_repeated(self, key: str) -> bool: if key in self._seen: return True self._seen.add(key) return False def _dispatch(self, action: Action) -> tuple[str, Reward]: at = action.action_type params = action.params qkey = self._qkey(at, params) # Mapping from ActionType to handler function handlers = { ActionType.LIST_FUNCTIONS: actions.list_functions, ActionType.GET_FUNCTION_METADATA: actions.get_function_metadata, ActionType.GET_FUNCTION_CODE: actions.get_function_code, ActionType.GET_STATE_VARIABLE: actions.get_state_variable, ActionType.GET_CALL_GRAPH: actions.get_call_graph, ActionType.GET_PROPERTY_SPECIFICATION: actions.get_property_specification, ActionType.SUBMIT_FUNCTION: actions.submit_function, } handler = handlers.get(at) if handler is None: return actions.unknown_action(self, qkey, params, at) return handler(self, qkey, params)