File size: 9,177 Bytes
ba4f7c6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | """
Product automaton: GridWorld Γ BΓΌchi_1 Γ ... Γ BΓΌchi_n
A product state is (grid_pos, aut_state_1, ..., aut_state_n).
We build the graph lazily via BFS, then run Tarjan's SCC algorithm.
"""
from collections import defaultdict, deque
from typing import Dict, FrozenSet, List, Optional, Set, Tuple
from .grid_world import GridWorld
from .automata import BuchiAut
# A product state is a tuple: (grid_pos, q1, q2, ..., qn)
ProductState = tuple
class ProductGraph:
def __init__(self, grid: GridWorld, automata: List[BuchiAut]):
self.grid = grid
self.automata = automata
self.n_aut = len(automata)
# Initial product state
init_aut = tuple(a.initial for a in automata)
self.initial: ProductState = (grid.start,) + init_aut
# Build graph
self.states: List[ProductState] = []
self.state_index: Dict[ProductState, int] = {}
self.adj: Dict[int, List[int]] = defaultdict(list) # forward edges
self.radj: Dict[int, List[int]] = defaultdict(list) # reverse edges
self._build()
# ββ graph construction ββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _build(self):
queue = deque([self.initial])
self._add_state(self.initial)
while queue:
ps = queue.popleft()
src_idx = self.state_index[ps]
grid_pos = ps[0]
aut_states = ps[1:]
label = self.grid.label(grid_pos)
for _, next_pos in self.grid.successors(grid_pos):
next_label = self.grid.label(next_pos)
# Advance each automaton on next_label (transition happens
# when entering the next cell, consistent with standard semantics)
next_aut = []
valid = True
for i, aut in enumerate(self.automata):
nq = aut.step(aut_states[i], next_label)
if nq is None:
valid = False
break
next_aut.append(nq)
if not valid:
continue
next_ps: ProductState = (next_pos,) + tuple(next_aut)
if next_ps not in self.state_index:
self._add_state(next_ps)
queue.append(next_ps)
dst_idx = self.state_index[next_ps]
self.adj[src_idx].append(dst_idx)
self.radj[dst_idx].append(src_idx)
def _add_state(self, ps: ProductState) -> int:
idx = len(self.states)
self.states.append(ps)
self.state_index[ps] = idx
return idx
# ββ Tarjan's SCC βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def compute_sccs(self) -> List[List[int]]:
"""Returns list of SCCs (each a list of state indices), largest first."""
n = len(self.states)
index_counter = [0]
stack = []
lowlink = {}
index = {}
on_stack = {}
sccs = []
def strongconnect(v):
index[v] = index_counter[0]
lowlink[v] = index_counter[0]
index_counter[0] += 1
stack.append(v)
on_stack[v] = True
for w in self.adj[v]:
if w not in index:
strongconnect(w)
lowlink[v] = min(lowlink[v], lowlink[w])
elif on_stack.get(w):
lowlink[v] = min(lowlink[v], index[w])
if lowlink[v] == index[v]:
scc = []
while True:
w = stack.pop()
on_stack[w] = False
scc.append(w)
if w == v:
break
sccs.append(scc)
import sys
sys.setrecursionlimit(100000)
for v in range(n):
if v not in index:
strongconnect(v)
return sccs
# ββ SCC reward analysis βββββββββββββββββββββββββββββββββββββββββββββββββββ
def scc_satisfied_specs(self, scc: List[int], rewards: List[float]) -> Tuple[float, Set[int]]:
"""
For an SCC, compute which specs have their accepting states inside it.
Returns (total_reward, set_of_satisfied_spec_indices).
"""
satisfied = set()
for idx in scc:
ps = self.states[idx]
aut_states = ps[1:]
for i, aut in enumerate(self.automata):
if aut.is_accepting(aut_states[i]):
satisfied.add(i)
total = sum(rewards[i] for i in satisfied)
return total, satisfied
def is_nontrivial_scc(self, scc: List[int]) -> bool:
"""An SCC is nontrivial if it has >1 state, or 1 state with a self-loop."""
if len(scc) > 1:
return True
v = scc[0]
return v in self.adj[v]
# ββ reachability βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def reachable_from_initial(self) -> Set[int]:
visited = set()
queue = deque([self.state_index[self.initial]])
while queue:
v = queue.popleft()
if v in visited:
continue
visited.add(v)
for w in self.adj[v]:
if w not in visited:
queue.append(w)
return visited
def bfs_path(self, src: int, targets: Set[int]) -> Optional[List[int]]:
"""BFS from src to any state in targets. Returns list of state indices."""
if src in targets:
return [src]
parent = {src: None}
queue = deque([src])
while queue:
v = queue.popleft()
for w in self.adj[v]:
if w not in parent:
parent[w] = v
if w in targets:
# reconstruct
path = []
cur = w
while cur is not None:
path.append(cur)
cur = parent[cur]
return list(reversed(path))
queue.append(w)
return None
def find_cycle_through(self, scc_set: Set[int], required_accepting: List[Set[int]]) -> Optional[List[int]]:
"""
Find a cycle within the SCC that passes through at least one accepting
state for each required spec.
Returns a list of state indices forming the cycle (first == last).
"""
# Restrict graph to SCC nodes
# Strategy: chain BFS paths through each required accepting set
# Start from any state in scc, visit a state in required_accepting[0],
# then required_accepting[1], ..., then return to start.
if not scc_set:
return None
start = next(iter(scc_set))
# Build checkpoints: for each spec, one state in scc that is accepting
checkpoints = []
for acc_set in required_accepting:
candidates = acc_set & scc_set
if candidates:
checkpoints.append(next(iter(candidates)))
if not checkpoints:
# trivial cycle: just loop at start (if self-loop exists)
if start in self.adj.get(start, []):
return [start, start]
# find any 2-cycle
path = self._bfs_in_scc(start, {start}, scc_set)
return path
# chain: start -> cp0 -> cp1 -> ... -> cpN -> start
waypoints = [start] + checkpoints + [start]
full_path = []
for i in range(len(waypoints) - 1):
seg = self._bfs_in_scc(waypoints[i], {waypoints[i + 1]}, scc_set)
if seg is None:
return None
if full_path:
full_path.extend(seg[1:]) # skip duplicate junction
else:
full_path.extend(seg)
return full_path
def _bfs_in_scc(self, src: int, targets: Set[int], scc_set: Set[int]) -> Optional[List[int]]:
"""BFS from src to any target, restricted to scc_set."""
if src in targets:
return [src]
parent = {src: None}
queue = deque([src])
while queue:
v = queue.popleft()
for w in self.adj[v]:
if w in scc_set and w not in parent:
parent[w] = v
if w in targets:
path = []
cur = w
while cur is not None:
path.append(cur)
cur = parent[cur]
return list(reversed(path))
queue.append(w)
return None
|