ghostexec / server /reward.py
modelbuilderhq's picture
Upload folder using huggingface_hub
ee21104 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Phase 4 reward: weighted (0.35 / 0.35 / 0.30) with potential-style deltas, critical-queue
shaping, full sub-scores even on invalid steps (+ explicit invalid penalty), and mild output
scaling.
"""
from __future__ import annotations
import math
from datetime import datetime, timedelta, timezone
from typing import Any
try:
from ..models import GhostexecAction, RewardBreakdown, WorldState
except ImportError:
from models import GhostexecAction, RewardBreakdown, WorldState
W_CONFLICT = 0.35
W_REL = 0.35
W_TASK = 0.30
# Raw conflict units (pre-weight) are clamped to keep invalid / idle steps from exploding.
CONFLICT_RAW_CAP: float = 6.0
# Scales the weighted sum of the three channels (weights stay fixed per hackathon rules).
WEIGHTED_OUTPUT_SCALE: float = 0.48
# Tone misfit penalties kept small vs outcome terms (~<20% of a strong +2 conflict step after weights).
TONE_PENALTY_CASUAL_ANGRY_BOARD: float = 0.35
TONE_PENALTY_FORMAL_PERSONAL: float = 0.08
_RESOLVE_MICRO_BONUS: float = 0.12
_CRITICAL_PER_EMAIL_BONUS: float = 0.22
_RESCHEDULE_VALID_MICRO_BONUS: float = 0.10
_SEND_MESSAGE_VALID_MICRO_BONUS: float = 0.08
_COMPLETE_TASK_VALID_MICRO_BONUS: float = 0.06
_DELEGATE_TASK_VALID_MICRO_BONUS: float = 0.10
_DO_NOTHING_STRICT_PENALTY: float = -0.15
_SYNERGY_CAP: float = 0.40
_TRADEOFF_CAP: float = 0.30
_POTENTIAL_CAP: float = 0.25
_SCAFFOLD_CAP: float = 0.35
_SHAPING_TO_BASE_BUDGET: float = 1.25
_QUALITY_CAP: float = 0.28
_REPLY_PRIORITY_MICRO_BONUS: dict[str, float] = {
"critical": 0.30,
"high": 0.15,
"normal": 0.04,
"low": 0.02,
}
_MOOD_RANK: dict[str, int] = {
"happy": 4,
"neutral": 3,
"annoyed": 2,
"angry": 1,
"furious": 0,
}
def _parse_dt(value: str) -> datetime:
if value.endswith("Z"):
return datetime.fromisoformat(value[:-1]).replace(tzinfo=timezone.utc)
dt = datetime.fromisoformat(value)
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
def _meeting_end(m: Any) -> datetime:
start = _parse_dt(m.start)
return start + timedelta(minutes=m.duration_minutes)
def _overlap(a0: datetime, a1: datetime, b0: datetime, b1: datetime) -> bool:
return a0 < b1 and b0 < a1
def meeting_conflicts(world: WorldState) -> list[dict[str, Any]]:
active = [m for m in world.meetings if not m.cancelled]
out: list[dict[str, Any]] = []
for i, a in enumerate(active):
a0, a1 = _parse_dt(a.start), _meeting_end(a)
for b in active[i + 1 :]:
b0, b1 = _parse_dt(b.start), _meeting_end(b)
if _overlap(a0, a1, b0, b1):
o0, o1 = max(a0, b0), min(a1, b1)
out.append(
{
"meeting_a": a.id,
"meeting_b": b.id,
"overlap_start": o0.isoformat(),
"overlap_end": o1.isoformat(),
}
)
return out
def _pair_set(rows: list[dict[str, Any]]) -> set[frozenset[str]]:
return {frozenset((r["meeting_a"], r["meeting_b"])) for r in rows}
def _attendee_moods_ok(world: WorldState, pair: frozenset[str]) -> bool:
names: set[str] = set()
for mid in pair:
m = next((x for x in world.meetings if x.id == mid), None)
if m:
names.update(m.attendees)
for n in names:
c = next((x for x in world.contacts if x.name == n), None)
if c is None:
continue
if c.mood not in ("happy", "neutral"):
return False
return True
def score_conflict_resolution(
before: WorldState,
after: WorldState,
action: GhostexecAction,
*,
action_ok: bool,
) -> float:
b = _pair_set(meeting_conflicts(before))
a = _pair_set(meeting_conflicts(after))
s = 0.0
for _p in b - a:
s += 2.0 + _RESOLVE_MICRO_BONUS
if _attendee_moods_ok(after, _p):
s += 1.0
for _ in a - b:
s -= 3.0
if action_ok and action.action_type == "reschedule_meeting":
s += _RESCHEDULE_VALID_MICRO_BONUS
return s
def critical_unreplied_count(world: WorldState) -> int:
return sum(1 for e in world.emails if e.priority == "critical" and not e.replied)
def score_critical_queue_bonus(before: WorldState, after: WorldState) -> float:
reduction = critical_unreplied_count(before) - critical_unreplied_count(after)
return _CRITICAL_PER_EMAIL_BONUS * max(0, reduction)
def _classify_tone(text: str) -> str:
t = text.lower()
if any(w in t for w in ("sorry", "apologize", "apologies", "my mistake")):
return "apologetic"
if any(w in t for w in ("dear ", "sincerely", "best regards", "respectfully", "cordially")):
return "formal"
if any(w in t for w in ("hey", "lol", "haha", "👋", "no worries", "cheers")):
return "casual"
if any(w in t for w in ("must", "immediately", "asap", "non-negotiable", "demand")):
return "assertive"
return "neutral"
def score_relationship(
before: WorldState,
after: WorldState,
action: GhostexecAction,
*,
action_ok: bool,
relationship_suppressed_for_email_to: frozenset[str] | None = None,
) -> float:
rel_sup = relationship_suppressed_for_email_to or frozenset()
s = 0.0
before_map = {c.name: c for c in before.contacts}
after_map = {c.name: c for c in after.contacts}
for name, ca in after_map.items():
cb = before_map.get(name)
if not cb:
continue
ra, rb = _MOOD_RANK[ca.mood], _MOOD_RANK[cb.mood]
vip = ca.importance >= 4
if ra > rb:
s += 3.0 if vip else 1.0
elif ra < rb:
s -= 4.0 if vip else 2.0
if action.action_type == "reply_email" and action.email_id:
em = next((e for e in before.emails if e.id == action.email_id), None)
if em and em.sender in rel_sup:
return 0.0
if em:
if action_ok and (action.message_body or "").strip():
pri = (em.priority or "").lower()
micro = _REPLY_PRIORITY_MICRO_BONUS.get(pri, 0.0)
if em.sender_relationship == "VIP":
micro *= 2.0
s += micro
tone = _classify_tone(action.message_body)
contact = next((c for c in before.contacts if c.name == em.sender), None)
if (
contact
and contact.relationship_type == "board_member"
and contact.mood in ("angry", "furious", "annoyed")
and tone == "casual"
):
s -= TONE_PENALTY_CASUAL_ANGRY_BOARD
if em.sender_relationship == "personal" and tone == "formal":
s -= TONE_PENALTY_FORMAL_PERSONAL
if action_ok and action.action_type == "send_message" and action.contact_name:
known_contact = any(c.name == action.contact_name for c in before.contacts)
if known_contact and (action.message_body or "").strip():
s += _SEND_MESSAGE_VALID_MICRO_BONUS
return s
def _overdue_tasks(world: WorldState) -> list[Any]:
now = _parse_dt(world.simulation_time)
out = []
for t in world.tasks:
if t.status == "done":
continue
if _parse_dt(t.deadline) < now:
out.append(t)
return out
def score_task_completion(
before: WorldState,
after: WorldState,
action: GhostexecAction,
*,
action_ok: bool,
) -> float:
s = 0.0
now = _parse_dt(after.simulation_time)
before_tasks = {t.id: t for t in before.tasks}
after_tasks = {t.id: t for t in after.tasks}
for tid, ta in after_tasks.items():
tb = before_tasks.get(tid)
if not tb:
continue
if tb.status != "overdue" and tb.status != "done" and ta.status == "overdue":
s -= 2.0
if tb.status != "done" and ta.status == "done":
dl = _parse_dt(tb.deadline)
if dl >= now:
s += 2.0
else:
s += 0.5
if (not tb.delegated_to) and ta.delegated_to:
de = next((c for c in after.contacts if c.name == ta.delegated_to), None)
if de and de.importance <= 3:
s += 1.0
if action_ok and action.action_type == "complete_task":
s += _COMPLETE_TASK_VALID_MICRO_BONUS
if action_ok and action.action_type == "delegate_task":
s += _DELEGATE_TASK_VALID_MICRO_BONUS
return s
def catastrophic(world: WorldState) -> bool:
vip_furious = any(c.importance >= 4 and c.mood == "furious" for c in world.contacts)
critical_open = sum(1 for e in world.emails if e.priority == "critical" and not e.replied)
return vip_furious or critical_open > 3
def _scaffold_learning_signal(
before: WorldState,
after: WorldState,
action: GhostexecAction,
*,
action_ok: bool,
step_index: int | None,
max_steps: int | None,
) -> float:
if not action_ok:
return 0.0
if action.action_type == "do_nothing":
return 0.0
s = 0.0
critical_before = critical_unreplied_count(before)
critical_after = critical_unreplied_count(after)
conflict_before = len(meeting_conflicts(before))
conflict_after = len(meeting_conflicts(after))
overdue_before = len(_overdue_tasks(before))
overdue_after = len(_overdue_tasks(after))
if action.action_type == "reply_email":
if critical_after < critical_before:
s += 0.16
elif critical_before > 0:
s += 0.05
if action.action_type in ("reschedule_meeting", "cancel_meeting"):
if conflict_after < conflict_before:
s += 0.15
elif conflict_before > 0:
s += 0.04
if action.action_type in ("complete_task", "delegate_task"):
if overdue_after < overdue_before:
s += 0.12
elif overdue_before > 0:
s += 0.03
# Early episode shaping slightly amplified for better exploration guidance.
if step_index is not None and max_steps and max_steps > 0:
frac = max(0.0, min(1.0, step_index / max_steps))
if frac <= 0.33:
s *= 1.20
elif frac >= 0.85:
s *= 0.90
return max(-_SCAFFOLD_CAP, min(_SCAFFOLD_CAP, s))
def _state_potential(world: WorldState) -> float:
conflicts = len(meeting_conflicts(world))
critical_open = critical_unreplied_count(world)
overdue = len(_overdue_tasks(world))
stress = float(world.stress)
# Lower operational pressure => higher potential.
return -(
1.15 * critical_open
+ 0.90 * conflicts
+ 0.55 * overdue
+ 0.02 * stress
)
def _budgeted_shaping_total(base_weighted_inner: float, shaping_total_inner: float) -> float:
# Keep shaping informative but bounded against the base objective to avoid exploit loops.
budget = _SHAPING_TO_BASE_BUDGET * (abs(base_weighted_inner) + 0.05)
return max(-budget, min(budget, shaping_total_inner))
def _quality_separation_signal(
*,
c: float,
r: float,
t: float,
action: GhostexecAction,
action_ok: bool,
) -> float:
# Amplify distance between clearly good vs clearly bad valid actions.
if not action_ok or action.action_type == "do_nothing":
return 0.0
base = W_CONFLICT * c + W_REL * r + W_TASK * t
if base >= 0.90:
return _QUALITY_CAP
if base >= 0.35:
return 0.12
if base <= -0.90:
return -_QUALITY_CAP
if base <= -0.35:
return -0.12
return 0.0
def aggregate_scores(
conflict: float,
relationship: float,
task: float,
*,
conflict_raw: float,
critical_queue_bonus: float,
weighted_inner: float,
weighted_base_only: float,
shaping_synergy: float,
shaping_tradeoff: float,
shaping_potential: float,
shaping_scaffold: float,
shaping_quality: float,
action_ok: bool,
episode_done: bool,
world_after: WorldState,
) -> RewardBreakdown:
weighted = WEIGHTED_OUTPUT_SCALE * weighted_inner
weighted_base_only_scaled = WEIGHTED_OUTPUT_SCALE * weighted_base_only
shaping_total = WEIGHTED_OUTPUT_SCALE * (
shaping_synergy + shaping_tradeoff + shaping_potential + shaping_scaffold + shaping_quality
)
denom = abs(weighted_base_only_scaled) + 1e-6
shaping_ratio = min(10.0, abs(shaping_total) / denom)
inv = 0.0
if not action_ok:
inv = -0.25
bonus = 0.0
cata = 0.0
if episode_done:
if world_after.stress < 40:
bonus = 10.0
if catastrophic(world_after):
cata = -15.0
final = weighted + inv + bonus + cata
return RewardBreakdown(
conflict_raw=conflict_raw,
critical_queue_bonus=critical_queue_bonus,
conflict=conflict,
relationship=relationship,
task=task,
shaping_synergy=WEIGHTED_OUTPUT_SCALE * shaping_synergy,
shaping_tradeoff=WEIGHTED_OUTPUT_SCALE * shaping_tradeoff,
shaping_potential=WEIGHTED_OUTPUT_SCALE * shaping_potential,
shaping_scaffold=WEIGHTED_OUTPUT_SCALE * shaping_scaffold,
shaping_quality=WEIGHTED_OUTPUT_SCALE * shaping_quality,
shaping_total=shaping_total,
shaping_to_base_ratio=shaping_ratio,
weighted_base=weighted,
output_scale=WEIGHTED_OUTPUT_SCALE,
invalid_step_adjustment=inv,
episode_completion_bonus=bonus,
catastrophic_penalty=cata,
do_nothing_floor=0.0,
final=final,
)
def apply_do_nothing_penalty_floor(
action: GhostexecAction,
breakdown: RewardBreakdown,
) -> RewardBreakdown:
if action.action_type != "do_nothing":
return breakdown
floor_delta = _DO_NOTHING_STRICT_PENALTY
new_final = breakdown.final + floor_delta
return breakdown.model_copy(
update={"do_nothing_floor": floor_delta, "final": new_final},
)
def compute_step_reward(
before: WorldState,
after: WorldState,
action: GhostexecAction,
*,
action_ok: bool,
episode_done: bool,
relationship_suppressed_for_email_to: frozenset[str] | None = None,
reward_mode: str = "full",
step_index: int | None = None,
max_steps: int | None = None,
) -> RewardBreakdown:
c_core = score_conflict_resolution(before, after, action, action_ok=action_ok)
crit_b = score_critical_queue_bonus(before, after)
c_raw = c_core + crit_b
c = max(-CONFLICT_RAW_CAP, min(CONFLICT_RAW_CAP, c_raw))
r = score_relationship(
before,
after,
action,
action_ok=action_ok,
relationship_suppressed_for_email_to=relationship_suppressed_for_email_to,
)
t = score_task_completion(before, after, action, action_ok=action_ok)
weighted_base_only = W_CONFLICT * c + W_REL * r + W_TASK * t
weighted_inner = weighted_base_only
synergy = 0.0
tradeoff_penalty = 0.0
potential_progress = 0.0
scaffold_signal = 0.0
quality_signal = 0.0
if reward_mode in ("full", "shaping"):
# Bounded nonlinear shaping to speed learning without overpowering base channels.
if c > 0.0 and r > 0.0:
synergy += min(_SYNERGY_CAP, 0.18 * math.tanh(0.35 * c * r))
if t > 0.0 and (c > 0.0 or r > 0.0):
bridge = max(c, 0.0) + max(r, 0.0)
synergy += min(_SYNERGY_CAP, 0.10 * math.tanh(0.25 * t * bridge))
if c < -0.5 and r < -0.5:
tradeoff_penalty -= min(_TRADEOFF_CAP, 0.12 * math.tanh(0.25 * abs(c * r)))
if t < -0.5 and (c < 0.0 or r < 0.0):
debt = abs(t) * (abs(min(c, 0.0)) + abs(min(r, 0.0)))
tradeoff_penalty -= min(_TRADEOFF_CAP, 0.08 * math.tanh(0.18 * debt))
potential_progress = max(
-_POTENTIAL_CAP,
min(_POTENTIAL_CAP, _state_potential(after) - _state_potential(before)),
)
scaffold_signal = _scaffold_learning_signal(
before,
after,
action,
action_ok=action_ok,
step_index=step_index,
max_steps=max_steps,
)
quality_signal = _quality_separation_signal(
c=c,
r=r,
t=t,
action=action,
action_ok=action_ok,
)
shaping_total_inner = (
synergy + tradeoff_penalty + potential_progress + scaffold_signal + quality_signal
)
weighted_inner += _budgeted_shaping_total(weighted_base_only, shaping_total_inner)
bd = aggregate_scores(
c,
r,
t,
conflict_raw=c_raw,
critical_queue_bonus=crit_b,
weighted_inner=weighted_inner,
weighted_base_only=weighted_base_only,
shaping_synergy=synergy,
shaping_tradeoff=tradeoff_penalty,
shaping_potential=potential_progress,
shaping_scaffold=scaffold_signal,
shaping_quality=quality_signal,
action_ok=action_ok,
episode_done=episode_done,
world_after=after,
)
return apply_do_nothing_penalty_floor(action, bd)