SmartContractAudit / server /tasks /task1 /environment.py
ajaxwin
refactor: Improved grading logic for task 2
f78cba2
"""
environment.py (Task 1 – Targeted Vulnerability Detection)
------------------------------------------------------------
Full OpenEnv-compliant environment.
Episode flow:
1. reset() selects a random (contract, vulnerable_function) pair.
2. The agent receives an Observation with the contract description.
3. The agent uses actions to explore the contract (each costs a small penalty).
4. When the agent submits, the Grader scores the answer and the episode ends.
"""
from __future__ import annotations
from math import floor, log2
import random
from typing import Any, Dict, List, Optional, Set
from data.data_loader import load_contracts, sample_episode
from env.base_env import BaseEnv
from env.schemas import (
Action,
ActionType,
Observation,
Reward,
ResetResult,
StateResult,
StepResult,
)
from server.tasks.task1 import actions
from .grader import Task1Grader
TASK_ID = "task1_vuln_detection"
AVAILABLE_ACTIONS = [
ActionType.LIST_FUNCTIONS,
ActionType.GET_FUNCTION_CODE,
ActionType.GET_FUNCTION_SUMMARY,
ActionType.GET_FILE_METADATA,
ActionType.GET_STATE_VARIABLE,
ActionType.GET_CALL_GRAPH,
ActionType.SUBMIT,
]
class Task1Environment(BaseEnv):
"""Task 1: Targeted Vulnerability Detection."""
def __init__(self, contracts_path: Optional[str] = None) -> None:
self._contracts = load_contracts(contracts_path) if contracts_path else load_contracts()
self._rng = random.Random()
self._max_steps: int = 40
# Episode state (initialised by reset)
self._contract: Dict[str, Any] = {}
self._target_fn: Dict[str, Any] = {}
self._grader: Optional[Task1Grader] = None
self._step_count: int = 0
self._cummulative_cost: float = 0.0
self._done: bool = False
self._query_history: List[str] = []
self._seen_queries: Set[str] = set()
# ------------------------------------------------------------------
# OpenEnv interface
# ------------------------------------------------------------------
def reset(self, seed: Optional[int] = None) -> ResetResult:
"""Start a new episode by sampling a random vulnerable function."""
if seed is not None:
self._rng.seed(seed)
self._contract, self._target_fn = sample_episode(self._contracts, self._rng)
self._grader = Task1Grader(
target_function=self._target_fn["name"],
vulnerability_issue=self._target_fn["vulnerability_details"]["issue"],
n = floor(log2(len(self._contract["functions"])))
)
self._step_count = 0
self._cummulative_cost = 0.0
self._done = False
self._query_history = []
self._seen_queries = set()
obs = self._build_observation(
last_action=None,
last_result=(
f"New episode started. Contract: {self._contract['contract_name']}. "
f"Use 'list_functions' to explore the contract."
),
)
return ResetResult(observation=obs, info={"task_id": TASK_ID})
def step(self, action: Action) -> StepResult:
"""Execute one agent action."""
if self._done:
raise RuntimeError("Episode is done. Call reset() to start a new episode.")
if self._step_count > self._max_steps:
raise RuntimeError("Exceeded maximum number of steps allowed. Call reset() to start a new episode.")
self._step_count += 1
result_text, reward = self._dispatch(action)
self._cummulative_cost += reward.value
self._query_history.append(f"[{action.action_type}] → {result_text[:200]}")
obs = self._build_observation(
last_action=action.action_type,
last_result=result_text,
)
return StepResult(
observation=obs,
reward=reward,
done=self._done,
info={
"step": self._step_count,
"cumulative_reward": self._cummulative_cost,
},
)
def state(self) -> StateResult:
return StateResult(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
target_function=self._target_fn.get("name", ""),
step_count=self._step_count,
cumulative_reward=self._cummulative_cost,
done=self._done,
query_history=list(self._query_history),
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _build_observation(
self,
last_action: Optional[str],
last_result: str,
) -> Observation:
return Observation(
task_id=TASK_ID,
contract_name=self._contract.get("contract_name", ""),
last_action=last_action,
last_action_result=last_result,
done=self._done,
extra={
"solidity_version": self._contract.get("metadata", {}).get("solidity_version", ""),
"hint": (
"Identify the vulnerable function and its issue. "
"Submit with action_type='submit', params={'function_name': '...', "
"'vulnerability_type': '...'}"
),
},
)
def _query_key(self, action_type: str, params: Dict[str, Any]) -> str:
"""Build a hashable key for repeated-query detection."""
return f"{action_type}:{sorted(params.items())}"
def _is_repeated(self, key: str) -> bool:
if key in self._seen_queries:
return True
self._seen_queries.add(key)
return False
def _dispatch(self, action: Action) -> tuple[str, Reward]:
at = action.action_type
params = action.params
qkey = self._query_key(at, params)
# Mapping from ActionType to handler function
handlers = {
ActionType.LIST_FUNCTIONS: actions.list_functions,
ActionType.GET_FUNCTION_CODE: actions.get_function_code,
ActionType.GET_FUNCTION_SUMMARY: actions.get_function_summary,
ActionType.GET_FILE_METADATA: actions.get_file_metadata,
ActionType.GET_STATE_VARIABLE: actions.get_state_variable,
ActionType.GET_CALL_GRAPH: actions.get_call_graph,
ActionType.SUBMIT: actions.submit,
}
handler = handlers.get(at)
if handler is None:
return actions.unknown_action(self, qkey, params, at)
return handler(self, qkey, params)