File size: 7,313 Bytes
6abc8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8eade
6abc8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8eade
 
 
 
 
 
 
 
6abc8c5
 
 
 
 
 
 
 
 
 
be8eade
 
 
 
 
 
 
 
 
 
6abc8c5
be8eade
 
 
6abc8c5
 
be8eade
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6abc8c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""Ephemeral generated app sandbox operations."""

from __future__ import annotations

import difflib
import json
from pathlib import Path
from typing import Any

try:
    from ..models import CyberSecurityOWASPState
    from ..safety import is_local_route
    from ..validators import is_path_allowed, simulate_request
except ImportError:  # pragma: no cover
    from models import CyberSecurityOWASPState
    from safety import is_local_route
    from validators import is_path_allowed, simulate_request


class AppSandbox:
    """Encapsulates all generated workspace reads, patches, and local requests."""

    def __init__(self, state: CyberSecurityOWASPState):
        self.state = state

    @property
    def workspace(self) -> Path:
        return Path(str(self.state.hidden_facts["workspace"]))

    def read_file(self, path: str) -> str:
        return self._resolve_path(path).read_text(encoding="utf-8")

    def search_code(self, query: str) -> str:
        if not query:
            raise ValueError("query is required")
        results: list[str] = []
        for rel in self.state.hidden_facts.get("editable_files", []):
            path = self.workspace / rel
            text = path.read_text(encoding="utf-8")
            for idx, line in enumerate(text.splitlines(), start=1):
                if query.lower() in line.lower():
                    results.append(f"{rel}:{idx}: {line}")
        return "\n".join(results) or "No matches."

    def patch_file(self, path: str, *, content: str | None = None, diff: str | None = None) -> dict[str, str]:
        target = self._resolve_path(path, write=True)
        before = target.read_text(encoding="utf-8")
        if content is not None:
            target.write_text(content, encoding="utf-8")
        else:
            self._apply_unified_diff(target, diff or "")
        after = target.read_text(encoding="utf-8")
        patch_diff = "".join(
            difflib.unified_diff(
                before.splitlines(True),
                after.splitlines(True),
                fromfile=path,
                tofile=path,
            )
        )
        self.state.patch_diff = patch_diff
        self.state.patch_attempt_count += 1
        files_touched = self.state.metrics.setdefault("files_touched", [])
        if path not in files_touched:
            files_touched.append(path)
        return {"path": path, "diff": patch_diff}

    def read_openapi(self) -> str:
        routes = self.state.visible_facts.get("workspace_summary", {}).get("routes", [])
        paths: dict[str, Any] = {}
        for route in routes:
            paths.setdefault(route["path"], {})[route["method"].lower()] = {
                "x-public": bool(route.get("public", False))
            }
        return json.dumps(
            {
                "openapi": "3.1.0",
                "info": {"title": "Generated invoices app", "version": "0.1.0"},
                "paths": paths,
            },
            indent=2,
            sort_keys=True,
        )

    def send_local_request(self, method: str, path: str, user_id: str | None = None) -> dict[str, Any]:
        if not is_local_route(path):
            raise ValueError("send_local_request only accepts local route paths")
        response = simulate_request(self.state, method, path, user_id)
        trace_id = self._record_request_trace(
            method=method,
            path=path,
            user_id=user_id,
            status=int(response.get("status", 0) or 0),
        )
        return {"trace_id": trace_id, **response}

    def compare_identities(
        self,
        method: str,
        path: str,
        first_user_id: str,
        second_user_id: str,
    ) -> dict[str, Any]:
        if not is_local_route(path):
            raise ValueError("compare_identities only accepts local route paths")
        first = simulate_request(self.state, method, path, first_user_id)
        second = simulate_request(self.state, method, path, second_user_id)
        trace_id = self._record_request_trace(
            method=method,
            path=path,
            user_id=first_user_id,
            status=int(first.get("status", 0) or 0),
            comparison_user_id=second_user_id,
            comparison_status=int(second.get("status", 0) or 0),
        )
        return {
            "trace_id": trace_id,
            "first": first,
            "second": second,
        }

    def _record_request_trace(
        self,
        *,
        method: str,
        path: str,
        user_id: str | None,
        status: int,
        comparison_user_id: str | None = None,
        comparison_status: int | None = None,
    ) -> str:
        trace_id = f"req_{len(self.state.request_trace) + 1:03d}"
        hidden = self.state.hidden_facts
        unauthorized_success = (
            str(hidden.get("other_invoice_id", "")) in path
            and user_id == hidden.get("owner_user_id")
            and status == 200
        )
        if comparison_user_id is not None and comparison_status is not None:
            unauthorized_success = unauthorized_success or (
                str(hidden.get("other_invoice_id", "")) in path
                and comparison_user_id == hidden.get("owner_user_id")
                and comparison_status == 200
            )
        self.state.request_trace.append(
            {
                "trace_id": trace_id,
                "method": method.upper(),
                "path": path,
                "user_id": user_id,
                "status": status,
                "comparison_user_id": comparison_user_id,
                "comparison_status": comparison_status,
                "unauthorized_success": unauthorized_success,
            }
        )
        return trace_id

    def _resolve_path(self, path: str, *, write: bool = False) -> Path:
        allowed, normalized_or_error = is_path_allowed(self.state, path, write=write)
        if not allowed:
            raise ValueError(normalized_or_error)
        return self.workspace / normalized_or_error

    def _apply_unified_diff(self, path: Path, diff: str) -> None:
        if not diff.strip():
            raise ValueError("diff or content is required")
        original = path.read_text(encoding="utf-8").splitlines(True)
        output: list[str] = []
        old_index = 0
        lines = diff.splitlines(True)
        i = 0
        while i < len(lines):
            line = lines[i]
            if not line.startswith("@@"):
                i += 1
                continue
            old_start = int(line.split()[1].split(",")[0][1:])
            output.extend(original[old_index : old_start - 1])
            old_index = old_start - 1
            i += 1
            while i < len(lines) and not lines[i].startswith("@@"):
                hunk_line = lines[i]
                if hunk_line.startswith(" "):
                    output.append(original[old_index])
                    old_index += 1
                elif hunk_line.startswith("-"):
                    old_index += 1
                elif hunk_line.startswith("+"):
                    output.append(hunk_line[1:])
                elif hunk_line.startswith("\\"):
                    pass
                i += 1
        output.extend(original[old_index:])
        path.write_text("".join(output), encoding="utf-8")