"""Actions for Task 1: Targeted Vulnerability Detection. """ from typing import Any, Dict, Tuple from env.schemas import ActionType, Reward from data.data_loader import ( list_function_names, get_function_by_name, list_state_variable_names, get_state_variable_by_name, ) def list_functions(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: """Handle LIST_FUNCTIONS action.""" if ctx._is_repeated(qkey): return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) names = list_function_names(ctx._contract) return ( f"Functions in {ctx._contract['contract_name']}: {', '.join(names)}", Reward(value=ActionType.LIST_FUNCTIONS.cost, reason="list_functions cost", partial=True), ) def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: """Handle GET_FUNCTION_CODE action.""" fn_name = params.get("function_name", "") if ctx._is_repeated(qkey): return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) fn = get_function_by_name(ctx._contract, fn_name) if fn is None: return ( f"Function '{fn_name}' not found. Available: {list_function_names(ctx._contract)}", Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Wrong/unknown function name", partial=True), ) code = fn.get("code", "// no code available") return ( f"// {fn['name']}\n{code}", Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Fetched code", partial=True), ) def get_function_summary(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: """Handle GET_FUNCTION_SUMMARY action.""" fn_name = params.get("function_name", "") if ctx._is_repeated(qkey): return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) fn = get_function_by_name(ctx._contract, fn_name) if fn is None: return ( f"Function '{fn_name}' not found.", Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Wrong function name", partial=True), ) comment = fn.get("comment", "No summary available.") return ( f"Summary of '{fn['name']}': {comment}", Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Fetched summary", partial=True), ) def get_file_metadata(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: """Handle GET_FILE_METADATA action.""" if ctx._is_repeated(qkey): return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) meta = ctx._contract.get("metadata", {}) result = ( f"Contract: {ctx._contract['contract_name']} | " f"Solidity: {meta.get('solidity_version', 'N/A')} | " f"Description: {meta.get('description', 'N/A')}" ) return result, Reward(value=ActionType.GET_FILE_METADATA.cost, reason="get_file_metadata cost", partial=True) def get_state_variable(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: """Handle GET_STATE_VARIABLE action.""" var_name = params.get("variable_name", "") if ctx._is_repeated(qkey): return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) if not var_name: names = list_state_variable_names(ctx._contract) return ( f"State variables: {', '.join(names)}", Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Listed state variables", partial=True), ) sv = get_state_variable_by_name(ctx._contract, var_name) if sv is None: return ( f"Variable '{var_name}' not found.", Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Unknown state variable", partial=True), ) return ( f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}", Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="get_state_variable cost", partial=True), ) def get_call_graph(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: """Handle GET_CALL_GRAPH action.""" if ctx._is_repeated(qkey): return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) cg = ctx._contract.get("call_graph", {}) cg_str = "; ".join(f"{fn} → [{', '.join(callees)}]" for fn, callees in cg.items()) return ( f"Call graph: {cg_str}", Reward(value=ActionType.GET_CALL_GRAPH.cost, reason="get_call_graph cost", partial=True), ) def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: """Handle SUBMIT action for Task 1. Expected params --------------- function_name : str – name of the vulnerable function vulnerability_type: str – short description of the vulnerability """ if ctx._done: return ( "Only ONE submission is allowed.", Reward(value=ActionType.RESUBMIT.cost, reason="Second submit_function attempt", partial=False), ) fn_name = params.get("function_name", "").strip() vuln_type = params.get("vulnerability_type", "").strip() if not fn_name or not vuln_type: return ( "submit_function requires both 'function_name' and " "'vulnerability_type' in params.", Reward(value=0.0, reason="Malformed submission", partial=False), ) ctx._done = True score = ctx._grader.grade(fn_name, vuln_type, ctx._step_count, ctx._cummulative_cost) return (f"Correct Answer: {ctx._grader.get_canonical_answer}"), Reward( value=score, reason=f"submit_function score={score:.1f}", partial=False, ) def unknown_action(ctx: Any, qkey: str, params: Dict, action_type: str) -> Tuple[str, Reward]: """Fallback for unknown actions.""" return ( f"Unknown action type: {action_type}", Reward(value=ActionType.UNKNOWN.cost, reason="Unknown action", partial=True), )