File size: 7,771 Bytes
fdd45f1
 
36c262c
b5bdc31
 
 
 
 
36c262c
b5bdc31
 
 
36c262c
b5bdc31
 
 
fdd45f1
36c262c
fdd45f1
 
b5bdc31
 
 
fdd45f1
36c262c
b5bdc31
 
 
 
 
 
36c262c
b5bdc31
 
 
 
36c262c
b5bdc31
36c262c
b5bdc31
36c262c
b5bdc31
 
 
 
 
 
 
 
 
 
36c262c
b5bdc31
 
 
 
 
 
36c262c
b5bdc31
 
 
 
36c262c
b5bdc31
36c262c
 
b5bdc31
 
 
 
 
 
 
 
 
36c262c
b5bdc31
 
 
 
 
 
36c262c
b5bdc31
36c262c
b5bdc31
 
 
 
 
36c262c
b5bdc31
 
 
36c262c
b5bdc31
 
 
 
36c262c
b5bdc31
 
36c262c
b5bdc31
 
36c262c
b5bdc31
 
 
 
36c262c
b5bdc31
36c262c
 
b5bdc31
 
 
fdd45f1
b5bdc31
 
 
 
 
 
fdd45f1
b5bdc31
 
36c262c
 
b5bdc31
 
 
 
 
36c262c
b5bdc31
 
 
 
36c262c
 
b5bdc31
 
 
36c262c
b5bdc31
fdd45f1
b5bdc31
36c262c
b5bdc31
 
 
 
 
36c262c
b5bdc31
 
36c262c
fdd45f1
b5bdc31
 
 
fdd45f1
b5bdc31
 
36c262c
b5bdc31
 
 
36c262c
fdd45f1
b5bdc31
36c262c
b5bdc31
 
 
 
 
36c262c
b5bdc31
36c262c
 
b5bdc31
36c262c
fdd45f1
36c262c
 
 
 
b5bdc31
fdd45f1
36c262c
 
fdd45f1
b5bdc31
 
 
36c262c
b5bdc31
 
36c262c
b5bdc31
 
 
 
36c262c
b5bdc31
 
 
 
 
 
 
36c262c
b5bdc31
 
 
36c262c
 
fdd45f1
36c262c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
"""
FocusFlow RL Environment β€” inference.py

LLM agent that:
  1. Reads the full observation (NL events, deadlines, cognitive load)
  2. Produces a chain-of-thought reasoning string
  3. Selects the best action
  4. Logs reward curves for showing training progress

Usage:
  export API_BASE_URL=https://api.groq.com/openai/v1
  export MODEL_NAME=llama-3.1-8b-instant
  export GROQ_API_KEY=your_groq_key_here
  export ENV_BASE_URL=http://localhost:7860
  export TASK_ID=task_1
  python inference.py
"""

import os
import json
import requests
import time
from typing import Optional
from openai import OpenAI

# ── Config ────────────────────────────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
MODEL_NAME   = os.getenv("MODEL_NAME",   "llama-3.1-8b-instant")
ENV_URL      = os.getenv("ENV_BASE_URL", "http://localhost:7860")
TASK_ID      = os.getenv("TASK_ID",      "task_1")
MAX_EPISODES = int(os.getenv("MAX_EPISODES", "5"))

client = OpenAI(
    api_key=os.getenv("OPENAI_API_KEY", os.getenv("GROQ_API_KEY", "dummy")),
    base_url=API_BASE_URL,
)

SYSTEM_PROMPT = """You are FocusAgent β€” an AI assistant helping a student stay focused during study sessions.

You will receive an observation about the student's current state and must choose the best action.

AVAILABLE ACTIONS:
- focus              : Stay on task, no special action needed
- block_app          : Block a distracting app (specify app_name)
- take_break         : Take a study break
- defer_event        : Postpone handling a distraction event for later
- respond_to_event   : Immediately handle an urgent event (provide response_text)
- plan_day           : Set a day plan (provide day_plan as a list of steps)
- adjust_energy      : Do something to restore energy/focus (stretch, hydrate, etc.)
- check_app          : (BAD) Give in to distraction β€” avoid this!
- quit_session       : (BAD) End the session early β€” avoid this!

DECISION FRAMEWORK:
1. Check pending_event first β€” is it urgent? Can it be deferred?
2. Check cognitive_load β€” if > 0.75, consider a break
3. Check deadline_pressure β€” if > 0.7, prioritise study tasks
4. Block apps proactively, especially high-temptation ones
5. Plan the day at the start of each day (step 1)

CRITICAL: Your `reasoning` field MUST explain:
  - What the most important signal in the observation is
  - Why you chose this action over alternatives
  - How this action serves the long-term goal

Poor reasoning = lower reward. Think carefully."""


