""" data_loader.py -------------- Loads and indexes smart contract data from JSON files. Task 1 helpers – vulnerable function sampling Task 2 helpers – property function sampling, natspec, similar-rule lookup """ import json import os import random from typing import Any, Dict, List, Optional, Tuple DATA_DIR = os.path.join(os.path.dirname(__file__)) DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json") DEFAULT_CSV_PATH = os.path.join(DATA_DIR, "properties.csv") # ──────────────────────────────────────────────────────────────── # Core loaders # ──────────────────────────────────────────────────────────────── def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]: """Load and return all contracts from the JSON dataset.""" with open(path, "r") as f: return json.load(f) def get_function_by_name( contract: Dict[str, Any], name: str ) -> Optional[Dict[str, Any]]: """Case-insensitive function lookup within a contract.""" for fn in contract.get("functions", []): if fn["name"].lower() == name.lower(): return fn return None def get_state_variable_by_name( contract: Dict[str, Any], name: str ) -> Optional[Dict[str, Any]]: """Case-insensitive state variable lookup.""" for sv in contract.get("state_variables", []): if sv["name"].lower() == name.lower(): return sv return None def list_function_names(contract: Dict[str, Any]) -> List[str]: """Return all function names in the contract.""" return [fn["name"] for fn in contract.get("functions", [])] def list_state_variable_names(contract: Dict[str, Any]) -> List[str]: """Return all state variable names.""" return [sv["name"] for sv in contract.get("state_variables", [])] # ──────────────────────────────────────────────────────────────── # Task 1 helpers # ──────────────────────────────────────────────────────────────── def get_all_vulnerable_entries( contracts: List[Dict[str, Any]], ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]: """ Returns a flat list of (contract, function) pairs where function['vulnerable'] is True. """ entries = [] for contract in contracts: for fn in contract.get("functions", []): if fn.get("vulnerable", False): entries.append((contract, fn)) return entries def sample_episode( contracts: List[Dict[str, Any]], rng: Optional[random.Random] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Randomly selects one (contract, vulnerable_function) pair for Task 1.""" if rng is None: rng = random.Random() entries = get_all_vulnerable_entries(contracts) if not entries: raise ValueError("No vulnerable functions found in dataset.") return rng.choice(entries) # ──────────────────────────────────────────────────────────────── # Task 2 helpers # ──────────────────────────────────────────────────────────────── def get_all_property_entries( contracts: List[Dict[str, Any]], ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]: """ Returns a flat list of (contract, function) pairs where function['property'] is not None. Used by Task 2 to populate the episode pool. """ entries = [] for contract in contracts: for fn in contract.get("functions", []): if fn.get("property", None) is not None and fn.get("vulnerable", False) is False: entries.append((contract, fn)) return entries def sample_property_episode( contracts: List[Dict[str, Any]], rng: Optional[random.Random] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Randomly selects one (contract, function-with-property) pair for Task 2.""" if rng is None: rng = random.Random() entries = get_all_property_entries(contracts) if not entries: raise ValueError("No functions with properties found in dataset.") return rng.choice(entries) def get_related_functions( contract: Dict[str, Any], function_name: str, ) -> List[str]: """ Returns function names that are related to the given function: - Functions that it calls (from call_graph) - Functions that call it (reverse call_graph lookup) """ name_lower = function_name.lower() cg: Dict[str, List[str]] = contract.get("call_graph", {}) related = set() # Direct callees (functions called by this function) for callee in cg.get(function_name, []): # Only include callees that are also functions in this contract if get_function_by_name(contract, callee) is not None: related.add(callee) # Reverse: functions that call this function for caller_name, callees in cg.items(): if any(c.lower() == name_lower for c in callees): if get_function_by_name(contract, caller_name) is not None: related.add(caller_name) return sorted(related) # ──────────────────────────────────────────────────────────────── # Task 3 helpers # ──────────────────────────────────────────────────────────────── def get_all_task3_entries( contracts: List[Dict[str, Any]], ) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]: """ Returns (contract, function) pairs where function is vulnerable and has a property field. """ vulnerable_entries = get_all_vulnerable_entries(contracts) entries = [] for contract, fn in vulnerable_entries: if fn.get("property", None) is not None: entries.append((contract, fn)) return entries def sample_task3_episode( contracts: List[Dict[str, Any]], rng: Optional[random.Random] = None, ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """Randomly selects one (contract, vulnerable_function) pair for Task 3.""" if rng is None: rng = random.Random() entries = get_all_task3_entries(contracts) if not entries: raise ValueError("No Task 3 entries found in dataset.") return rng.choice(entries)