Spaces:
Running
Running
| """Actions for Task 1: Targeted Vulnerability Detection. | |
| """ | |
| from typing import Any, Dict, Tuple | |
| from env.schemas import ActionType, Reward | |
| from data.data_loader import ( | |
| list_function_names, | |
| get_function_by_name, | |
| list_state_variable_names, | |
| get_state_variable_by_name, | |
| ) | |
| def list_functions(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: | |
| """Handle LIST_FUNCTIONS action.""" | |
| if ctx._is_repeated(qkey): | |
| return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) | |
| names = list_function_names(ctx._contract) | |
| return ( | |
| f"Functions in {ctx._contract['contract_name']}: {', '.join(names)}", | |
| Reward(value=ActionType.LIST_FUNCTIONS.cost, reason="list_functions cost", partial=True), | |
| ) | |
| def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: | |
| """Handle GET_FUNCTION_CODE action.""" | |
| fn_name = params.get("function_name", "") | |
| if ctx._is_repeated(qkey): | |
| return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) | |
| fn = get_function_by_name(ctx._contract, fn_name) | |
| if fn is None: | |
| return ( | |
| f"Function '{fn_name}' not found. Available: {list_function_names(ctx._contract)}", | |
| Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Wrong/unknown function name", partial=True), | |
| ) | |
| code = fn.get("code", "// no code available") | |
| return ( | |
| f"// {fn['name']}\n{code}", | |
| Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="Fetched code", partial=True), | |
| ) | |
| def get_function_summary(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: | |
| """Handle GET_FUNCTION_SUMMARY action.""" | |
| fn_name = params.get("function_name", "") | |
| if ctx._is_repeated(qkey): | |
| return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) | |
| fn = get_function_by_name(ctx._contract, fn_name) | |
| if fn is None: | |
| return ( | |
| f"Function '{fn_name}' not found.", | |
| Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Wrong function name", partial=True), | |
| ) | |
| comment = fn.get("comment", "No summary available.") | |
| return ( | |
| f"Summary of '{fn['name']}': {comment}", | |
| Reward(value=ActionType.GET_FUNCTION_SUMMARY.cost, reason="Fetched summary", partial=True), | |
| ) | |
| def get_file_metadata(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: | |
| """Handle GET_FILE_METADATA action.""" | |
| if ctx._is_repeated(qkey): | |
| return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) | |
| meta = ctx._contract.get("metadata", {}) | |
| result = ( | |
| f"Contract: {ctx._contract['contract_name']} | " | |
| f"Solidity: {meta.get('solidity_version', 'N/A')} | " | |
| f"Description: {meta.get('description', 'N/A')}" | |
| ) | |
| return result, Reward(value=ActionType.GET_FILE_METADATA.cost, reason="get_file_metadata cost", partial=True) | |
| def get_state_variable(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: | |
| """Handle GET_STATE_VARIABLE action.""" | |
| var_name = params.get("variable_name", "") | |
| if ctx._is_repeated(qkey): | |
| return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) | |
| if not var_name: | |
| names = list_state_variable_names(ctx._contract) | |
| return ( | |
| f"State variables: {', '.join(names)}", | |
| Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Listed state variables", partial=True), | |
| ) | |
| sv = get_state_variable_by_name(ctx._contract, var_name) | |
| if sv is None: | |
| return ( | |
| f"Variable '{var_name}' not found.", | |
| Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="Unknown state variable", partial=True), | |
| ) | |
| return ( | |
| f"{sv['type']} {sv['visibility']} {sv['name']}: {sv.get('description', '')}", | |
| Reward(value=ActionType.GET_STATE_VARIABLE.cost, reason="get_state_variable cost", partial=True), | |
| ) | |
| def get_call_graph(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: | |
| """Handle GET_CALL_GRAPH action.""" | |
| if ctx._is_repeated(qkey): | |
| return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query", partial=True) | |
| cg = ctx._contract.get("call_graph", {}) | |
| cg_str = "; ".join(f"{fn} → [{', '.join(callees)}]" for fn, callees in cg.items()) | |
| return ( | |
| f"Call graph: {cg_str}", | |
| Reward(value=ActionType.GET_CALL_GRAPH.cost, reason="get_call_graph cost", partial=True), | |
| ) | |
| def submit(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]: | |
| """Handle SUBMIT action for Task 1. | |
| Expected params | |
| --------------- | |
| function_name : str – name of the vulnerable function | |
| vulnerability_type: str – short description of the vulnerability | |
| """ | |
| if ctx._done: | |
| return ( | |
| "Only ONE submission is allowed.", | |
| Reward(value=ActionType.RESUBMIT.cost, | |
| reason="Second submit_function attempt", | |
| partial=False), | |
| ) | |
| fn_name = params.get("function_name", "").strip() | |
| vuln_type = params.get("vulnerability_type", "").strip() | |
| if not fn_name or not vuln_type: | |
| return ( | |
| "submit_function requires both 'function_name' and " | |
| "'vulnerability_type' in params.", | |
| Reward(value=0.0, reason="Malformed submission", partial=False), | |
| ) | |
| ctx._done = True | |
| score = ctx._grader.grade(fn_name, vuln_type, ctx._step_count, ctx._cummulative_cost) | |
| return (f"Correct Answer: {ctx._grader.get_canonical_answer}"), Reward( | |
| value=score, | |
| reason=f"submit_function score={score:.1f}", | |
| partial=False, | |
| ) | |
| def unknown_action(ctx: Any, qkey: str, params: Dict, action_type: str) -> Tuple[str, Reward]: | |
| """Fallback for unknown actions.""" | |
| return ( | |
| f"Unknown action type: {action_type}", | |
| Reward(value=ActionType.UNKNOWN.cost, reason="Unknown action", partial=True), | |
| ) |