hannan2859r commited on
Commit
36c262c
Β·
verified Β·
1 Parent(s): ebf4b94

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +54 -60
inference.py CHANGED
@@ -1,43 +1,44 @@
1
  """
2
  FocusFlow RL Environment β€” inference.py
3
-
4
  LLM agent that:
5
  1. Reads the full observation (NL events, deadlines, cognitive load)
6
  2. Produces a chain-of-thought reasoning string
7
  3. Selects the best action
8
  4. Logs reward curves for showing training progress
9
-
10
  Usage:
11
  export API_BASE_URL=https://api.groq.com/openai/v1
12
  export MODEL_NAME=llama-3.1-8b-instant
 
13
  export ENV_BASE_URL=http://localhost:7860
14
  export TASK_ID=task_1
15
  python inference.py
16
  """
17
-
18
  import os
19
  import json
20
  import requests
21
  import time
22
  from typing import Optional
23
  from openai import OpenAI
24
-
25
  # ── Config ────────────────────────────────────────────────────────────────────
26
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
27
  MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
28
  ENV_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
29
  TASK_ID = os.getenv("TASK_ID", "task_1")
30
  MAX_EPISODES = int(os.getenv("MAX_EPISODES", "5"))
31
-
32
  client = OpenAI(
33
  api_key=os.getenv("OPENAI_API_KEY", os.getenv("GROQ_API_KEY", "dummy")),
34
  base_url=API_BASE_URL,
35
  )
36
-
37
  SYSTEM_PROMPT = """You are FocusAgent β€” an AI assistant helping a student stay focused during study sessions.
38
-
39
  You will receive an observation about the student's current state and must choose the best action.
40
-
41
  AVAILABLE ACTIONS:
42
  - focus : Stay on task, no special action needed
43
  - block_app : Block a distracting app (specify app_name)
@@ -48,22 +49,22 @@ AVAILABLE ACTIONS:
48
  - adjust_energy : Do something to restore energy/focus (stretch, hydrate, etc.)
49
  - check_app : (BAD) Give in to distraction β€” avoid this!
50
  - quit_session : (BAD) End the session early β€” avoid this!
51
-
52
  DECISION FRAMEWORK:
53
  1. Check pending_event first β€” is it urgent? Can it be deferred?
54
  2. Check cognitive_load β€” if > 0.75, consider a break
55
  3. Check deadline_pressure β€” if > 0.7, prioritise study tasks
56
  4. Block apps proactively, especially high-temptation ones
57
  5. Plan the day at the start of each day (step 1)
58
-
59
  CRITICAL: Your `reasoning` field MUST explain:
60
  - What the most important signal in the observation is
61
  - Why you chose this action over alternatives
62
  - How this action serves the long-term goal
63
-
64
  Poor reasoning = lower reward. Think carefully."""
65
-
66
-
67
  def obs_to_prompt(obs: dict, step: int) -> str:
68
  event_str = "None"
69
  if obs.get("pending_event"):
@@ -73,45 +74,45 @@ def obs_to_prompt(obs: dict, step: int) -> str:
73
  f"(urgency={e['urgency']:.2f}, can_defer={e['can_defer']}, "
74
  f"expires_in={e.get('deadline_steps', 'N/A')} steps)"
75
  )
76
-
77
  dc = obs.get("day_context", {})
78
  deadlines = dc.get("pending_deadlines", [])
79
  dl_str = ", ".join(
80
  f"{d['task']} (due step {d['due_step']})"
81
  for d in deadlines if not d.get("completed")
82
  ) or "None"
83
-
84
  return f"""=== STEP {step} OBSERVATION ===
85
-
86
  SESSION STATE:
87
  Phase : {obs['current_phase']}
88
  Time remaining : {obs['time_remaining_seconds']//60}m {obs['time_remaining_seconds']%60}s
89
  Sessions done : {obs['sessions_completed']}
90
  Focus score : {obs['focus_score']:.3f}
91
-
92
  ENVIRONMENT:
93
  Active distractions : {', '.join(obs['active_distractions']) or 'None'}
94
  Blocked apps : {', '.join(obs['blocked_apps']) or 'None'}
95
-
96
  COGNITIVE & ENERGY:
97
  Cognitive load : {obs['cognitive_load']:.2f} {'⚠ HIGH' if obs['cognitive_load'] > 0.75 else 'βœ“ OK'}
98
  Energy level : {dc.get('energy_level', 1.0):.2f}
99
  Deadline pressure : {obs['deadline_pressure']:.2f} {'⚠ URGENT' if obs['deadline_pressure'] > 0.7 else 'βœ“ OK'}
100
-
101
  PENDING EVENT:
102
  {event_str}
103
-
104
  PENDING DEADLINES:
105
  {dl_str}
106
-
107
  LAST FEEDBACK:
108
  {obs['last_action_feedback']}
109
  Last reward : {obs.get('last_action_reward', 0.0):.4f}
110
  Reasoning score : {obs.get('reasoning_quality_score', 0.0):.2f}
111
-
112
  Choose your next action and provide clear reasoning."""
