""" environment.py (Task 2 – Property Discovery) ---------------------------------------------- OpenEnv-compliant RL environment. Episode setup: - One function from a Solidity contract that has a known property. - The agent sees: contract description + function name + function signature. - The agent must discover the natural-language property of the function. Actions & rewards: get_function_code -0.06 (always positive topic context) get_function_natspec -0.08 (strongest hint — natspec has param/return docs) get_file_natspec -0.03 (broad contract-level context) get_related_functions -0.06 (shows callers/callees) get_io -0.04 (structured input/output description) get_similar_rule -0.20 (shows a similar property from another contract) submit_property scored 0–5 (ONE attempt, ends episode) repeated_query -0.40 Episode ends when: - submit_property is called (scored), OR - max_steps is reached without submission (reward = -1.0) """ from __future__ import annotations import random from typing import Any, Dict, List, Optional, Set from math import log2, floor from data.data_loader import load_contracts, sample_property_episode from env.base_env import BaseEnv from env.schemas import ( Action, ActionType, Observation, Reward, ResetResult, StateResult, StepResult, ) from .grader import Task2Grader from server.tasks.task2 import actions TASK_ID = "task2_property_discovery" AVAILABLE_ACTIONS = [ ActionType.GET_FUNCTION_CODE, ActionType.GET_FUNCTION_NATSPEC, ActionType.GET_FILE_NATSPEC, ActionType.GET_RELATED_FUNCTIONS, ActionType.GET_SIGNATURE, ActionType.GET_SIMILAR_RULE, ActionType.SUBMIT_PROPERTY, ] class Task2Environment(BaseEnv): """Task 2: Property Discovery.""" def __init__(self, contracts_path: Optional[str] = None) -> None: self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts() self._rng = random.Random() self._max_steps: int = 40 # Episode state – initialised by reset() self._contract: Dict[str, Any] = {} self._target_fn: Dict[str, Any] = {} self._grader: Optional[Task2Grader] = None self._step_count: int = 0 self._cum_reward: float = 0.0 self._done: bool = False self._query_hist: List[str] = [] self._seen: Set[str] = set() # ── OpenEnv interface ──────────────────────────────────────────────────── def reset(self, seed: Optional[int] = None) -> ResetResult: if seed is not None: self._rng.seed(seed) self._contract, self._target_fn = sample_property_episode( self._contracts, self._rng ) self._grader = Task2Grader( function_name=self._target_fn["name"], property=self._target_fn["property"], n = floor(log2(len(self._contract["functions"]))) ) self._step_count = 0 self._cum_reward = 0.0 self._done = False self._query_hist = [] self._seen = set() obs = self._build_obs( last_action=None, last_result=( f"New episode started.\n" f"Contract : {self._contract['contract_name']}\n" f"Function : {self._target_fn['name']} " f"({self._target_fn.get('signature', '')})\n" f"Your task : Discover the natural-language property of " f"'{self._target_fn['name']}' and submit it with submit_property action." ), ) return ResetResult(observation=obs, info={"task_id": TASK_ID}) def step(self, action: Action) -> StepResult: if self._done: raise RuntimeError("Episode is done. Call reset() to start a new episode.") if self._step_count > self._max_steps: raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.") self._step_count += 1 result_text, reward = self._dispatch(action) self._cum_reward += reward.value self._query_hist.append(f"[{action.action_type}] → {result_text[:100]}") obs = self._build_obs( last_action=action.action_type, last_result=result_text, ) return StepResult( observation=obs, reward=reward, done=self._done, info={ "step": self._step_count, "cumulative_reward": self._cum_reward, }, ) def state(self) -> StateResult: return StateResult( task_id=TASK_ID, contract_name=self._contract.get("contract_name", ""), target_function=self._target_fn.get("name", ""), step_count=self._step_count, cumulative_reward=self._cum_reward, done=self._done, query_history=list(self._query_hist), ) # ── Internal helpers ───────────────────────────────────────────────────── def _build_obs(self, last_action: Optional[str], last_result: str) -> Observation: return Observation( task_id=TASK_ID, contract_name=self._contract.get("contract_name", ""), last_action=last_action, last_action_result=last_result, done=self._done, extra={ "target_function": self._target_fn.get("name", ""), "target_signature": self._target_fn.get("signature", ""), "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""), "hint": ( "Discover the property of the target function. " "Use get_function_code, get_function_natspec, or get_similar_rule for hints. " "Submit with submit_property, params={'property': ''}. " "ONE submission attempt only." ), }, ) def _qkey(self, at: str, params: Dict[str, Any]) -> str: return f"{at}:{sorted(params.items())}" def _is_repeated(self, key: str) -> bool: if key in self._seen: return True self._seen.add(key) return False def _dispatch(self, action: Action) -> tuple[str, Reward]: at = action.action_type params = action.params qkey = self._qkey(at, params) handlers = { ActionType.GET_FUNCTION_CODE: actions.get_function_code, ActionType.GET_FUNCTION_NATSPEC: actions.get_function_natspec, ActionType.GET_FILE_NATSPEC: actions.get_file_natspec, ActionType.GET_RELATED_FUNCTIONS: actions.get_related_functions_action, ActionType.GET_SIGNATURE: actions.get_signature, ActionType.GET_SIMILAR_RULE: actions.get_similar_rule_action, ActionType.SUBMIT_PROPERTY: actions.submit_property, } handler = handlers.get(at) if handler is None: return actions.unknown_action(self, qkey, params, at) return handler(self, qkey, params)