echo-ultimate / env /self_consistency.py
Vikaspandey582003's picture
Upload folder using huggingface_hub
acb327b verified
"""
ECHO ULTIMATE — Self-Consistency Confidence Checker.
Samples N answers for the same question. If answers disagree,
automatically reduces the stated confidence by CONSISTENCY_DISCOUNT.
This is a key innovation over the base ECHO environment.
In training: disabled (too slow, adds noise).
In demo: enabled (impressive, shows genuine uncertainty awareness).
"""
import logging
from collections import Counter
from dataclasses import dataclass, field
from typing import Callable, Optional
from config import cfg
from env.parser import parse_response, ParseResult
logger = logging.getLogger(__name__)
@dataclass
class ConsistencyResult:
"""Result of self-consistency checking for one question."""
answers: list[str] = field(default_factory=list)
confidences: list[int] = field(default_factory=list)
final_answer: str = ""
final_confidence: int = 50
agreement_rate: float = 1.0
was_adjusted: bool = False
adjustment_amount: int = 0
parse_results: list = field(default_factory=list)
class SelfConsistencyChecker:
"""
Multi-sample confidence adjustment.
Algorithm:
1. Generate n_samples responses for the same prompt
2. Parse each into (confidence, answer)
3. Find majority-vote answer
4. agreement_rate = fraction of samples matching majority
5. If agreement_rate < 1.0:
final_confidence = round(mean_confidence * (1 - CONSISTENCY_DISCOUNT))
else:
final_confidence = mean_confidence (unchanged)
6. Return ConsistencyResult with final_answer and final_confidence
"""
def __init__(self, n_samples: int = cfg.SELF_CONSISTENCY_SAMPLES) -> None:
self.n_samples = n_samples
self.discount = cfg.CONSISTENCY_DISCOUNT
def check(
self,
prompt: str,
generate_fn: Callable[[str], str],
n_samples: Optional[int] = None,
) -> ConsistencyResult:
"""
Run n_samples generations and return a consistency-adjusted result.
Args:
prompt: formatted question prompt
generate_fn: callable(prompt) -> raw LLM output string
n_samples: override default sample count
"""
n = n_samples or self.n_samples
parsed_list: list[ParseResult] = []
answers = []
confidences = []
for i in range(n):
try:
raw = generate_fn(prompt)
parsed = parse_response(raw)
except Exception as exc:
logger.warning("SelfConsistencyChecker sample %d failed: %s", i, exc)
from env.parser import ParseResult as PR
parsed = PR(confidence=50, answer="", raw="")
parsed_list.append(parsed)
answers.append(parsed.answer.strip().lower())
confidences.append(parsed.confidence)
if not answers:
return ConsistencyResult(final_confidence=50, final_answer="")
# Majority vote answer
counter = Counter(answers)
majority_answer_lower, majority_count = counter.most_common(1)[0]
agreement_rate = majority_count / n
# Find the original-cased answer for the majority
final_answer = ""
for pr in parsed_list:
if pr.answer.strip().lower() == majority_answer_lower:
final_answer = pr.answer
break
mean_conf = round(sum(confidences) / len(confidences))
# Apply discount if answers disagree
was_adjusted = agreement_rate < 1.0
if was_adjusted:
adjusted = round(mean_conf * (1.0 - self.discount))
adjustment_amount = mean_conf - adjusted
final_confidence = max(cfg.CONFIDENCE_MIN, adjusted)
else:
final_confidence = mean_conf
adjustment_amount = 0
return ConsistencyResult(
answers=[pr.answer for pr in parsed_list],
confidences=confidences,
final_answer=final_answer,
final_confidence=final_confidence,
agreement_rate=agreement_rate,
was_adjusted=was_adjusted,
adjustment_amount=adjustment_amount,
parse_results=parsed_list,
)
def format_explanation(self, result: ConsistencyResult) -> str:
"""Human-readable explanation of the consistency check result."""
if not result.was_adjusted:
return (
f"✅ All {len(result.answers)} samples agreed → "
f"confidence unchanged at {result.final_confidence}%"
)
return (
f"⚠️ Samples disagreed (agreement={result.agreement_rate:.0%}) → "
f"confidence reduced by {result.adjustment_amount}% "
f"to {result.final_confidence}%\n"
f" Samples: {result.answers}"
)