ajaxwin
refactor: Improved grading logic for task 2
f78cba2
"""
Actions for Task 2: Property Inference.
Defines the logic for each action type that the agent can take in the environment, including
querying function code, NatSpec, related functions, and submitting the inferred property.
ctx: context object containing episode state and data
qkey: a unique key representing the specific query (used for tracking repeated queries)
params: parameters for the action, such as the submitted property text for SUBMIT_PROPERTY
"""
from typing import Any, Dict, Tuple
from data.data_loader import get_function_by_name, get_related_functions
from utils import PropertyRetriever
from env.schemas import ActionType, Reward
PropertyRetrieverInstance = PropertyRetriever() # Load once at module level
# TODO: Can separate into signature, visiblity
def get_function_code(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_FUNCTION_CODE action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query")
fn = ctx._target_fn
code = fn.get("code", "// no code available")
return (
code,
Reward(value=ActionType.GET_FUNCTION_CODE.cost, reason="get_function_code cost"),
)
# TODO: Can separate comment and output_property(output_comment)
def get_function_natspec(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_FUNCTION_NATSPEC action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query")
fn = ctx._target_fn
name = fn["name"]
natspec = fn.get("natspec") or fn.get("comment") or "No NatSpec available."
out_prop = fn.get("output_property", "")
result = f"NatSpec for '{name}':\n{natspec}"
if out_prop:
result += f"\n\nExpected output: {out_prop}"
return result, Reward(value=ActionType.GET_FILE_NATSPEC.cost, reason="get_function_natspec cost")
def get_file_natspec(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_FILE_NATSPEC action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query")
meta = ctx._contract.get("metadata", {})
natspec = meta.get("natspec") or meta.get("description", "No file NatSpec available.")
return (
f"File NatSpec for {ctx._contract['contract_name']}:\n{natspec}",
Reward(value=ActionType.GET_FILE_NATSPEC.cost, reason="get_file_natspec cost"),
)
def get_related_functions_action(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_RELATED_FUNCTIONS action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query")
name = ctx._target_fn["name"]
related = get_related_functions(ctx._contract, name)
if not related:
text = f"No related functions found for '{name}'."
else:
summaries = []
for rn in related:
rfn = get_function_by_name(ctx._contract, rn)
if rfn:
sig = rfn.get("signature", rn)
comment = rfn.get("comment", "")
summaries.append(f" • {sig}{comment}")
text = f"Related functions for '{name}':\n" + "\n".join(summaries)
return text, Reward(value=ActionType.GET_RELATED_FUNCTIONS.cost, reason="get_related_functions cost")
def get_signature(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_SIGNATURE action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query")
fn = ctx._target_fn
sig = fn.get("signature")
return sig, Reward(value=ActionType.GET_SIGNATURE.cost, reason="get_signature cost")
def get_similar_rule_action(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle GET_SIMILAR_RULE action."""
if ctx._is_repeated(qkey):
return "Repeated query.", Reward(value=ActionType.REPEATED.cost, reason="Repeated query")
PropertyRetrieverInstance.load_model() # Ensure model is loaded before querying
similar_rule = PropertyRetrieverInstance.get_similar_property(ctx._target_fn["code"])
if similar_rule is None:
return (
"No similar rule available for this function.",
Reward(value=-0.20, reason="get_similar_rule cost (not found)"),
)
return similar_rule, Reward(value=ActionType.GET_SIMILAR_RULE.cost, reason="get_similar_rule cost")
def submit_property(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Reward]:
"""Handle SUBMIT_PROPERTY action for Task 2.
Expected params
---------------
property : str – natural-language property describing the function's behaviour
"""
submitted_property = params.get("property", "").strip()
if not submitted_property:
return (
"submit_property requires a non-empty 'property' string in params.",
Reward(value=ActionType.RESUBMIT.cost, reason="Malformed submission", partial=False),
)
ctx._done = True
score, confidence = ctx._grader.grade(submitted_property, ctx._step_count, ctx._cum_reward)
return "", Reward(
value=score,
reason=f"submit_property confidence={confidence} score={score:.3f}",
partial=False,
)
def unknown_action(ctx: Any, qkey: str, params: Dict, action_type: str) -> Tuple[str, Reward]:
"""Fallback for unknown actions."""
ctx._done = True
return (
f"Unknown action type: '{action_type}'. Valid: {[a.value for a in ActionType]}, \
Reset environment to start again.",
Reward(value=ActionType.UNKNOWN.cost, reason="Unknown action"),
)