Spaces:
Running
Running
| """ | |
| environment.py (Task 1 – Targeted Vulnerability Detection) | |
| ------------------------------------------------------------ | |
| Full OpenEnv-compliant environment. | |
| Episode flow: | |
| 1. reset() selects a random (contract, vulnerable_function) pair. | |
| 2. The agent receives an Observation with the contract description. | |
| 3. The agent uses actions to explore the contract (each costs a small penalty). | |
| 4. When the agent submits, the Grader scores the answer and the episode ends. | |
| """ | |
| from __future__ import annotations | |
| from math import floor, log2 | |
| import random | |
| from typing import Any, Dict, List, Optional, Set | |
| from data.data_loader import load_contracts, sample_episode | |
| from env.base_env import BaseEnv | |
| from env.schemas import ( | |
| Action, | |
| ActionType, | |
| Observation, | |
| Reward, | |
| ResetResult, | |
| StateResult, | |
| StepResult, | |
| ) | |
| from server.tasks.task1 import actions | |
| from .grader import Task1Grader | |
| TASK_ID = "task1_vuln_detection" | |
| AVAILABLE_ACTIONS = [ | |
| ActionType.LIST_FUNCTIONS, | |
| ActionType.GET_FUNCTION_CODE, | |
| ActionType.GET_FUNCTION_SUMMARY, | |
| ActionType.GET_FILE_METADATA, | |
| ActionType.GET_STATE_VARIABLE, | |
| ActionType.GET_CALL_GRAPH, | |
| ActionType.SUBMIT, | |
| ] | |
| class Task1Environment(BaseEnv): | |
| """Task 1: Targeted Vulnerability Detection.""" | |
| def __init__(self, contracts_path: Optional[str] = None) -> None: | |
| self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts() | |
| self._rng = random.Random() | |
| self._max_steps: int = 40 | |
| # Episode state (initialised by reset) | |
| self._contract: Dict[str, Any] = {} | |
| self._target_fn: Dict[str, Any] = {} | |
| self._grader: Optional[Task1Grader] = None | |
| self._step_count: int = 0 | |
| self._cummulative_cost: float = 0.0 | |
| self._done: bool = False | |
| self._query_history: List[str] = [] | |
| self._seen_queries: Set[str] = set() | |
| # ------------------------------------------------------------------ | |
| # OpenEnv interface | |
| # ------------------------------------------------------------------ | |
| def reset(self, seed: Optional[int] = None) -> ResetResult: | |
| """Start a new episode by sampling a random vulnerable function.""" | |
| if seed is not None: | |
| self._rng.seed(seed) | |
| self._contract, self._target_fn = sample_episode(self._contracts, self._rng) | |
| self._grader = Task1Grader( | |
| target_function=self._target_fn["name"], | |
| vulnerability_issue=self._target_fn["vulnerability_details"]["issue"], | |
| n = floor(log2(len(self._contract["functions"]))) | |
| ) | |
| self._step_count = 0 | |
| self._cummulative_cost = 0.0 | |
| self._done = False | |
| self._query_history = [] | |
| self._seen_queries = set() | |
| obs = self._build_observation( | |
| last_action=None, | |
| last_result=( | |
| f"New episode started. Contract: {self._contract['contract_name']}. " | |
| f"Use 'list_functions' to explore the contract." | |
| ), | |
| ) | |
| return ResetResult(observation=obs, info={"task_id": TASK_ID}) | |
| def step(self, action: Action) -> StepResult: | |
| """Execute one agent action.""" | |
| if self._done: | |
| raise RuntimeError("Episode is done. Call reset() to start a new episode.") | |
| if self._step_count > self._max_steps: | |
| raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.") | |
| self._step_count += 1 | |
| result_text, reward = self._dispatch(action) | |
| self._cummulative_cost += reward.value | |
| self._query_history.append(f"[{action.action_type}] → {result_text[:200]}") | |
| obs = self._build_observation( | |
| last_action=action.action_type, | |
| last_result=result_text, | |
| ) | |
| return StepResult( | |
| observation=obs, | |
| reward=reward, | |
| done=self._done, | |
| info={ | |
| "step": self._step_count, | |
| "cumulative_reward": self._cummulative_cost, | |
| }, | |
| ) | |
| def state(self) -> StateResult: | |
| return StateResult( | |
| task_id=TASK_ID, | |
| contract_name=self._contract.get("contract_name", ""), | |
| target_function=self._target_fn.get("name", ""), | |
| step_count=self._step_count, | |
| cumulative_reward=self._cummulative_cost, | |
| done=self._done, | |
| query_history=list(self._query_history), | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _build_observation( | |
| self, | |
| last_action: Optional[str], | |
| last_result: str, | |
| ) -> Observation: | |
| return Observation( | |
| task_id=TASK_ID, | |
| contract_name=self._contract.get("contract_name", ""), | |
| last_action=last_action, | |
| last_action_result=last_result, | |
| done=self._done, | |
| extra={ | |
| "solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""), | |
| "hint": ( | |
| "Identify the vulnerable function and its issue. " | |
| "Submit with action_type='submit', params={'function_name': '...', " | |
| "'vulnerability_type': '...'}" | |
| ), | |
| }, | |
| ) | |
| def _query_key(self, action_type: str, params: Dict[str, Any]) -> str: | |
| """Build a hashable key for repeated-query detection.""" | |
| return f"{action_type}:{sorted(params.items())}" | |
| def _is_repeated(self, key: str) -> bool: | |
| if key in self._seen_queries: | |
| return True | |
| self._seen_queries.add(key) | |
| return False | |
| def _dispatch(self, action: Action) -> tuple[str, Reward]: | |
| at = action.action_type | |
| params = action.params | |
| qkey = self._query_key(at, params) | |
| # Mapping from ActionType to handler function | |
| handlers = { | |
| ActionType.LIST_FUNCTIONS: actions.list_functions, | |
| ActionType.GET_FUNCTION_CODE: actions.get_function_code, | |
| ActionType.GET_FUNCTION_SUMMARY: actions.get_function_summary, | |
| ActionType.GET_FILE_METADATA: actions.get_file_metadata, | |
| ActionType.GET_STATE_VARIABLE: actions.get_state_variable, | |
| ActionType.GET_CALL_GRAPH: actions.get_call_graph, | |
| ActionType.SUBMIT: actions.submit, | |
| } | |
| handler = handlers.get(at) | |
| if handler is None: | |
| return actions.unknown_action(self, qkey, params, at) | |
| return handler(self, qkey, params) |