def obs_to_prompt(obs: dict, step: int) -> str:
    event_str = "None"
    if obs.get("pending_event"):
        e = obs["pending_event"]
        event_str = (
            f"[{e['type']}] {e['description']} "
            f"(urgency={e['urgency']:.2f}, can_defer={e['can_defer']}, "
            f"expires_in={e.get('deadline_steps', 'N/A')} steps)"
        )

    dc = obs.get("day_context", {})
    deadlines = dc.get("pending_deadlines", [])
    dl_str = ", ".join(
        f"{d['task']} (due step {d['due_step']})"
        for d in deadlines if not d.get("completed")
    ) or "None"

    return f"""=== STEP {step} OBSERVATION ===

SESSION STATE:
  Phase            : {obs['current_phase']}
  Time remaining   : {obs['time_remaining_seconds']//60}m {obs['time_remaining_seconds']%60}s
  Sessions done    : {obs['sessions_completed']}
  Focus score      : {obs['focus_score']:.3f}

ENVIRONMENT:
  Active distractions : {', '.join(obs['active_distractions']) or 'None'}
  Blocked apps        : {', '.join(obs['blocked_apps']) or 'None'}

COGNITIVE & ENERGY:
  Cognitive load    : {obs['cognitive_load']:.2f}  {'⚠ HIGH' if obs['cognitive_load'] > 0.75 else 'βœ“ OK'}
  Energy level      : {dc.get('energy_level', 1.0):.2f}
  Deadline pressure : {obs['deadline_pressure']:.2f}  {'⚠ URGENT' if obs['deadline_pressure'] > 0.7 else 'βœ“ OK'}

PENDING EVENT:
  {event_str}

PENDING DEADLINES:
  {dl_str}

LAST FEEDBACK:
  {obs['last_action_feedback']}
  Last reward      : {obs.get('last_action_reward', 0.0):.4f}
  Reasoning score  : {obs.get('reasoning_quality_score', 0.0):.2f}

Choose your next action and provide clear reasoning."""


def call_llm(prompt: str) -> dict:
    """Call the LLM and parse its action response."""
    response = client.chat.completions.create(
        model=MODEL_NAME,
        temperature=0.3,
        messages=[
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user",   "content": prompt + "\n\nRespond ONLY with valid JSON matching FocusAction schema."}
        ],
        response_format={"type": "json_object"},
    )
    raw = response.choices[0].message.content
    return json.loads(raw)


def run_episode(task_id: str, episode: int) -> dict:
    """Run a single episode. Returns episode stats."""
    r = requests.post(f"{ENV_URL}/reset", params={"task_id": task_id})
    r.raise_for_status()
    obs = r.json()

    step           = 0
    total_reward   = 0.0
    reward_history = []
    done           = False
    result         = {}

    print(f"\n{'='*60}")
    print(f"EPISODE {episode+1} | Task: {task_id}")
    print(f"{'='*60}")

    while not done:
        step += 1
        prompt = obs_to_prompt(obs, step)

        try:
            action = call_llm(prompt)
        except Exception as e:
            print(f"  [LLM ERROR] {e} β€” defaulting to focus")
            action = {"action_type": "focus", "reasoning": "LLM call failed, defaulting to focus."}

        if not action.get("reasoning") or len(action["reasoning"].strip()) < 10:
            action["reasoning"] = "Staying focused to complete the session efficiently."

        try:
            resp = requests.post(f"{ENV_URL}/step", json=action)
            resp.raise_for_status()
            result = resp.json()
        except Exception as e:
            print(f"  [ENV ERROR] {e}")
            break

        obs    = result
        reward = result.get("last_action_reward", 0.0)
        done   = result.get("done", False)

        total_reward += reward
        reward_history.append(reward)

        print(
            f"  Step {step:3d} | action={action['action_type']:<18} "
            f"reward={reward:+.4f} | cumulative={total_reward:.4f} | "
            f"reasoning_q={result.get('reasoning_quality_score', 0):.2f}"
        )

        time.sleep(0.1)   # rate limit

    success = result.get("success", False) if result else False
    print(f"\n  Episode {episode+1} done. Total reward: {total_reward:.4f} | Success: {success}")

    return {
        "episode":        episode + 1,
        "total_reward":   round(total_reward, 4),
        "steps":          step,
        "success":        success,
        "reward_history": reward_history,
    }


def main():
    print(f"FocusFlow Agent | Model: {MODEL_NAME} | Task: {TASK_ID}")
    print(f"Environment: {ENV_URL}")
    print()

    h = requests.get(f"{ENV_URL}/health")
    print(f"Health: {h.json()}\n")

    all_stats = []
    for ep in range(MAX_EPISODES):
        stats = run_episode(TASK_ID, ep)
        all_stats.append(stats)

    rewards = [s["total_reward"] for s in all_stats]
    print(f"\n{'='*60}")
    print(f"SUMMARY over {MAX_EPISODES} episodes:")
    print(f"  Rewards : {[f'{r:.3f}' for r in rewards]}")
    print(f"  Mean    : {sum(rewards)/len(rewards):.4f}")
    print(f"  Best    : {max(rewards):.4f}")
    print(f"  Success : {sum(s['success'] for s in all_stats)}/{MAX_EPISODES}")

    with open("reward_log.json", "w") as f:
        json.dump(all_stats, f, indent=2)
    print("\nReward log saved to reward_log.json")


if __name__ == "__main__":
    main()