File size: 8,198 Bytes
02f3ccd
 
c42abce
d9073dc
93557f4
d9073dc
 
dc7554a
c42abce
02f3ccd
 
 
 
 
cf86516
d9073dc
 
 
 
f234241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc7554a
f234241
dc7554a
f234241
 
dc7554a
 
 
 
02f3ccd
 
dc7554a
 
02f3ccd
 
 
f234241
02f3ccd
 
dc7554a
 
 
 
 
02f3ccd
 
dc7554a
 
 
d9073dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc7554a
d9073dc
93557f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc7554a
 
 
d9073dc
dc7554a
 
d9073dc
dc7554a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02f3ccd
 
 
 
1ab78f9
c42abce
 
 
1ab78f9
c42abce
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
from __future__ import annotations

import os
import time
import traceback
from functools import lru_cache
from threading import Lock
from typing import Dict, List, Tuple

import gradio as gr

from environment.env import WorkLifeFirewallEnv


MIN_EVENT_INTERVAL_SECONDS = float(os.getenv("MIN_EVENT_INTERVAL_SECONDS", "3.0"))
_RATE_LIMIT_LOCK = Lock()
_LAST_EVENT_TS = 0.0


def _action_for_policy(policy_style: str, event_id: str) -> str:
    strategic_actions = {
        "E1_staging": "I will fix staging first, post an incident update in 15 minutes, and share ETA.",
        "E2_slack": "I will respond async after I stabilize staging and batch replies at 11:30 AM.",
        "E3_client_email": "I acknowledge the urgency and will send a concrete recovery timeline by today EOD.",
        "E4_leave": "I am escalating leave approval with context and requesting a decision by tomorrow noon.",
        "E5_appraisal": "I will block 90 minutes tomorrow and submit appraisal before Thursday EOD.",
        "E6_oncall": "I cannot swap on-call this week; I can help async with runbook notes.",
        "E7_standup": "I will skip the 10:30 PM standup and send an async status update instead.",
    }
    people_pleaser_actions = {
        "E1_staging": "Sure, I will handle it now and stay online until everything is done.",
        "E2_slack": "Sure, I will reply to everyone immediately.",
        "E3_client_email": "Sure, I will take full ownership and deliver whatever is needed tonight.",
        "E4_leave": "No worries, I can postpone leave if needed.",
        "E5_appraisal": "Sure, I will do appraisal tonight after work.",
        "E6_oncall": "Sure, I will take your on-call shift again.",
        "E7_standup": "Sure, I will attend the 10:30 PM standup.",
    }
    balanced_actions = {
        "E1_staging": "I will handle staging now and share progress updates every 30 minutes.",
        "E2_slack": "I will prioritize urgent Slack items first and answer the rest asynchronously.",
        "E3_client_email": "I will send a calm status note with next steps and timeline.",
        "E4_leave": "I will follow up respectfully on leave approval and ask for a clear response date.",
        "E5_appraisal": "I will reserve focused time this week and finish appraisal before Friday.",
        "E6_oncall": "I cannot fully swap this time, but I can help with handover notes.",
        "E7_standup": "I will share async updates and join only if there is a critical blocker.",
    }

    if policy_style == "strategic":
        return strategic_actions.get(event_id, "I will send a clear async update and commit a timeline.")
    if policy_style == "people_pleaser":
        return people_pleaser_actions.get(event_id, "Sure, I will do it now.")
    return balanced_actions.get(event_id, "I will handle this with a clear plan.")


def _run_single_episode(policy_style: str, seed: int, randomize_order: bool) -> Tuple[str, Dict[str, object], Dict[str, float]]:
    env = WorkLifeFirewallEnv(randomize_order=randomize_order, seed=seed)
    obs = env.reset()
    done = False
    logs: List[str] = []
    components: Dict[str, float] = {}

    while not done:
        event = obs["event"]
        action = _action_for_policy(policy_style, event["id"])

        obs, reward, done, info = env.step(action)
        logs.append(
            f"{event['id']} | reward={reward:.3f} | action={action}"
        )
        if info.get("components"):
            components = info["components"]

    state = env.state()
    return "\n".join(logs), state, components


def _throttle_event_requests() -> None:
    global _LAST_EVENT_TS
    with _RATE_LIMIT_LOCK:
        now = time.monotonic()
        remaining = MIN_EVENT_INTERVAL_SECONDS - (now - _LAST_EVENT_TS)
        if remaining > 0:
            time.sleep(remaining)
        _LAST_EVENT_TS = time.monotonic()


