Spaces:
Sleeping
Sleeping
File size: 7,313 Bytes
6abc8c5 be8eade 6abc8c5 be8eade 6abc8c5 be8eade 6abc8c5 be8eade 6abc8c5 be8eade 6abc8c5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """Ephemeral generated app sandbox operations."""
from __future__ import annotations
import difflib
import json
from pathlib import Path
from typing import Any
try:
from ..models import CyberSecurityOWASPState
from ..safety import is_local_route
from ..validators import is_path_allowed, simulate_request
except ImportError: # pragma: no cover
from models import CyberSecurityOWASPState
from safety import is_local_route
from validators import is_path_allowed, simulate_request
class AppSandbox:
"""Encapsulates all generated workspace reads, patches, and local requests."""
def __init__(self, state: CyberSecurityOWASPState):
self.state = state
@property
def workspace(self) -> Path:
return Path(str(self.state.hidden_facts["workspace"]))
def read_file(self, path: str) -> str:
return self._resolve_path(path).read_text(encoding="utf-8")
def search_code(self, query: str) -> str:
if not query:
raise ValueError("query is required")
results: list[str] = []
for rel in self.state.hidden_facts.get("editable_files", []):
path = self.workspace / rel
text = path.read_text(encoding="utf-8")
for idx, line in enumerate(text.splitlines(), start=1):
if query.lower() in line.lower():
results.append(f"{rel}:{idx}: {line}")
return "\n".join(results) or "No matches."
def patch_file(self, path: str, *, content: str | None = None, diff: str | None = None) -> dict[str, str]:
target = self._resolve_path(path, write=True)
before = target.read_text(encoding="utf-8")
if content is not None:
target.write_text(content, encoding="utf-8")
else:
self._apply_unified_diff(target, diff or "")
after = target.read_text(encoding="utf-8")
patch_diff = "".join(
difflib.unified_diff(
before.splitlines(True),
after.splitlines(True),
fromfile=path,
tofile=path,
)
)
self.state.patch_diff = patch_diff
self.state.patch_attempt_count += 1
files_touched = self.state.metrics.setdefault("files_touched", [])
if path not in files_touched:
files_touched.append(path)
return {"path": path, "diff": patch_diff}
def read_openapi(self) -> str:
routes = self.state.visible_facts.get("workspace_summary", {}).get("routes", [])
paths: dict[str, Any] = {}
for route in routes:
paths.setdefault(route["path"], {})[route["method"].lower()] = {
"x-public": bool(route.get("public", False))
}
return json.dumps(
{
"openapi": "3.1.0",
"info": {"title": "Generated invoices app", "version": "0.1.0"},
"paths": paths,
},
indent=2,
sort_keys=True,
)
def send_local_request(self, method: str, path: str, user_id: str | None = None) -> dict[str, Any]:
if not is_local_route(path):
raise ValueError("send_local_request only accepts local route paths")
response = simulate_request(self.state, method, path, user_id)
trace_id = self._record_request_trace(
method=method,
path=path,
user_id=user_id,
status=int(response.get("status", 0) or 0),
)
return {"trace_id": trace_id, **response}
def compare_identities(
self,
method: str,
path: str,
first_user_id: str,
second_user_id: str,
) -> dict[str, Any]:
if not is_local_route(path):
raise ValueError("compare_identities only accepts local route paths")
first = simulate_request(self.state, method, path, first_user_id)
second = simulate_request(self.state, method, path, second_user_id)
trace_id = self._record_request_trace(
method=method,
path=path,
user_id=first_user_id,
status=int(first.get("status", 0) or 0),
comparison_user_id=second_user_id,
comparison_status=int(second.get("status", 0) or 0),
)
return {
"trace_id": trace_id,
"first": first,
"second": second,
}
def _record_request_trace(
self,
*,
method: str,
path: str,
user_id: str | None,
status: int,
comparison_user_id: str | None = None,
comparison_status: int | None = None,
) -> str:
trace_id = f"req_{len(self.state.request_trace) + 1:03d}"
hidden = self.state.hidden_facts
unauthorized_success = (
str(hidden.get("other_invoice_id", "")) in path
and user_id == hidden.get("owner_user_id")
and status == 200
)
if comparison_user_id is not None and comparison_status is not None:
unauthorized_success = unauthorized_success or (
str(hidden.get("other_invoice_id", "")) in path
and comparison_user_id == hidden.get("owner_user_id")
and comparison_status == 200
)
self.state.request_trace.append(
{
"trace_id": trace_id,
"method": method.upper(),
"path": path,
"user_id": user_id,
"status": status,
"comparison_user_id": comparison_user_id,
"comparison_status": comparison_status,
"unauthorized_success": unauthorized_success,
}
)
return trace_id
def _resolve_path(self, path: str, *, write: bool = False) -> Path:
allowed, normalized_or_error = is_path_allowed(self.state, path, write=write)
if not allowed:
raise ValueError(normalized_or_error)
return self.workspace / normalized_or_error
def _apply_unified_diff(self, path: Path, diff: str) -> None:
if not diff.strip():
raise ValueError("diff or content is required")
original = path.read_text(encoding="utf-8").splitlines(True)
output: list[str] = []
old_index = 0
lines = diff.splitlines(True)
i = 0
while i < len(lines):
line = lines[i]
if not line.startswith("@@"):
i += 1
continue
old_start = int(line.split()[1].split(",")[0][1:])
output.extend(original[old_index : old_start - 1])
old_index = old_start - 1
i += 1
while i < len(lines) and not lines[i].startswith("@@"):
hunk_line = lines[i]
if hunk_line.startswith(" "):
output.append(original[old_index])
old_index += 1
elif hunk_line.startswith("-"):
old_index += 1
elif hunk_line.startswith("+"):
output.append(hunk_line[1:])
elif hunk_line.startswith("\\"):
pass
i += 1
output.extend(original[old_index:])
path.write_text("".join(output), encoding="utf-8")
|