File size: 4,547 Bytes
5c1bb37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from __future__ import annotations

import hashlib
from typing import Any

from _common import cell_set, safe_div, universe_cells, valid_candidates


def deterministic_random_key(candidate_id: str, seed: int = 2026) -> str:
    return hashlib.sha256(f"{seed}:{candidate_id}".encode("utf-8")).hexdigest()


def select_views(
    candidates: list[dict[str, Any]],
    method: str,
    budget: int,
    lambda_conflict: float = 0.35,
    probe_weight: float = 0.15,
    seed: int = 2026,
) -> tuple[list[dict[str, Any]], set[str]]:
    valid = valid_candidates(candidates)
    universe = universe_cells(valid)
    universe_size = max(1, len(universe))
    selected: list[dict[str, Any]] = []
    selected_ids: set[str] = set()
    covered: set[str] = set()

    if method == "random_seeded":
        order = sorted(valid, key=lambda row: deterministic_random_key(str(row["candidate_id"]), seed))
        for rank, candidate in enumerate(order[:budget]):
            cells = cell_set(candidate)
            new_cells = cells - covered
            covered.update(new_cells)
            selected.append(_selected_row(candidate, rank, safe_div(len(new_cells), universe_size), 0.0))
        return selected, covered

    for rank in range(budget):
        best_candidate = None
        best_tuple = None
        best_gain = 0.0
        for candidate in valid:
            cid = str(candidate["candidate_id"])
            if cid in selected_ids:
                continue
            cells = cell_set(candidate)
            marginal_gain = safe_div(len(cells - covered), universe_size)
            probe = float(candidate.get("single_view_probe_coverage", 0.0))
            conflict = float(candidate.get("conflict_prior", 0.0))

            if method == "greedy_coverage":
                score = marginal_gain + 0.02 * probe - 0.01 * conflict
            elif method == "single_view_probe":
                score = probe + 0.05 * marginal_gain - 0.01 * conflict
            elif method == "low_conflict":
                score = marginal_gain - lambda_conflict * conflict
            elif method == "cm_evs":
                score = marginal_gain + probe_weight * probe - lambda_conflict * conflict
            else:
                raise ValueError(f"Unknown selection method: {method}")

            score_tuple = (score, marginal_gain, -conflict, cid)
            if best_tuple is None or score_tuple > best_tuple:
                best_tuple = score_tuple
                best_candidate = candidate
                best_gain = marginal_gain

        if best_candidate is None:
            break
        cid = str(best_candidate["candidate_id"])
        selected_ids.add(cid)
        covered.update(cell_set(best_candidate))
        selected.append(_selected_row(best_candidate, rank, best_gain, float(best_tuple[0])))

    return selected, covered


def summarize_selection(
    candidates: list[dict[str, Any]],
    selected: list[dict[str, Any]],
    covered: set[str],
    method: str,
    budget: int,
    lambda_conflict: float = 0.35,
    source_label: str | None = None,
) -> dict[str, Any]:
    valid = valid_candidates(candidates)
    universe = universe_cells(valid)
    scene_id = str(valid[0].get("scene_id", "unknown")) if valid else "unknown"
    source = source_label or (str(valid[0].get("source", "unknown")) if valid else "unknown")
    mean_conflict = safe_div(sum(float(row.get("conflict_prior", 0.0)) for row in selected), len(selected))
    mean_probe = safe_div(sum(float(row.get("single_view_probe_coverage", 0.0)) for row in selected), len(selected))
    runtime = sum(float(row.get("runtime_s", 0.0)) for row in selected)
    coverage = safe_div(len(covered), max(1, len(universe)))
    return {
        "source": source,
        "scene_id": scene_id,
        "method": method,
        "budget": budget,
        "selected_views": len(selected),
        "coverage": round(coverage, 6),
        "coverage_per_view": round(safe_div(coverage, len(selected)), 6),
        "mean_conflict_prior": round(mean_conflict, 6),
        "mean_probe_coverage": round(mean_probe, 6),
        "estimated_runtime_s": round(runtime, 6),
        "lambda_conflict": lambda_conflict,
        "selected_ids": ";".join(str(row["candidate_id"]) for row in selected),
    }


def _selected_row(candidate: dict[str, Any], rank: int, marginal_gain: float, score: float) -> dict[str, Any]:
    row = dict(candidate)
    row["rank"] = rank
    row["marginal_gain"] = round(float(marginal_gain), 6)
    row["score"] = round(float(score), 6)
    return row