@lru_cache(maxsize=256)
def _cached_episode(policy_style: str, seed: int, randomize_order: bool) -> Tuple[str, Dict[str, object], Dict[str, float]]:
    # Cache deterministic episodes so repeated button clicks do not trigger repeated backend work.
    return _run_single_episode(policy_style, seed, randomize_order)


def run_episode(policy_style: str, seed: int, randomize_order: bool):
    _throttle_event_requests()
    try:
        logs, state, components = _cached_episode(policy_style, seed, randomize_order)
        episode_log = "\n".join([
            "Episode started.",
            logs,
        ])
        summary = (
            "### Outcome\n"
            f"- Friday energy: **{state['energy_pct']}%**\n"
            f"- Sprint health: **{state['sprint_health_pct']}%**\n"
            f"- Leave status: **{state['leave_status']}**\n"
        )
        comp_rows = [[k, round(v, 3)] for k, v in sorted(components.items())]
        if not comp_rows:
            comp_rows = [["(none)", 0.0]]
        return episode_log, summary, comp_rows
    except Exception as exc:
        episode_log = "Episode failed:\n" + "".join(
            traceback.format_exception(type(exc), exc, exc.__traceback__)
        )
        summary = (
            "### Outcome\n"
            f"- Error: **{type(exc).__name__}**\n"
            f"- Message: **{exc}**\n"
        )
        return episode_log, summary, [["error", 1.0]]


def compare_policies(seed: int, randomize_order: bool):
    _throttle_event_requests()
    rows = []
    for policy in ["strategic", "balanced", "people_pleaser"]:
        _, state, components = _cached_episode(policy, seed, randomize_order)
        rows.append([
            policy,
            state["energy_pct"],
            state["sprint_health_pct"],
            state["leave_status"],
            round(float(components.get("boundary_setting", 0.0)), 3),
            round(float(components.get("communication_quality", 0.0)), 3),
        ])
    return rows


with gr.Blocks(title="Work-Life Firewall") as demo:
    gr.Markdown(
        "# Work-Life Firewall\n"
        "Train and evaluate boundary-setting behavior in a realistic software work-week simulation."
    )

    with gr.Row():
        policy_style = gr.Radio(
            choices=["strategic", "balanced", "people_pleaser"],
            value="balanced",
            label="Policy style",
        )
        seed = gr.Slider(minimum=1, maximum=9999, value=42, step=1, label="Seed")
        randomize_order = gr.Checkbox(value=False, label="Randomize event order")

    with gr.Row():
        run_btn = gr.Button("Run Single Episode", variant="primary")
        compare_btn = gr.Button("Compare All Policies")

    with gr.Row():
        episode_log = gr.Textbox(label="Episode log", lines=14)
        outcome_md = gr.Markdown()

    component_table = gr.Dataframe(
        headers=["component", "score"],
        datatype=["str", "number"],
        row_count=(5, "dynamic"),
        label="Rubric component scores",
    )

    comparison_table = gr.Dataframe(
        headers=[
            "policy",
            "friday_energy_pct",
            "sprint_health_pct",
            "leave_status",
            "boundary_setting",
            "communication_quality",
        ],
        datatype=["str", "number", "number", "str", "number", "number"],
        row_count=(3, "fixed"),
        label="Policy comparison",
    )

    run_btn.click(
        fn=run_episode,
        inputs=[policy_style, seed, randomize_order],
        outputs=[episode_log, outcome_md, component_table],
    )

    compare_btn.click(
        fn=compare_policies,
        inputs=[seed, randomize_order],
        outputs=[comparison_table],
    )

    gr.Examples(
        examples=[
            ["strategic", 42, False],
            ["balanced", 42, False],
            ["people_pleaser", 42, False],
            ["strategic", 7, True],
        ],
        inputs=[policy_style, seed, randomize_order],
        outputs=[episode_log, outcome_md, component_table],
        fn=run_episode,
        cache_examples=False,
    )


if __name__ == "__main__":
    on_hugging_face_space = bool(os.getenv("SPACE_ID"))
    demo.launch(
        server_name="0.0.0.0",
        server_port=int(os.getenv("PORT", "7860")),
        share=not on_hugging_face_space,
        show_error=True,
    )