113
-
114
-
115
  def call_llm(prompt: str) -> dict:
116
  """Call the LLM and parse its action response."""
117
  response = client.chat.completions.create(
@@ -125,40 +126,37 @@ def call_llm(prompt: str) -> dict:
125
  )
126
  raw = response.choices[0].message.content
127
  return json.loads(raw)
128
-
129
-
130
  def run_episode(task_id: str, episode: int) -> dict:
131
  """Run a single episode. Returns episode stats."""
132
- # Reset environment
133
  r = requests.post(f"{ENV_URL}/reset", params={"task_id": task_id})
134
  r.raise_for_status()
135
  obs = r.json()
136
-
137
  step = 0
138
  total_reward = 0.0
139
  reward_history = []
140
  done = False
141
-
 
142
  print(f"\n{'='*60}")
143
  print(f"EPISODE {episode+1} | Task: {task_id}")
144
  print(f"{'='*60}")
145
-
146
  while not done:
147
  step += 1
148
  prompt = obs_to_prompt(obs, step)
149
-
150
- # LLM decides action
151
  try:
152
  action = call_llm(prompt)
153
  except Exception as e:
154
  print(f" [LLM ERROR] {e} β€” defaulting to focus")
155
  action = {"action_type": "focus", "reasoning": "LLM call failed, defaulting to focus."}
156
-
157
- # Ensure reasoning is present
158
  if not action.get("reasoning") or len(action["reasoning"].strip()) < 10:
159
  action["reasoning"] = "Staying focused to complete the session efficiently."
160
-
161
- # Step environment
162
  try:
163
  resp = requests.post(f"{ENV_URL}/step", json=action)
164
  resp.raise_for_status()
@@ -166,49 +164,47 @@ def run_episode(task_id: str, episode: int) -> dict:
166
  except Exception as e:
167
  print(f" [ENV ERROR] {e}")
168
  break
169
-
170
  obs = result
171
  reward = result.get("last_action_reward", 0.0)
172
  done = result.get("done", False)
173
-
174
  total_reward += reward
175
  reward_history.append(reward)
176
-
177
  print(
178
  f" Step {step:3d} | action={action['action_type']:<18} "
179
  f"reward={reward:+.4f} | cumulative={total_reward:.4f} | "
180
  f"reasoning_q={result.get('reasoning_quality_score', 0):.2f}"
181
  )
182
-
183
  time.sleep(0.1) # rate limit
184
-
185
- success = result.get("success", False) if not isinstance(result, Exception) else False
186
  print(f"\n Episode {episode+1} done. Total reward: {total_reward:.4f} | Success: {success}")
187
-
188
  return {
189
- "episode": episode + 1,
190
- "total_reward": round(total_reward, 4),
191
- "steps": step,
192
- "success": success,
193
  "reward_history": reward_history,
194
  }
195
-
196
-
197
  def main():
198
  print(f"FocusFlow Agent | Model: {MODEL_NAME} | Task: {TASK_ID}")
199
  print(f"Environment: {ENV_URL}")
200
  print()
201
-
202
- # Health check
203
  h = requests.get(f"{ENV_URL}/health")
204
  print(f"Health: {h.json()}\n")
205
-
206
  all_stats = []
207
  for ep in range(MAX_EPISODES):
208
  stats = run_episode(TASK_ID, ep)
209
  all_stats.append(stats)
210
-
211
- # Summary
212
  rewards = [s["total_reward"] for s in all_stats]
213
  print(f"\n{'='*60}")
214
  print(f"SUMMARY over {MAX_EPISODES} episodes:")
@@ -216,13 +212,11 @@ def main():
216
  print(f" Mean : {sum(rewards)/len(rewards):.4f}")
217
  print(f" Best : {max(rewards):.4f}")
