Spaces:
Running
Running
File size: 7,037 Bytes
08c19c7 9c888b7 08c19c7 671787b 08c19c7 9c888b7 08c19c7 9c888b7 08c19c7 9c888b7 08c19c7 9c888b7 08c19c7 9c888b7 cf983b8 9c888b7 08c19c7 9c888b7 08c19c7 9c888b7 08c19c7 9c888b7 7203787 056cf7b 7203787 056cf7b 7203787 056cf7b 7203787 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | """
data_loader.py
--------------
Loads and indexes smart contract data from JSON files.
Task 1 helpers β vulnerable function sampling
Task 2 helpers β property function sampling, natspec, similar-rule lookup
"""
import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple
DATA_DIR = os.path.join(os.path.dirname(__file__))
DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
DEFAULT_CSV_PATH = os.path.join(DATA_DIR, "properties.csv")
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Core loaders
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
"""Load and return all contracts from the JSON dataset."""
with open(path, "r") as f:
return json.load(f)
def get_function_by_name(
contract: Dict[str, Any], name: str
) -> Optional[Dict[str, Any]]:
"""Case-insensitive function lookup within a contract."""
for fn in contract.get("functions", []):
if fn["name"].lower() == name.lower():
return fn
return None
def get_state_variable_by_name(
contract: Dict[str, Any], name: str
) -> Optional[Dict[str, Any]]:
"""Case-insensitive state variable lookup."""
for sv in contract.get("state_variables", []):
if sv["name"].lower() == name.lower():
return sv
return None
def list_function_names(contract: Dict[str, Any]) -> List[str]:
"""Return all function names in the contract."""
return [fn["name"] for fn in contract.get("functions", [])]
def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
"""Return all state variable names."""
return [sv["name"] for sv in contract.get("state_variables", [])]
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Task 1 helpers
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_all_vulnerable_entries(
contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
"""
Returns a flat list of (contract, function) pairs where
function['vulnerable'] is True.
"""
entries = []
for contract in contracts:
for fn in contract.get("functions", []):
if fn.get("vulnerable", False):
entries.append((contract, fn))
return entries
def sample_episode(
contracts: List[Dict[str, Any]],
rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Randomly selects one (contract, vulnerable_function) pair for Task 1."""
if rng is None:
rng = random.Random()
entries = get_all_vulnerable_entries(contracts)
if not entries:
raise ValueError("No vulnerable functions found in dataset.")
return rng.choice(entries)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Task 2 helpers
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_all_property_entries(
contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
"""
Returns a flat list of (contract, function) pairs where
function['property'] is not None.
Used by Task 2 to populate the episode pool.
"""
entries = []
for contract in contracts:
for fn in contract.get("functions", []):
if fn.get("property", None) is not None and fn.get("vulnerable", False) is False:
entries.append((contract, fn))
return entries
def sample_property_episode(
contracts: List[Dict[str, Any]],
rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Randomly selects one (contract, function-with-property) pair for Task 2."""
if rng is None:
rng = random.Random()
entries = get_all_property_entries(contracts)
if not entries:
raise ValueError("No functions with properties found in dataset.")
return rng.choice(entries)
def get_related_functions(
contract: Dict[str, Any],
function_name: str,
) -> List[str]:
"""
Returns function names that are related to the given function:
- Functions that it calls (from call_graph)
- Functions that call it (reverse call_graph lookup)
"""
name_lower = function_name.lower()
cg: Dict[str, List[str]] = contract.get("call_graph", {})
related = set()
# Direct callees (functions called by this function)
for callee in cg.get(function_name, []):
# Only include callees that are also functions in this contract
if get_function_by_name(contract, callee) is not None:
related.add(callee)
# Reverse: functions that call this function
for caller_name, callees in cg.items():
if any(c.lower() == name_lower for c in callees):
if get_function_by_name(contract, caller_name) is not None:
related.add(caller_name)
return sorted(related)
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
# Task 3 helpers
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def get_all_task3_entries(
contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
"""
Returns (contract, function) pairs where function is vulnerable
and has a property field.
"""
vulnerable_entries = get_all_vulnerable_entries(contracts)
entries = []
for contract, fn in vulnerable_entries:
if fn.get("property", None) is not None:
entries.append((contract, fn))
return entries
def sample_task3_episode(
contracts: List[Dict[str, Any]],
rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""Randomly selects one (contract, vulnerable_function) pair for Task 3."""
if rng is None:
rng = random.Random()
entries = get_all_task3_entries(contracts)
if not entries:
raise ValueError("No Task 3 entries found in dataset.")
return rng.choice(entries)
|