Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Phase 4 reward: weighted (0.35 / 0.35 / 0.30) with potential-style deltas, critical-queue | |
| shaping, full sub-scores even on invalid steps (+ explicit invalid penalty), and mild output | |
| scaling. | |
| """ | |
| from __future__ import annotations | |
| import math | |
| from datetime import datetime, timedelta, timezone | |
| from typing import Any | |
| try: | |
| from ..models import GhostexecAction, RewardBreakdown, WorldState | |
| except ImportError: | |
| from models import GhostexecAction, RewardBreakdown, WorldState | |
| W_CONFLICT = 0.35 | |
| W_REL = 0.35 | |
| W_TASK = 0.30 | |
| # Raw conflict units (pre-weight) are clamped to keep invalid / idle steps from exploding. | |
| CONFLICT_RAW_CAP: float = 6.0 | |
| # Scales the weighted sum of the three channels (weights stay fixed per hackathon rules). | |
| WEIGHTED_OUTPUT_SCALE: float = 0.48 | |
| # Tone misfit penalties kept small vs outcome terms (~<20% of a strong +2 conflict step after weights). | |
| TONE_PENALTY_CASUAL_ANGRY_BOARD: float = 0.35 | |
| TONE_PENALTY_FORMAL_PERSONAL: float = 0.08 | |
| _RESOLVE_MICRO_BONUS: float = 0.12 | |
| _CRITICAL_PER_EMAIL_BONUS: float = 0.22 | |
| _RESCHEDULE_VALID_MICRO_BONUS: float = 0.10 | |
| _SEND_MESSAGE_VALID_MICRO_BONUS: float = 0.08 | |
| _COMPLETE_TASK_VALID_MICRO_BONUS: float = 0.06 | |
| _DELEGATE_TASK_VALID_MICRO_BONUS: float = 0.10 | |
| _DO_NOTHING_STRICT_PENALTY: float = -0.15 | |
| _SYNERGY_CAP: float = 0.40 | |
| _TRADEOFF_CAP: float = 0.30 | |
| _POTENTIAL_CAP: float = 0.25 | |
| _SCAFFOLD_CAP: float = 0.35 | |
| _SHAPING_TO_BASE_BUDGET: float = 1.25 | |
| _QUALITY_CAP: float = 0.28 | |
| _REPLY_PRIORITY_MICRO_BONUS: dict[str, float] = { | |
| "critical": 0.30, | |
| "high": 0.15, | |
| "normal": 0.04, | |
| "low": 0.02, | |
| } | |
| _MOOD_RANK: dict[str, int] = { | |
| "happy": 4, | |
| "neutral": 3, | |
| "annoyed": 2, | |
| "angry": 1, | |
| "furious": 0, | |
| } | |
| def _parse_dt(value: str) -> datetime: | |
| if value.endswith("Z"): | |
| return datetime.fromisoformat(value[:-1]).replace(tzinfo=timezone.utc) | |
| dt = datetime.fromisoformat(value) | |
| if dt.tzinfo is None: | |
| return dt.replace(tzinfo=timezone.utc) | |
| return dt | |
| def _meeting_end(m: Any) -> datetime: | |
| start = _parse_dt(m.start) | |
| return start + timedelta(minutes=m.duration_minutes) | |
| def _overlap(a0: datetime, a1: datetime, b0: datetime, b1: datetime) -> bool: | |
| return a0 < b1 and b0 < a1 | |
| def meeting_conflicts(world: WorldState) -> list[dict[str, Any]]: | |
| active = [m for m in world.meetings if not m.cancelled] | |
| out: list[dict[str, Any]] = [] | |
| for i, a in enumerate(active): | |
| a0, a1 = _parse_dt(a.start), _meeting_end(a) | |
| for b in active[i + 1 :]: | |
| b0, b1 = _parse_dt(b.start), _meeting_end(b) | |
| if _overlap(a0, a1, b0, b1): | |
| o0, o1 = max(a0, b0), min(a1, b1) | |
| out.append( | |
| { | |
| "meeting_a": a.id, | |
| "meeting_b": b.id, | |
| "overlap_start": o0.isoformat(), | |
| "overlap_end": o1.isoformat(), | |
| } | |
| ) | |
| return out | |
| def _pair_set(rows: list[dict[str, Any]]) -> set[frozenset[str]]: | |
| return {frozenset((r["meeting_a"], r["meeting_b"])) for r in rows} | |
| def _attendee_moods_ok(world: WorldState, pair: frozenset[str]) -> bool: | |
| names: set[str] = set() | |
| for mid in pair: | |
| m = next((x for x in world.meetings if x.id == mid), None) | |
| if m: | |
| names.update(m.attendees) | |
| for n in names: | |
| c = next((x for x in world.contacts if x.name == n), None) | |
| if c is None: | |
| continue | |
| if c.mood not in ("happy", "neutral"): | |
| return False | |
| return True | |
| def score_conflict_resolution( | |
| before: WorldState, | |
| after: WorldState, | |
| action: GhostexecAction, | |
| *, | |
| action_ok: bool, | |
| ) -> float: | |
| b = _pair_set(meeting_conflicts(before)) | |
| a = _pair_set(meeting_conflicts(after)) | |
| s = 0.0 | |
| for _p in b - a: | |
| s += 2.0 + _RESOLVE_MICRO_BONUS | |
| if _attendee_moods_ok(after, _p): | |
| s += 1.0 | |
| for _ in a - b: | |
| s -= 3.0 | |
| if action_ok and action.action_type == "reschedule_meeting": | |
| s += _RESCHEDULE_VALID_MICRO_BONUS | |
| return s | |
| def critical_unreplied_count(world: WorldState) -> int: | |
| return sum(1 for e in world.emails if e.priority == "critical" and not e.replied) | |
| def score_critical_queue_bonus(before: WorldState, after: WorldState) -> float: | |
| reduction = critical_unreplied_count(before) - critical_unreplied_count(after) | |
| return _CRITICAL_PER_EMAIL_BONUS * max(0, reduction) | |
| def _classify_tone(text: str) -> str: | |
| t = text.lower() | |
| if any(w in t for w in ("sorry", "apologize", "apologies", "my mistake")): | |
| return "apologetic" | |
| if any(w in t for w in ("dear ", "sincerely", "best regards", "respectfully", "cordially")): | |
| return "formal" | |
| if any(w in t for w in ("hey", "lol", "haha", "👋", "no worries", "cheers")): | |
| return "casual" | |
| if any(w in t for w in ("must", "immediately", "asap", "non-negotiable", "demand")): | |
| return "assertive" | |
| return "neutral" | |
| def score_relationship( | |
| before: WorldState, | |
| after: WorldState, | |
| action: GhostexecAction, | |
| *, | |
| action_ok: bool, | |
| relationship_suppressed_for_email_to: frozenset[str] | None = None, | |
| ) -> float: | |
| rel_sup = relationship_suppressed_for_email_to or frozenset() | |
| s = 0.0 | |
| before_map = {c.name: c for c in before.contacts} | |
| after_map = {c.name: c for c in after.contacts} | |
| for name, ca in after_map.items(): | |
| cb = before_map.get(name) | |
| if not cb: | |
| continue | |
| ra, rb = _MOOD_RANK[ca.mood], _MOOD_RANK[cb.mood] | |
| vip = ca.importance >= 4 | |
| if ra > rb: | |
| s += 3.0 if vip else 1.0 | |
| elif ra < rb: | |
| s -= 4.0 if vip else 2.0 | |
| if action.action_type == "reply_email" and action.email_id: | |
| em = next((e for e in before.emails if e.id == action.email_id), None) | |
| if em and em.sender in rel_sup: | |
| return 0.0 | |
| if em: | |
| if action_ok and (action.message_body or "").strip(): | |
| pri = (em.priority or "").lower() | |
| micro = _REPLY_PRIORITY_MICRO_BONUS.get(pri, 0.0) | |
| if em.sender_relationship == "VIP": | |
| micro *= 2.0 | |
| s += micro | |
| tone = _classify_tone(action.message_body) | |
| contact = next((c for c in before.contacts if c.name == em.sender), None) | |
| if ( | |
| contact | |
| and contact.relationship_type == "board_member" | |
| and contact.mood in ("angry", "furious", "annoyed") | |
| and tone == "casual" | |
| ): | |
| s -= TONE_PENALTY_CASUAL_ANGRY_BOARD | |
| if em.sender_relationship == "personal" and tone == "formal": | |
| s -= TONE_PENALTY_FORMAL_PERSONAL | |
| if action_ok and action.action_type == "send_message" and action.contact_name: | |
| known_contact = any(c.name == action.contact_name for c in before.contacts) | |
| if known_contact and (action.message_body or "").strip(): | |
| s += _SEND_MESSAGE_VALID_MICRO_BONUS | |
| return s | |
| def _overdue_tasks(world: WorldState) -> list[Any]: | |
| now = _parse_dt(world.simulation_time) | |
| out = [] | |
| for t in world.tasks: | |
| if t.status == "done": | |
| continue | |
| if _parse_dt(t.deadline) < now: | |
| out.append(t) | |
| return out | |
| def score_task_completion( | |
| before: WorldState, | |
| after: WorldState, | |
| action: GhostexecAction, | |
| *, | |
| action_ok: bool, | |
| ) -> float: | |
| s = 0.0 | |
| now = _parse_dt(after.simulation_time) | |
| before_tasks = {t.id: t for t in before.tasks} | |
| after_tasks = {t.id: t for t in after.tasks} | |
| for tid, ta in after_tasks.items(): | |
| tb = before_tasks.get(tid) | |
| if not tb: | |
| continue | |
| if tb.status != "overdue" and tb.status != "done" and ta.status == "overdue": | |
| s -= 2.0 | |
| if tb.status != "done" and ta.status == "done": | |
| dl = _parse_dt(tb.deadline) | |
| if dl >= now: | |
| s += 2.0 | |
| else: | |
| s += 0.5 | |
| if (not tb.delegated_to) and ta.delegated_to: | |
| de = next((c for c in after.contacts if c.name == ta.delegated_to), None) | |
| if de and de.importance <= 3: | |
| s += 1.0 | |
| if action_ok and action.action_type == "complete_task": | |
| s += _COMPLETE_TASK_VALID_MICRO_BONUS | |
| if action_ok and action.action_type == "delegate_task": | |
| s += _DELEGATE_TASK_VALID_MICRO_BONUS | |
| return s | |
| def catastrophic(world: WorldState) -> bool: | |
| vip_furious = any(c.importance >= 4 and c.mood == "furious" for c in world.contacts) | |
| critical_open = sum(1 for e in world.emails if e.priority == "critical" and not e.replied) | |
| return vip_furious or critical_open > 3 | |
| def _scaffold_learning_signal( | |
| before: WorldState, | |
| after: WorldState, | |
| action: GhostexecAction, | |
| *, | |
| action_ok: bool, | |
| step_index: int | None, | |
| max_steps: int | None, | |
| ) -> float: | |
| if not action_ok: | |
| return 0.0 | |
| if action.action_type == "do_nothing": | |
| return 0.0 | |
| s = 0.0 | |
| critical_before = critical_unreplied_count(before) | |
| critical_after = critical_unreplied_count(after) | |
| conflict_before = len(meeting_conflicts(before)) | |
| conflict_after = len(meeting_conflicts(after)) | |
| overdue_before = len(_overdue_tasks(before)) | |
| overdue_after = len(_overdue_tasks(after)) | |
| if action.action_type == "reply_email": | |
| if critical_after < critical_before: | |
| s += 0.16 | |
| elif critical_before > 0: | |
| s += 0.05 | |
| if action.action_type in ("reschedule_meeting", "cancel_meeting"): | |
| if conflict_after < conflict_before: | |
| s += 0.15 | |
| elif conflict_before > 0: | |
| s += 0.04 | |
| if action.action_type in ("complete_task", "delegate_task"): | |
| if overdue_after < overdue_before: | |
| s += 0.12 | |
| elif overdue_before > 0: | |
| s += 0.03 | |
| # Early episode shaping slightly amplified for better exploration guidance. | |
| if step_index is not None and max_steps and max_steps > 0: | |
| frac = max(0.0, min(1.0, step_index / max_steps)) | |
| if frac <= 0.33: | |
| s *= 1.20 | |
| elif frac >= 0.85: | |
| s *= 0.90 | |
| return max(-_SCAFFOLD_CAP, min(_SCAFFOLD_CAP, s)) | |
| def _state_potential(world: WorldState) -> float: | |
| conflicts = len(meeting_conflicts(world)) | |
| critical_open = critical_unreplied_count(world) | |
| overdue = len(_overdue_tasks(world)) | |
| stress = float(world.stress) | |
| # Lower operational pressure => higher potential. | |
| return -( | |
| 1.15 * critical_open | |
| + 0.90 * conflicts | |
| + 0.55 * overdue | |
| + 0.02 * stress | |
| ) | |
| def _budgeted_shaping_total(base_weighted_inner: float, shaping_total_inner: float) -> float: | |
| # Keep shaping informative but bounded against the base objective to avoid exploit loops. | |
| budget = _SHAPING_TO_BASE_BUDGET * (abs(base_weighted_inner) + 0.05) | |
| return max(-budget, min(budget, shaping_total_inner)) | |
| def _quality_separation_signal( | |
| *, | |
| c: float, | |
| r: float, | |
| t: float, | |
| action: GhostexecAction, | |
| action_ok: bool, | |
| ) -> float: | |
| # Amplify distance between clearly good vs clearly bad valid actions. | |
| if not action_ok or action.action_type == "do_nothing": | |
| return 0.0 | |
| base = W_CONFLICT * c + W_REL * r + W_TASK * t | |
| if base >= 0.90: | |
| return _QUALITY_CAP | |
| if base >= 0.35: | |
| return 0.12 | |
| if base <= -0.90: | |
| return -_QUALITY_CAP | |
| if base <= -0.35: | |
| return -0.12 | |
| return 0.0 | |
| def aggregate_scores( | |
| conflict: float, | |
| relationship: float, | |
| task: float, | |
| *, | |
| conflict_raw: float, | |
| critical_queue_bonus: float, | |
| weighted_inner: float, | |
| weighted_base_only: float, | |
| shaping_synergy: float, | |
| shaping_tradeoff: float, | |
| shaping_potential: float, | |
| shaping_scaffold: float, | |
| shaping_quality: float, | |
| action_ok: bool, | |
| episode_done: bool, | |
| world_after: WorldState, | |
| ) -> RewardBreakdown: | |
| weighted = WEIGHTED_OUTPUT_SCALE * weighted_inner | |
| weighted_base_only_scaled = WEIGHTED_OUTPUT_SCALE * weighted_base_only | |
| shaping_total = WEIGHTED_OUTPUT_SCALE * ( | |
| shaping_synergy + shaping_tradeoff + shaping_potential + shaping_scaffold + shaping_quality | |
| ) | |
| denom = abs(weighted_base_only_scaled) + 1e-6 | |
| shaping_ratio = min(10.0, abs(shaping_total) / denom) | |
| inv = 0.0 | |
| if not action_ok: | |
| inv = -0.25 | |
| bonus = 0.0 | |
| cata = 0.0 | |
| if episode_done: | |
| if world_after.stress < 40: | |
| bonus = 10.0 | |
| if catastrophic(world_after): | |
| cata = -15.0 | |
| final = weighted + inv + bonus + cata | |
| return RewardBreakdown( | |
| conflict_raw=conflict_raw, | |
| critical_queue_bonus=critical_queue_bonus, | |
| conflict=conflict, | |
| relationship=relationship, | |
| task=task, | |
| shaping_synergy=WEIGHTED_OUTPUT_SCALE * shaping_synergy, | |
| shaping_tradeoff=WEIGHTED_OUTPUT_SCALE * shaping_tradeoff, | |
| shaping_potential=WEIGHTED_OUTPUT_SCALE * shaping_potential, | |
| shaping_scaffold=WEIGHTED_OUTPUT_SCALE * shaping_scaffold, | |
| shaping_quality=WEIGHTED_OUTPUT_SCALE * shaping_quality, | |
| shaping_total=shaping_total, | |
| shaping_to_base_ratio=shaping_ratio, | |
| weighted_base=weighted, | |
| output_scale=WEIGHTED_OUTPUT_SCALE, | |
| invalid_step_adjustment=inv, | |
| episode_completion_bonus=bonus, | |
| catastrophic_penalty=cata, | |
| do_nothing_floor=0.0, | |
| final=final, | |
| ) | |
| def apply_do_nothing_penalty_floor( | |
| action: GhostexecAction, | |
| breakdown: RewardBreakdown, | |
| ) -> RewardBreakdown: | |
| if action.action_type != "do_nothing": | |
| return breakdown | |
| floor_delta = _DO_NOTHING_STRICT_PENALTY | |
| new_final = breakdown.final + floor_delta | |
| return breakdown.model_copy( | |
| update={"do_nothing_floor": floor_delta, "final": new_final}, | |
| ) | |
| def compute_step_reward( | |
| before: WorldState, | |
| after: WorldState, | |
| action: GhostexecAction, | |
| *, | |
| action_ok: bool, | |
| episode_done: bool, | |
| relationship_suppressed_for_email_to: frozenset[str] | None = None, | |
| reward_mode: str = "full", | |
| step_index: int | None = None, | |
| max_steps: int | None = None, | |
| ) -> RewardBreakdown: | |
| c_core = score_conflict_resolution(before, after, action, action_ok=action_ok) | |
| crit_b = score_critical_queue_bonus(before, after) | |
| c_raw = c_core + crit_b | |
| c = max(-CONFLICT_RAW_CAP, min(CONFLICT_RAW_CAP, c_raw)) | |
| r = score_relationship( | |
| before, | |
| after, | |
| action, | |
| action_ok=action_ok, | |
| relationship_suppressed_for_email_to=relationship_suppressed_for_email_to, | |
| ) | |
| t = score_task_completion(before, after, action, action_ok=action_ok) | |
| weighted_base_only = W_CONFLICT * c + W_REL * r + W_TASK * t | |
| weighted_inner = weighted_base_only | |
| synergy = 0.0 | |
| tradeoff_penalty = 0.0 | |
| potential_progress = 0.0 | |
| scaffold_signal = 0.0 | |
| quality_signal = 0.0 | |
| if reward_mode in ("full", "shaping"): | |
| # Bounded nonlinear shaping to speed learning without overpowering base channels. | |
| if c > 0.0 and r > 0.0: | |
| synergy += min(_SYNERGY_CAP, 0.18 * math.tanh(0.35 * c * r)) | |
| if t > 0.0 and (c > 0.0 or r > 0.0): | |
| bridge = max(c, 0.0) + max(r, 0.0) | |
| synergy += min(_SYNERGY_CAP, 0.10 * math.tanh(0.25 * t * bridge)) | |
| if c < -0.5 and r < -0.5: | |
| tradeoff_penalty -= min(_TRADEOFF_CAP, 0.12 * math.tanh(0.25 * abs(c * r))) | |
| if t < -0.5 and (c < 0.0 or r < 0.0): | |
| debt = abs(t) * (abs(min(c, 0.0)) + abs(min(r, 0.0))) | |
| tradeoff_penalty -= min(_TRADEOFF_CAP, 0.08 * math.tanh(0.18 * debt)) | |
| potential_progress = max( | |
| -_POTENTIAL_CAP, | |
| min(_POTENTIAL_CAP, _state_potential(after) - _state_potential(before)), | |
| ) | |
| scaffold_signal = _scaffold_learning_signal( | |
| before, | |
| after, | |
| action, | |
| action_ok=action_ok, | |
| step_index=step_index, | |
| max_steps=max_steps, | |
| ) | |
| quality_signal = _quality_separation_signal( | |
| c=c, | |
| r=r, | |
| t=t, | |
| action=action, | |
| action_ok=action_ok, | |
| ) | |
| shaping_total_inner = ( | |
| synergy + tradeoff_penalty + potential_progress + scaffold_signal + quality_signal | |
| ) | |
| weighted_inner += _budgeted_shaping_total(weighted_base_only, shaping_total_inner) | |
| bd = aggregate_scores( | |
| c, | |
| r, | |
| t, | |
| conflict_raw=c_raw, | |
| critical_queue_bonus=crit_b, | |
| weighted_inner=weighted_inner, | |
| weighted_base_only=weighted_base_only, | |
| shaping_synergy=synergy, | |
| shaping_tradeoff=tradeoff_penalty, | |
| shaping_potential=potential_progress, | |
| shaping_scaffold=scaffold_signal, | |
| shaping_quality=quality_signal, | |
| action_ok=action_ok, | |
| episode_done=episode_done, | |
| world_after=after, | |
| ) | |
| return apply_do_nothing_penalty_floor(action, bd) | |