Spaces:
Sleeping
Sleeping
| """ | |
| RhythmEnv Visual Explorer β Life Simulator v2 | |
| Run: python ui/app.py | |
| """ | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| import gradio as gr | |
| from server.rhythm_environment import ( | |
| RhythmEnvironment, MAX_STEPS, METERS, ACTION_EFFECTS, PROFILES | |
| ) | |
| from models import RhythmAction, ActionType | |
| SLOT_NAMES = ["Morning", "Afternoon", "Evening", "Night"] | |
| SLOT_ICONS = ["π ", "βοΈ", "π", "π"] | |
| DAY_NAMES = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] | |
| PROFILE_NAMES = ["introvert_morning", "extrovert_night_owl", "workaholic_stoic", "random"] | |
| ACTION_NAMES = [at.value.upper() for at in ActionType] | |
| METER_COLORS = { | |
| "vitality": "#3b82f6", | |
| "cognition": "#8b5cf6", | |
| "progress": "#22c55e", | |
| "serenity": "#14b8a6", | |
| "connection": "#f97316", | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Global session state | |
| # --------------------------------------------------------------------------- | |
| _env = None | |
| _last_obs = None | |
| _step_log = [] | |
| _meter_history = [] # list of {meter: value} per step | |
| _completed_slots = [] # (day, slot) pairs already acted on | |
| def get_env(): | |
| global _env | |
| if _env is None: | |
| _env = RhythmEnvironment() | |
| return _env | |
| # --------------------------------------------------------------------------- | |
| # HTML β colored meter bars | |
| # --------------------------------------------------------------------------- | |
| def _bar_color(v: float) -> str: | |
| if v < 0.20: | |
| return "#ef4444" | |
| if v < 0.40: | |
| return "#f59e0b" | |
| return "#22c55e" | |
| def format_meters_html(obs) -> str: | |
| day_name = DAY_NAMES[obs.day] if obs.day < 7 else f"Day {obs.day+1}" | |
| slot_name = SLOT_NAMES[obs.slot] if obs.slot < 4 else f"Slot {obs.slot}" | |
| event_bit = ( | |
| f'<span style="color:#f59e0b;margin-left:8px">β‘ {obs.active_event}</span>' | |
| if obs.active_event else "" | |
| ) | |
| html = f""" | |
| <div style="background:#f9fafb;border-radius:10px;padding:14px 16px;font-family:monospace"> | |
| <div style="font-size:13px;color:#6b7280;margin-bottom:10px"> | |
| π <b>{day_name} {slot_name}</b> | |
| Β· Step {obs.timestep}/{MAX_STEPS} | |
| Β· {obs.remaining_steps} steps left | |
| {event_bit} | |
| </div> | |
| """ | |
| for meter in METERS: | |
| val = getattr(obs, meter) | |
| pct = int(val * 100) | |
| color = _bar_color(val) | |
| dot = METER_COLORS[meter] | |
| html += f""" | |
| <div style="display:flex;align-items:center;gap:8px;margin:5px 0"> | |
| <span style="width:10px;height:10px;border-radius:50%;background:{dot};display:inline-block;flex-shrink:0"></span> | |
| <span style="width:80px;font-size:12px;color:#374151">{meter.capitalize()}</span> | |
| <div style="flex:1;background:#e5e7eb;border-radius:6px;height:16px;overflow:hidden;max-width:260px"> | |
| <div style="width:{pct}%;background:{color};height:16px;border-radius:6px;transition:width 0.25s"></div> | |
| </div> | |
| <span style="width:36px;font-size:12px;color:#374151;text-align:right">{val:.2f}</span> | |
| </div>""" | |
| html += "\n </div>" | |
| return html | |
| # --------------------------------------------------------------------------- | |
| # HTML β week calendar grid | |
| # --------------------------------------------------------------------------- | |
| def format_week_grid(obs) -> str: | |
| html = """ | |
| <div style="background:#f9fafb;border-radius:10px;padding:12px 16px;font-family:monospace;margin-top:8px"> | |
| <div style="font-size:12px;color:#6b7280;margin-bottom:8px">Week Progress</div> | |
| <table style="border-collapse:separate;border-spacing:3px;width:100%"> | |
| <tr> | |
| <td style="width:24px"></td>""" | |
| for day in DAY_NAMES: | |
| html += f'<td style="text-align:center;font-size:11px;color:#9ca3af;padding:1px 3px">{day}</td>' | |
| html += "</tr>" | |
| current_step = obs.timestep # 0-based: next step to take | |
| # timestep goes 0β27; obs.timestep is the step about to be taken | |
| # slots completed = those < current_step | |
| for slot_idx, icon in enumerate(SLOT_ICONS): | |
| html += f'<tr><td style="font-size:12px;text-align:center">{icon}</td>' | |
| for day_idx in range(7): | |
| step_num = day_idx * 4 + slot_idx | |
| if step_num < current_step: | |
| cell = "β " | |
| bg = "#d1fae5" | |
| elif step_num == current_step and not obs.done: | |
| cell = "π΅" | |
| bg = "#dbeafe" | |
| else: | |
| cell = "Β·" | |
| bg = "transparent" | |
| html += f'<td style="text-align:center;background:{bg};border-radius:3px;padding:1px 3px;font-size:13px">{cell}</td>' | |
| html += "</tr>" | |
| html += "</table></div>" | |
| return html | |
| # --------------------------------------------------------------------------- | |
| # Matplotlib β meter trajectory chart | |
| # --------------------------------------------------------------------------- | |
| def make_chart(history: list) -> plt.Figure: | |
| fig, ax = plt.subplots(figsize=(7, 3.5)) | |
| fig.patch.set_facecolor("#f9fafb") | |
| ax.set_facecolor("#f9fafb") | |
| if history: | |
| steps = list(range(len(history))) | |
| for meter, color in METER_COLORS.items(): | |
| vals = [h[meter] for h in history] | |
| ax.plot(steps, vals, color=color, linewidth=2.0, label=meter.capitalize(), solid_capstyle="round") | |
| ax.axhline(y=0.20, color="#ef4444", linestyle="--", linewidth=0.8, alpha=0.4) | |
| patches = [mpatches.Patch(color=c, label=m.capitalize()) for m, c in METER_COLORS.items()] | |
| ax.legend(handles=patches, loc="upper right", fontsize=8, ncol=2, | |
| framealpha=0.7, edgecolor="#e5e7eb") | |
| ax.set_xlim(0, MAX_STEPS) | |
| ax.set_ylim(-0.02, 1.08) | |
| ax.set_xlabel("Step (1 step = 1 time slot)", fontsize=9, color="#6b7280") | |
| ax.set_ylabel("Meter value", fontsize=9, color="#6b7280") | |
| ax.set_title("Life Meters Over the Week", fontsize=11, color="#374151", pad=8) | |
| ax.tick_params(labelsize=8, colors="#9ca3af") | |
| for spine in ax.spines.values(): | |
| spine.set_edgecolor("#e5e7eb") | |
| ax.grid(True, alpha=0.3, color="#d1d5db") | |
| plt.tight_layout(pad=1.2) | |
| return fig | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _snap(obs): | |
| return {m: getattr(obs, m) for m in METERS} | |
| def _step_line(obs, action_name: str) -> str: | |
| sign = "+" if obs.reward >= 0 else "" | |
| day = DAY_NAMES[obs.day] if obs.day < 7 else f"D{obs.day}" | |
| slot = SLOT_NAMES[obs.slot] if obs.slot < 4 else f"S{obs.slot}" | |
| line = f"Step {obs.timestep:>2} [{day} {slot}] {action_name:<15} {sign}{obs.reward:.3f}" | |
| if obs.active_event: | |
| line += f" β‘{obs.active_event}" | |
| return line | |
| # --------------------------------------------------------------------------- | |
| # Tab 1 callbacks | |
| # --------------------------------------------------------------------------- | |
| OUTPUTS_COUNT = 5 # meters_html, week_grid, chart, log, score | |
| def reset_episode(profile_name: str, seed_str: str): | |
| global _last_obs, _step_log, _meter_history | |
| try: | |
| seed = int(seed_str.strip()) if seed_str.strip() else 42 | |
| except ValueError: | |
| seed = 42 | |
| env = get_env() | |
| _last_obs = env.reset(seed=seed) if profile_name == "random" else env.reset(seed=seed, profile=profile_name) | |
| _step_log = [f"βΆ Profile: {env._profile['name']} | Seed: {seed} | 28 steps to go"] | |
| _meter_history = [_snap(_last_obs)] | |
| return ( | |
| format_meters_html(_last_obs), | |
| format_week_grid(_last_obs), | |
| make_chart(_meter_history), | |
| "\n".join(_step_log), | |
| "β", | |
| ) | |
| def take_action(action_str: str): | |
| global _last_obs, _step_log, _meter_history | |
| if _last_obs is None: | |
| return "β οΈ Reset the episode first.", "", make_chart([]), "β", "β" | |
| if _last_obs.done: | |
| return ( | |
| format_meters_html(_last_obs), | |
| format_week_grid(_last_obs), | |
| make_chart(_meter_history), | |
| "\n".join(_step_log[-22:]), | |
| "Episode done β press Reset to play again.", | |
| ) | |
| env = get_env() | |
| obs = env.step(RhythmAction(action_type=ActionType(action_str.lower()))) | |
| _last_obs = obs | |
| _meter_history.append(_snap(obs)) | |
| _step_log.append(_step_line(obs, action_str)) | |
| if obs.done: | |
| final = obs.reward_breakdown.get("final_score", 0.0) | |
| _step_log.append("β" * 52) | |
| _step_log.append(f"β Final score: {final:.4f}") | |
| score = ( | |
| f"Final: {obs.reward_breakdown.get('final_score', 0.0):.4f}" | |
| if obs.done else f"Step reward: {obs.reward:+.4f}" | |
| ) | |
| return ( | |
| format_meters_html(obs), | |
| format_week_grid(obs), | |
| make_chart(_meter_history), | |
| "\n".join(_step_log[-22:]), | |
| score, | |
| ) | |
| def _run_auto(profile_name: str, seed_str: str, strategy: str): | |
| global _last_obs, _step_log, _meter_history | |
| import random as _random | |
| from training.inference_eval import heuristic_action | |
| try: | |
| seed = int(seed_str.strip()) if seed_str.strip() else 42 | |
| except ValueError: | |
| seed = 42 | |
| rng = _random.Random(seed + 999) | |
| all_actions = list(ActionType) | |
| env = get_env() | |
| obs = env.reset(seed=seed) if profile_name == "random" else env.reset(seed=seed, profile=profile_name) | |
| _last_obs = obs | |
| _step_log = [f"βΆ Auto-run ({strategy}) | Profile: {env._profile['name']} | Seed: {seed}"] | |
| _meter_history = [_snap(obs)] | |
| while not obs.done: | |
| action_type = heuristic_action(obs) if strategy == "heuristic" else rng.choice(all_actions) | |
| obs = env.step(RhythmAction(action_type=action_type)) | |
| _last_obs = obs | |
| _meter_history.append(_snap(obs)) | |
| _step_log.append(_step_line(obs, action_type.value.upper())) | |
| final = obs.reward_breakdown.get("final_score", 0.0) | |
| _step_log += ["β" * 52, f"β Final score: {final:.4f}"] | |
| return ( | |
| format_meters_html(obs), | |
| format_week_grid(obs), | |
| make_chart(_meter_history), | |
| "\n".join(_step_log[-25:]), | |
| f"Final: {final:.4f}", | |
| ) | |
| def run_heuristic(p, s): return _run_auto(p, s, "heuristic") | |
| def run_random(p, s): return _run_auto(p, s, "random") | |
| # --------------------------------------------------------------------------- | |
| # Reference tab helpers | |
| # --------------------------------------------------------------------------- | |
| def show_action_effects() -> str: | |
| header = f"{'Action':<15}" + "".join(f" {m[:3]:>6}" for m in METERS) | |
| lines = [header, "β" * 52] | |
| for action, effects in ACTION_EFFECTS.items(): | |
| row = f"{action:<15}" | |
| for m in METERS: | |
| row += f" {effects[m]:>+6.2f}" | |
| lines.append(row) | |
| return "\n".join(lines) | |
| def show_profiles() -> str: | |
| lines = [] | |
| for p in PROFILES: | |
| weights = p["reward_weights"] | |
| lines += [f"\n{'β'*52}", f" {p['name'].upper()}", f"{'β'*52}"] | |
| lines.append(" Reward weights (hidden from agent):") | |
| for m, w in weights.items(): | |
| bar = "β" * int(w * 20) | |
| lines.append(f" {m:<12} {bar:<20} {w:.0%}") | |
| lines.append("\n Key hidden modifiers:") | |
| if p.get("morning_cognition_bonus"): | |
| lines.append(f" β’ Morning: cognition/progress Γ{p['morning_cognition_bonus']} (peak window)") | |
| if p.get("evening_night_cognition_bonus"): | |
| lines.append(f" β’ Evening/Night: cognition/progress Γ{p['evening_night_cognition_bonus']} (peak zone)") | |
| if p.get("morning_penalty"): | |
| lines.append(f" β’ Morning: cognition/progress Γ{p['morning_penalty']} (groggy zone)") | |
| sv = p.get("social_vitality_multiplier", 1.0) | |
| if sv != 1.0: | |
| lines.append(f" β’ Social vitality drain Γ{sv}") | |
| if p.get("binge_shame"): | |
| lines.append(f" β’ Binge watch: shame spiral β0.15 serenity") | |
| if p.get("progress_serenity_bonus"): | |
| lines.append(f" β’ Work gives serenity +{p['progress_serenity_bonus']} (meaning)") | |
| if p.get("idle_serenity_decay"): | |
| lines.append(f" β’ Idle drains serenity β{p['idle_serenity_decay']} (guilt)") | |
| if p.get("work_vitality_recovery"): | |
| lines.append(f" β’ Work recovers vitality +{p['work_vitality_recovery']} (energized)") | |
| if p.get("solo_serenity_bonus"): | |
| lines.append(f" β’ Solo time gives serenity +{p['solo_serenity_bonus']} (recharge)") | |
| scm = p.get("social_connection_multiplier", 1.0) | |
| if scm != 1.0: | |
| lines.append(f" β’ Social connection Γ{scm}") | |
| lines.append(f" β’ Connection passive decay: β{p['connection_decay_rate']}/step") | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # Build UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="RhythmEnv β Life Simulator") as demo: | |
| gr.Markdown( | |
| "# RhythmEnv β Life Simulator\n" | |
| "**Can a lightweight AI learn who you are β without being told?**\n\n" | |
| "Balance 5 life meters across a 7-day week. " | |
| "A hidden personality profile secretly changes how every action affects you. " | |
| "The agent must infer who you are from reward signals alone." | |
| ) | |
| with gr.Tabs(): | |
| # ββ Tab 1: Play βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("βΆ Play"): | |
| with gr.Row(): | |
| profile_dd = gr.Dropdown( | |
| choices=PROFILE_NAMES, value="introvert_morning", | |
| label="Hidden Profile (visible here for demo β agent cannot see this)", | |
| scale=3, | |
| ) | |
| seed_in = gr.Textbox(label="Seed", value="42", scale=1) | |
| reset_btn = gr.Button("β³ Reset", variant="primary", scale=1) | |
| gr.Markdown( | |
| "| Profile | Core trait | What the agent must discover |\n" | |
| "|---|---|---|\n" | |
| "| `introvert_morning` | Recharges alone, peaks at dawn |" | |
| " Social drain Γ3 Β· Morning deep work gives Γ2 progress |\n" | |
| "| `extrovert_night_owl` | Energised by people, peaks at night |" | |
| " Morning is a penalty zone Β· Social gives Γ2 connection |\n" | |
| "| `workaholic_stoic` | Finds meaning in output, resilient |" | |
| " Idle time drains serenity Β· Work recovers vitality |" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| meters_html = gr.HTML() | |
| week_grid_html = gr.HTML() | |
| score_display = gr.Textbox(label="Score", interactive=False, lines=1) | |
| with gr.Column(scale=3): | |
| chart_display = gr.Plot(label="Meter Trajectories") | |
| with gr.Row(): | |
| action_dd = gr.Dropdown( | |
| choices=ACTION_NAMES, value="DEEP_WORK", | |
| label="Choose action", scale=4, | |
| ) | |
| step_btn = gr.Button("βΆ Take Step", variant="primary", scale=1) | |
| with gr.Row(): | |
| heuristic_btn = gr.Button("βΆβΆ Full Episode β Heuristic Baseline") | |
| random_btn = gr.Button("βΆβΆ Full Episode β Random Baseline") | |
| log_display = gr.Textbox( | |
| label="Step Log (last 22 steps)", | |
| lines=10, interactive=False, | |
| ) | |
| # ββ Tab 2: Environment Reference βββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("π Environment Reference"): | |
| gr.Markdown("### Action Effect Matrix") | |
| gr.Markdown( | |
| "Base delta per action on each meter. " | |
| "Profile modifiers and time-of-day multipliers are applied on top β invisibly." | |
| ) | |
| gr.Textbox(value=show_action_effects(), lines=14, interactive=False, label="") | |
| gr.Markdown("### Hidden Personality Profiles") | |
| gr.Markdown( | |
| "The agent **cannot see these** during play. " | |
| "It must infer the active profile through reward patterns β " | |
| "the core learning challenge of RhythmEnv." | |
| ) | |
| gr.Textbox(value=show_profiles(), lines=55, interactive=False, label="") | |
| # ββ Wire up ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _out = [meters_html, week_grid_html, chart_display, log_display, score_display] | |
| reset_btn.click(reset_episode, inputs=[profile_dd, seed_in], outputs=_out) | |
| step_btn.click(take_action, inputs=[action_dd], outputs=_out) | |
| heuristic_btn.click(run_heuristic, inputs=[profile_dd, seed_in], outputs=_out) | |
| random_btn.click(run_random, inputs=[profile_dd, seed_in], outputs=_out) | |
| if __name__ == "__main__": | |
| demo.launch(server_port=7862, share=False, theme=gr.themes.Soft()) | |