from __future__ import annotations import hashlib from typing import Any from _common import cell_set, safe_div, universe_cells, valid_candidates def deterministic_random_key(candidate_id: str, seed: int = 2026) -> str: return hashlib.sha256(f"{seed}:{candidate_id}".encode("utf-8")).hexdigest() def select_views( candidates: list[dict[str, Any]], method: str, budget: int, lambda_conflict: float = 0.35, probe_weight: float = 0.15, seed: int = 2026, ) -> tuple[list[dict[str, Any]], set[str]]: valid = valid_candidates(candidates) universe = universe_cells(valid) universe_size = max(1, len(universe)) selected: list[dict[str, Any]] = [] selected_ids: set[str] = set() covered: set[str] = set() if method == "random_seeded": order = sorted(valid, key=lambda row: deterministic_random_key(str(row["candidate_id"]), seed)) for rank, candidate in enumerate(order[:budget]): cells = cell_set(candidate) new_cells = cells - covered covered.update(new_cells) selected.append(_selected_row(candidate, rank, safe_div(len(new_cells), universe_size), 0.0)) return selected, covered for rank in range(budget): best_candidate = None best_tuple = None best_gain = 0.0 for candidate in valid: cid = str(candidate["candidate_id"]) if cid in selected_ids: continue cells = cell_set(candidate) marginal_gain = safe_div(len(cells - covered), universe_size) probe = float(candidate.get("single_view_probe_coverage", 0.0)) conflict = float(candidate.get("conflict_prior", 0.0)) if method == "greedy_coverage": score = marginal_gain + 0.02 * probe - 0.01 * conflict elif method == "single_view_probe": score = probe + 0.05 * marginal_gain - 0.01 * conflict elif method == "low_conflict": score = marginal_gain - lambda_conflict * conflict elif method == "cm_evs": score = marginal_gain + probe_weight * probe - lambda_conflict * conflict else: raise ValueError(f"Unknown selection method: {method}") score_tuple = (score, marginal_gain, -conflict, cid) if best_tuple is None or score_tuple > best_tuple: best_tuple = score_tuple best_candidate = candidate best_gain = marginal_gain if best_candidate is None: break cid = str(best_candidate["candidate_id"]) selected_ids.add(cid) covered.update(cell_set(best_candidate)) selected.append(_selected_row(best_candidate, rank, best_gain, float(best_tuple[0]))) return selected, covered def summarize_selection( candidates: list[dict[str, Any]], selected: list[dict[str, Any]], covered: set[str], method: str, budget: int, lambda_conflict: float = 0.35, source_label: str | None = None, ) -> dict[str, Any]: valid = valid_candidates(candidates) universe = universe_cells(valid) scene_id = str(valid[0].get("scene_id", "unknown")) if valid else "unknown" source = source_label or (str(valid[0].get("source", "unknown")) if valid else "unknown") mean_conflict = safe_div(sum(float(row.get("conflict_prior", 0.0)) for row in selected), len(selected)) mean_probe = safe_div(sum(float(row.get("single_view_probe_coverage", 0.0)) for row in selected), len(selected)) runtime = sum(float(row.get("runtime_s", 0.0)) for row in selected) coverage = safe_div(len(covered), max(1, len(universe))) return { "source": source, "scene_id": scene_id, "method": method, "budget": budget, "selected_views": len(selected), "coverage": round(coverage, 6), "coverage_per_view": round(safe_div(coverage, len(selected)), 6), "mean_conflict_prior": round(mean_conflict, 6), "mean_probe_coverage": round(mean_probe, 6), "estimated_runtime_s": round(runtime, 6), "lambda_conflict": lambda_conflict, "selected_ids": ";".join(str(row["candidate_id"]) for row in selected), } def _selected_row(candidate: dict[str, Any], rank: int, marginal_gain: float, score: float) -> dict[str, Any]: row = dict(candidate) row["rank"] = rank row["marginal_gain"] = round(float(marginal_gain), 6) row["score"] = round(float(score), 6) return row