File size: 4,931 Bytes
ba4f7c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Maximum-reward lasso planner.

Given a product automaton and per-spec rewards, finds:
  1. The SCC with maximum total reward that is reachable from the initial state
  2. A lasso path: prefix (initial → SCC) + cycle (within SCC visiting accepting states)

Returns the path as a sequence of grid positions plus a result summary.
"""

from dataclasses import dataclass
from typing import Dict, List, Optional, Set, Tuple

from .grid_world import GridWorld
from .automata import BuchiAut
from .product import ProductGraph


@dataclass
class PlanResult:
    path: List[Tuple[int, int]]       # grid positions (prefix + one full cycle)
    cycle_start_idx: int               # index in path where cycle begins
    satisfied: List[int]               # indices of satisfied specs
    violated: List[int]                # indices of violated specs
    total_reward: float
    max_possible_reward: float
    spec_names: List[str]
    spec_rewards: List[float]
    success: bool
    message: str


def plan(
    grid: GridWorld,
    automata: List[BuchiAut],
    rewards: List[float],
) -> PlanResult:
    spec_names = [a.name for a in automata]
    max_possible = sum(rewards)

    pg = ProductGraph(grid, automata)
    sccs = pg.compute_sccs()
    reachable = pg.reachable_from_initial()

    # Filter to nontrivial SCCs reachable from initial
    init_idx = pg.state_index[pg.initial]
    candidates = []
    for scc in sccs:
        scc_set = set(scc)
        if not any(v in reachable for v in scc):
            continue
        if not pg.is_nontrivial_scc(scc):
            continue
        reward, satisfied_set = pg.scc_satisfied_specs(scc, rewards)
        candidates.append((reward, satisfied_set, scc))

    if not candidates:
        return PlanResult(
            path=[], cycle_start_idx=0,
            satisfied=[], violated=list(range(len(automata))),
            total_reward=0, max_possible_reward=max_possible,
            spec_names=spec_names, spec_rewards=list(rewards),
            success=False,
            message="No reachable accepting cycle found. Check for obstacles blocking all paths.",
        )

    # Pick best SCC
    candidates.sort(key=lambda x: x[0], reverse=True)
    best_reward, satisfied_set, best_scc = candidates[0]
    best_scc_set = set(best_scc)
    violated_set = set(range(len(automata))) - satisfied_set

    # Build accepting state sets per spec (restricted to the best SCC)
    required_accepting = []
    for i, aut in enumerate(automata):
        if i in satisfied_set:
            acc_in_scc = {
                v for v in best_scc_set
                if aut.is_accepting(pg.states[v][1 + i])
            }
            required_accepting.append(acc_in_scc)

    # Find prefix: initial → any state in best SCC
    prefix_path = pg.bfs_path(init_idx, best_scc_set)
    if prefix_path is None:
        return PlanResult(
            path=[], cycle_start_idx=0,
            satisfied=[], violated=list(range(len(automata))),
            total_reward=0, max_possible_reward=max_possible,
            spec_names=spec_names, spec_rewards=list(rewards),
            success=False,
            message="Could not find path to best SCC (graph error).",
        )

    # Find cycle within SCC through required accepting states
    cycle_start_prod = prefix_path[-1]
    cycle_scc = best_scc_set  # restrict to scc

    # For cycle, we need to start from the endpoint of the prefix
    cycle_entry = prefix_path[-1]
    cycle = pg.find_cycle_through(best_scc_set, required_accepting)

    if cycle is None:
        return PlanResult(
            path=[], cycle_start_idx=0,
            satisfied=[], violated=list(range(len(automata))),
            total_reward=0, max_possible_reward=max_possible,
            spec_names=spec_names, spec_rewards=list(rewards),
            success=False,
            message="Could not construct cycle within SCC.",
        )

    # Connect prefix end to cycle start (they may differ)
    if prefix_path[-1] != cycle[0]:
        bridge = pg.bfs_path(prefix_path[-1], {cycle[0]})
        if bridge is None:
            bridge = [prefix_path[-1]]
        full_prod_path = prefix_path[:-1] + bridge + cycle[1:]
        cycle_start_idx = len(prefix_path[:-1] + bridge) - 1
    else:
        full_prod_path = prefix_path + cycle[1:]
        cycle_start_idx = len(prefix_path) - 1

    # Extract grid positions
    grid_path = [pg.states[v][0] for v in full_prod_path]

    return PlanResult(
        path=grid_path,
        cycle_start_idx=cycle_start_idx,
        satisfied=sorted(satisfied_set),
        violated=sorted(violated_set),
        total_reward=best_reward,
        max_possible_reward=max_possible,
        spec_names=spec_names,
        spec_rewards=list(rewards),
        success=True,
        message=f"Plan found! Satisfies {len(satisfied_set)}/{len(automata)} specs "
                f"(reward {best_reward:.0f}/{max_possible:.0f}).",
    )