File size: 7,037 Bytes
08c19c7
 
 
 
9c888b7
 
 
08c19c7
 
 
 
 
 
 
 
 
 
671787b
08c19c7
9c888b7
 
 
 
 
08c19c7
 
 
 
 
 
9c888b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
08c19c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c888b7
08c19c7
 
 
 
 
 
 
 
9c888b7
 
 
08c19c7
9c888b7
 
 
 
 
 
 
 
 
 
 
cf983b8
9c888b7
 
08c19c7
 
9c888b7
 
 
 
 
 
 
 
 
 
 
08c19c7
 
9c888b7
 
 
 
 
 
 
 
 
 
 
 
08c19c7
9c888b7
 
 
 
 
 
 
 
 
 
 
 
 
 
7203787
 
 
 
 
 
 
 
056cf7b
 
7203787
056cf7b
7203787
056cf7b
 
 
7203787
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
"""
data_loader.py
--------------
Loads and indexes smart contract data from JSON files.

Task 1 helpers  – vulnerable function sampling
Task 2 helpers  – property function sampling, natspec, similar-rule lookup
"""

import json
import os
import random
from typing import Any, Dict, List, Optional, Tuple


DATA_DIR = os.path.join(os.path.dirname(__file__))
DEFAULT_CONTRACTS_FILE = os.path.join(DATA_DIR, "contracts.json")
DEFAULT_CSV_PATH = os.path.join(DATA_DIR, "properties.csv")


# ────────────────────────────────────────────────────────────────
# Core loaders
# ────────────────────────────────────────────────────────────────

def load_contracts(path: str = DEFAULT_CONTRACTS_FILE) -> List[Dict[str, Any]]:
    """Load and return all contracts from the JSON dataset."""
    with open(path, "r") as f:
        return json.load(f)


def get_function_by_name(
    contract: Dict[str, Any], name: str
) -> Optional[Dict[str, Any]]:
    """Case-insensitive function lookup within a contract."""
    for fn in contract.get("functions", []):
        if fn["name"].lower() == name.lower():
            return fn
    return None


def get_state_variable_by_name(
    contract: Dict[str, Any], name: str
) -> Optional[Dict[str, Any]]:
    """Case-insensitive state variable lookup."""
    for sv in contract.get("state_variables", []):
        if sv["name"].lower() == name.lower():
            return sv
    return None


def list_function_names(contract: Dict[str, Any]) -> List[str]:
    """Return all function names in the contract."""
    return [fn["name"] for fn in contract.get("functions", [])]


def list_state_variable_names(contract: Dict[str, Any]) -> List[str]:
    """Return all state variable names."""
    return [sv["name"] for sv in contract.get("state_variables", [])]


# ────────────────────────────────────────────────────────────────
# Task 1 helpers
# ────────────────────────────────────────────────────────────────

def get_all_vulnerable_entries(
    contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
    """
    Returns a flat list of (contract, function) pairs where
    function['vulnerable'] is True.
    """
    entries = []
    for contract in contracts:
        for fn in contract.get("functions", []):
            if fn.get("vulnerable", False):
                entries.append((contract, fn))
    return entries


def sample_episode(
    contracts: List[Dict[str, Any]],
    rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Randomly selects one (contract, vulnerable_function) pair for Task 1."""
    if rng is None:
        rng = random.Random()
    entries = get_all_vulnerable_entries(contracts)
    if not entries:
        raise ValueError("No vulnerable functions found in dataset.")
    return rng.choice(entries)


# ────────────────────────────────────────────────────────────────
# Task 2 helpers
# ────────────────────────────────────────────────────────────────

def get_all_property_entries(
    contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
    """
    Returns a flat list of (contract, function) pairs where
    function['property'] is not None.
    Used by Task 2 to populate the episode pool.
    """
    entries = []
    for contract in contracts:
        for fn in contract.get("functions", []):
            if fn.get("property", None) is not None and fn.get("vulnerable", False) is False:
                entries.append((contract, fn))
    return entries


def sample_property_episode(
    contracts: List[Dict[str, Any]],
    rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Randomly selects one (contract, function-with-property) pair for Task 2."""
    if rng is None:
        rng = random.Random()
    entries = get_all_property_entries(contracts)
    if not entries:
        raise ValueError("No functions with properties found in dataset.")
    return rng.choice(entries)


def get_related_functions(
    contract: Dict[str, Any],
    function_name: str,
) -> List[str]:
    """
    Returns function names that are related to the given function:
    - Functions that it calls (from call_graph)
    - Functions that call it (reverse call_graph lookup)
    """
    name_lower = function_name.lower()
    cg: Dict[str, List[str]] = contract.get("call_graph", {})
    related = set()

    # Direct callees (functions called by this function)
    for callee in cg.get(function_name, []):
        # Only include callees that are also functions in this contract
        if get_function_by_name(contract, callee) is not None:
            related.add(callee)

    # Reverse: functions that call this function
    for caller_name, callees in cg.items():
        if any(c.lower() == name_lower for c in callees):
            if get_function_by_name(contract, caller_name) is not None:
                related.add(caller_name)

    return sorted(related)

# ────────────────────────────────────────────────────────────────
# Task 3 helpers
# ────────────────────────────────────────────────────────────────

def get_all_task3_entries(
    contracts: List[Dict[str, Any]],
) -> List[Tuple[Dict[str, Any], Dict[str, Any]]]:
    """
    Returns (contract, function) pairs where function is vulnerable
    and has a property field.
    """
    vulnerable_entries = get_all_vulnerable_entries(contracts)
    entries = []
    for contract, fn in vulnerable_entries:
        if fn.get("property", None) is not None:
            entries.append((contract, fn))
    return entries


def sample_task3_episode(
    contracts: List[Dict[str, Any]],
    rng: Optional[random.Random] = None,
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
    """Randomly selects one (contract, vulnerable_function) pair for Task 3."""
    if rng is None:
        rng = random.Random()
    entries = get_all_task3_entries(contracts)
    if not entries:
        raise ValueError("No Task 3 entries found in dataset.")
    return rng.choice(entries)