SmartContractAudit / data /data_loader.py
ajaxwin
task1, task2 evaluated
671787b
"""
data_loader.py
--------------
Loads and indexes smart contract data from JSON files.
Task 1 helpers – vulnerable function sampling
Task 2 helpers – property function sampling, natspec, similar-rule lookup
"""
import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple
DATA_DIR = os.path.join(os.path.dirname(__file__))
DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
DEFAULT_CSV_PATH = os.path.join(DATA_DIR, "properties.csv")
# ────────────────────────────────────────────────────────────────
# Core loaders
# ────────────────────────────────────────────────────────────────
def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
"""Load and return all contracts from the JSON dataset."""
with open(path, "r") as f:
return json.load(f)
def get_function_by_name(
contract: Dict[str, Any], name: str
) -> Optional[Dict[str, Any]]:
"""Case-insensitive function lookup within a contract."""
for fn in contract.get("functions", []):
if fn["name"].lower() == name.lower():
return fn
return None
def get_state_variable_by_name(
contract: Dict[str, Any], name: str
) -> Optional[Dict[str, Any]]:
"""Case-insensitive state variable lookup."""
for sv in contract.get("state_variables", []):
if sv["name"].lower() == name.lower():
return sv
return None
def list_function_names(contract: Dict[str, Any]) -> List[str]:
"""Return all function names in the contract."""
return [fn["name"] for fn in contract.get("functions", [])]
def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
"""Return all state variable names."""
return [sv["name"] for sv in contract.get("state_variables", [])]
# ────────────────────────────────────────────────────────────────
# Task 1 helpers
# ────────────────────────────────────────────────────────────────
def get_all_vulnerable_entries(
contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
"""
Returns a flat list of (contract, function) pairs where
function['vulnerable'] is True.
"""
entries = []
for contract in contracts:
for fn in contract.get("functions", []):
if fn.get("vulnerable", False):
entries.append((contract, fn))
return entries
def sample_episode(
contracts: List[Dict[str, Any]],
rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Randomly selects one (contract, vulnerable_function) pair for Task 1."""
if rng is None:
rng = random.Random()
entries = get_all_vulnerable_entries(contracts)
if not entries:
raise ValueError("No vulnerable functions found in dataset.")
return rng.choice(entries)
# ────────────────────────────────────────────────────────────────
# Task 2 helpers
# ────────────────────────────────────────────────────────────────
def get_all_property_entries(
contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
"""
Returns a flat list of (contract, function) pairs where
function['property'] is not None.
Used by Task 2 to populate the episode pool.
"""
entries = []
for contract in contracts:
for fn in contract.get("functions", []):
if fn.get("property", None) is not None and fn.get("vulnerable", False) is False:
entries.append((contract, fn))
return entries
def sample_property_episode(
contracts: List[Dict[str, Any]],
rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Randomly selects one (contract, function-with-property) pair for Task 2."""
if rng is None:
rng = random.Random()
entries = get_all_property_entries(contracts)
if not entries:
raise ValueError("No functions with properties found in dataset.")
return rng.choice(entries)
def get_related_functions(
contract: Dict[str, Any],
function_name: str,
) -> List[str]:
"""
Returns function names that are related to the given function:
- Functions that it calls (from call_graph)
- Functions that call it (reverse call_graph lookup)
"""
name_lower = function_name.lower()
cg: Dict[str, List[str]] = contract.get("call_graph", {})
related = set()
# Direct callees (functions called by this function)
for callee in cg.get(function_name, []):
# Only include callees that are also functions in this contract
if get_function_by_name(contract, callee) is not None:
related.add(callee)
# Reverse: functions that call this function
for caller_name, callees in cg.items():
if any(c.lower() == name_lower for c in callees):
if get_function_by_name(contract, caller_name) is not None:
related.add(caller_name)
return sorted(related)
# ────────────────────────────────────────────────────────────────
# Task 3 helpers
# ────────────────────────────────────────────────────────────────
def get_all_task3_entries(
contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
"""
Returns (contract, function) pairs where function is vulnerable
and has a property field.
"""
vulnerable_entries = get_all_vulnerable_entries(contracts)
entries = []
for contract, fn in vulnerable_entries:
if fn.get("property", None) is not None:
entries.append((contract, fn))
return entries
def sample_task3_episode(
contracts: List[Dict[str, Any]],
rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Randomly selects one (contract, vulnerable_function) pair for Task 3."""
if rng is None:
rng = random.Random()
entries = get_all_task3_entries(contracts)
if not entries:
raise ValueError("No Task 3 entries found in dataset.")
return rng.choice(entries)