218
  print(f" Success : {sum(s['success'] for s in all_stats)}/{MAX_EPISODES}")
219
-
220
- # Save for reward curve plotting
221
  with open("reward_log.json", "w") as f:
222
  json.dump(all_stats, f, indent=2)
223
  print("\nReward log saved to reward_log.json")
224
-
225
-
226
  if __name__ == "__main__":
227
- main()
228
-
 
1
  """
2
  FocusFlow RL Environment β€” inference.py
3
+
4
  LLM agent that:
5
  1. Reads the full observation (NL events, deadlines, cognitive load)
6
  2. Produces a chain-of-thought reasoning string
7
  3. Selects the best action
8
  4. Logs reward curves for showing training progress
9
+
10
  Usage:
11
  export API_BASE_URL=https://api.groq.com/openai/v1
12
  export MODEL_NAME=llama-3.1-8b-instant
13
+ export GROQ_API_KEY=your_groq_key_here
14
  export ENV_BASE_URL=http://localhost:7860
15
  export TASK_ID=task_1
16
  python inference.py
17
  """
18
+
19
  import os
20
  import json
21
  import requests
22
  import time
23
  from typing import Optional
24
  from openai import OpenAI
25
+
26
  # ── Config ────────────────────────────────────────────────────────────────────
27
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
28
  MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
29
  ENV_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
30
  TASK_ID = os.getenv("TASK_ID", "task_1")
31
  MAX_EPISODES = int(os.getenv("MAX_EPISODES", "5"))
32
+
33
  client = OpenAI(
34
  api_key=os.getenv("OPENAI_API_KEY", os.getenv("GROQ_API_KEY", "dummy")),
35
  base_url=API_BASE_URL,
36
  )
37
+
38
  SYSTEM_PROMPT = """You are FocusAgent β€” an AI assistant helping a student stay focused during study sessions.
39
+
40
  You will receive an observation about the student's current state and must choose the best action.
41
+
42
  AVAILABLE ACTIONS:
43
  - focus : Stay on task, no special action needed
44
  - block_app : Block a distracting app (specify app_name)
 
49
  - adjust_energy : Do something to restore energy/focus (stretch, hydrate, etc.)
50
  - check_app : (BAD) Give in to distraction β€” avoid this!
51
  - quit_session : (BAD) End the session early β€” avoid this!
52
+
53
  DECISION FRAMEWORK:
54
  1. Check pending_event first β€” is it urgent? Can it be deferred?
55
  2. Check cognitive_load β€” if > 0.75, consider a break
56
  3. Check deadline_pressure β€” if > 0.7, prioritise study tasks
57
  4. Block apps proactively, especially high-temptation ones
58
  5. Plan the day at the start of each day (step 1)
59
+
60
  CRITICAL: Your `reasoning` field MUST explain:
61
  - What the most important signal in the observation is
62
  - Why you chose this action over alternatives
63
  - How this action serves the long-term goal
64
+
65
  Poor reasoning = lower reward. Think carefully."""
66
+
67
+
68
  def obs_to_prompt(obs: dict, step: int) -> str:
69
  event_str = "None"
70
  if obs.get("pending_event"):
 
74
  f"(urgency={e['urgency']:.2f}, can_defer={e['can_defer']}, "
75
  f"expires_in={e.get('deadline_steps', 'N/A')} steps)"
76
  )
77
+
78
  dc = obs.get("day_context", {})
79
  deadlines = dc.get("pending_deadlines", [])
80
  dl_str = ", ".join(
81
  f"{d['task']} (due step {d['due_step']})"
82
  for d in deadlines if not d.get("completed")
83
  ) or "None"
84
+
85
  return f"""=== STEP {step} OBSERVATION ===
86
+
87
  SESSION STATE:
88
  Phase : {obs['current_phase']}
89
  Time remaining : {obs['time_remaining_seconds']//60}m {obs['time_remaining_seconds']%60}s
90
  Sessions done : {obs['sessions_completed']}
91
  Focus score : {obs['focus_score']:.3f}
92
+
93
  ENVIRONMENT:
94
  Active distractions : {', '.join(obs['active_distractions']) or 'None'}
95
  Blocked apps : {', '.join(obs['blocked_apps']) or 'None'}
96
+
97
  COGNITIVE & ENERGY:
98
  Cognitive load : {obs['cognitive_load']:.2f} {'⚠ HIGH' if obs['cognitive_load'] > 0.75 else 'βœ“ OK'}
99
  Energy level : {dc.get('energy_level', 1.0):.2f}
100
  Deadline pressure : {obs['deadline_pressure']:.2f} {'⚠ URGENT' if obs['deadline_pressure'] > 0.7 else 'βœ“ OK'}
101
+
102
  PENDING EVENT:
103
  {event_str}
104
+
105
  PENDING DEADLINES:
106
  {dl_str}
107
+
108
  LAST FEEDBACK:
109
  {obs['last_action_feedback']}
110
  Last reward : {obs.get('last_action_reward', 0.0):.4f}
111
  Reasoning score : {obs.get('reasoning_quality_score', 0.0):.2f}
112
+
113
  Choose your next action and provide clear reasoning."""
