Spaces:
Sleeping
Sleeping
File size: 4,905 Bytes
acb327b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | """
ECHO ULTIMATE — Self-Consistency Confidence Checker.
Samples N answers for the same question. If answers disagree,
automatically reduces the stated confidence by CONSISTENCY_DISCOUNT.
This is a key innovation over the base ECHO environment.
In training: disabled (too slow, adds noise).
In demo: enabled (impressive, shows genuine uncertainty awareness).
"""
import logging
from collections import Counter
from dataclasses import dataclass, field
from typing import Callable, Optional
from config import cfg
from env.parser import parse_response, ParseResult
logger = logging.getLogger(__name__)
@dataclass
class ConsistencyResult:
"""Result of self-consistency checking for one question."""
answers: list[str] = field(default_factory=list)
confidences: list[int] = field(default_factory=list)
final_answer: str = ""
final_confidence: int = 50
agreement_rate: float = 1.0
was_adjusted: bool = False
adjustment_amount: int = 0
parse_results: list = field(default_factory=list)
class SelfConsistencyChecker:
"""
Multi-sample confidence adjustment.
Algorithm:
1. Generate n_samples responses for the same prompt
2. Parse each into (confidence, answer)
3. Find majority-vote answer
4. agreement_rate = fraction of samples matching majority
5. If agreement_rate < 1.0:
final_confidence = round(mean_confidence * (1 - CONSISTENCY_DISCOUNT))
else:
final_confidence = mean_confidence (unchanged)
6. Return ConsistencyResult with final_answer and final_confidence
"""
def __init__(self, n_samples: int = cfg.SELF_CONSISTENCY_SAMPLES) -> None:
self.n_samples = n_samples
self.discount = cfg.CONSISTENCY_DISCOUNT
def check(
self,
prompt: str,
generate_fn: Callable[[str], str],
n_samples: Optional[int] = None,
) -> ConsistencyResult:
"""
Run n_samples generations and return a consistency-adjusted result.
Args:
prompt: formatted question prompt
generate_fn: callable(prompt) -> raw LLM output string
n_samples: override default sample count
"""
n = n_samples or self.n_samples
parsed_list: list[ParseResult] = []
answers = []
confidences = []
for i in range(n):
try:
raw = generate_fn(prompt)
parsed = parse_response(raw)
except Exception as exc:
logger.warning("SelfConsistencyChecker sample %d failed: %s", i, exc)
from env.parser import ParseResult as PR
parsed = PR(confidence=50, answer="", raw="")
parsed_list.append(parsed)
answers.append(parsed.answer.strip().lower())
confidences.append(parsed.confidence)
if not answers:
return ConsistencyResult(final_confidence=50, final_answer="")
# Majority vote answer
counter = Counter(answers)
majority_answer_lower, majority_count = counter.most_common(1)[0]
agreement_rate = majority_count / n
# Find the original-cased answer for the majority
final_answer = ""
for pr in parsed_list:
if pr.answer.strip().lower() == majority_answer_lower:
final_answer = pr.answer
break
mean_conf = round(sum(confidences) / len(confidences))
# Apply discount if answers disagree
was_adjusted = agreement_rate < 1.0
if was_adjusted:
adjusted = round(mean_conf * (1.0 - self.discount))
adjustment_amount = mean_conf - adjusted
final_confidence = max(cfg.CONFIDENCE_MIN, adjusted)
else:
final_confidence = mean_conf
adjustment_amount = 0
return ConsistencyResult(
answers=[pr.answer for pr in parsed_list],
confidences=confidences,
final_answer=final_answer,
final_confidence=final_confidence,
agreement_rate=agreement_rate,
was_adjusted=was_adjusted,
adjustment_amount=adjustment_amount,
parse_results=parsed_list,
)
def format_explanation(self, result: ConsistencyResult) -> str:
"""Human-readable explanation of the consistency check result."""
if not result.was_adjusted:
return (
f"✅ All {len(result.answers)} samples agreed → "
f"confidence unchanged at {result.final_confidence}%"
)
return (
f"⚠️ Samples disagreed (agreement={result.agreement_rate:.0%}) → "
f"confidence reduced by {result.adjustment_amount}% "
f"to {result.final_confidence}%\n"
f" Samples: {result.answers}"
)
|