114
+
115
+
116
  def call_llm(prompt: str) -> dict:
117
  """Call the LLM and parse its action response."""
118
  response = client.chat.completions.create(
 
126
  )
127
  raw = response.choices[0].message.content
128
  return json.loads(raw)
129
+
130
+
131
  def run_episode(task_id: str, episode: int) -> dict:
132
  """Run a single episode. Returns episode stats."""
 
133
  r = requests.post(f"{ENV_URL}/reset", params={"task_id": task_id})
134
  r.raise_for_status()
135
  obs = r.json()
136
+
137
  step = 0
138
  total_reward = 0.0
139
  reward_history = []
140
  done = False
141
+ result = {}
142
+
143
  print(f"\n{'='*60}")
144
  print(f"EPISODE {episode+1} | Task: {task_id}")
145
  print(f"{'='*60}")
146
+
147
  while not done:
148
  step += 1
149
  prompt = obs_to_prompt(obs, step)
150
+
 
151
  try:
152
  action = call_llm(prompt)
153
  except Exception as e:
154
  print(f" [LLM ERROR] {e} β€” defaulting to focus")
155
  action = {"action_type": "focus", "reasoning": "LLM call failed, defaulting to focus."}
156
+
 
157
  if not action.get("reasoning") or len(action["reasoning"].strip()) < 10:
158
  action["reasoning"] = "Staying focused to complete the session efficiently."
159
+
 
160
  try:
161
  resp = requests.post(f"{ENV_URL}/step", json=action)
162
  resp.raise_for_status()
 
164
  except Exception as e:
165
  print(f" [ENV ERROR] {e}")
166
  break
167
+
168
  obs = result
169
  reward = result.get("last_action_reward", 0.0)
170
  done = result.get("done", False)
171
+
172
  total_reward += reward
173
  reward_history.append(reward)
174
+
175
  print(
176
  f" Step {step:3d} | action={action['action_type']:<18} "
177
  f"reward={reward:+.4f} | cumulative={total_reward:.4f} | "
178
  f"reasoning_q={result.get('reasoning_quality_score', 0):.2f}"
179
  )
180
+
181
  time.sleep(0.1) # rate limit
182
+
183
+ success = result.get("success", False) if result else False
184
  print(f"\n Episode {episode+1} done. Total reward: {total_reward:.4f} | Success: {success}")
185
+
186
  return {
187
+ "episode": episode + 1,
188
+ "total_reward": round(total_reward, 4),
189
+ "steps": step,
190
+ "success": success,
191
  "reward_history": reward_history,
192
  }
193
+
194
+
195
  def main():
196
  print(f"FocusFlow Agent | Model: {MODEL_NAME} | Task: {TASK_ID}")
197
  print(f"Environment: {ENV_URL}")
198
  print()
199
+
 
200
  h = requests.get(f"{ENV_URL}/health")
201
  print(f"Health: {h.json()}\n")
202
+
203
  all_stats = []
204
  for ep in range(MAX_EPISODES):
205
  stats = run_episode(TASK_ID, ep)
206
  all_stats.append(stats)
207
+
 
208
  rewards = [s["total_reward"] for s in all_stats]
209
  print(f"\n{'='*60}")
210
  print(f"SUMMARY over {MAX_EPISODES} episodes:")
 
212
  print(f" Mean : {sum(rewards)/len(rewards):.4f}")
213
  print(f" Best : {max(rewards):.4f}")
214
  print(f" Success : {sum(s['success'] for s in all_stats)}/{MAX_EPISODES}")
215
+
 
216
  with open("reward_log.json", "w") as f:
217
  json.dump(all_stats, f, indent=2)
218
  print("\nReward log saved to reward_log.json")
219
+
220
+
221
  if __name__ == "__main__":
222
+ main()