Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +54 -60
inference.py
CHANGED
|
@@ -1,43 +1,44 @@
|
|
| 1 |
"""
|
| 2 |
FocusFlow RL Environment β inference.py
|
| 3 |
-
|
| 4 |
LLM agent that:
|
| 5 |
1. Reads the full observation (NL events, deadlines, cognitive load)
|
| 6 |
2. Produces a chain-of-thought reasoning string
|
| 7 |
3. Selects the best action
|
| 8 |
4. Logs reward curves for showing training progress
|
| 9 |
-
|
| 10 |
Usage:
|
| 11 |
export API_BASE_URL=https://api.groq.com/openai/v1
|
| 12 |
export MODEL_NAME=llama-3.1-8b-instant
|
|
|
|
| 13 |
export ENV_BASE_URL=http://localhost:7860
|
| 14 |
export TASK_ID=task_1
|
| 15 |
python inference.py
|
| 16 |
"""
|
| 17 |
-
|
| 18 |
import os
|
| 19 |
import json
|
| 20 |
import requests
|
| 21 |
import time
|
| 22 |
from typing import Optional
|
| 23 |
from openai import OpenAI
|
| 24 |
-
|
| 25 |
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 27 |
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
|
| 28 |
ENV_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
|
| 29 |
TASK_ID = os.getenv("TASK_ID", "task_1")
|
| 30 |
MAX_EPISODES = int(os.getenv("MAX_EPISODES", "5"))
|
| 31 |
-
|
| 32 |
client = OpenAI(
|
| 33 |
api_key=os.getenv("OPENAI_API_KEY", os.getenv("GROQ_API_KEY", "dummy")),
|
| 34 |
base_url=API_BASE_URL,
|
| 35 |
)
|
| 36 |
-
|
| 37 |
SYSTEM_PROMPT = """You are FocusAgent β an AI assistant helping a student stay focused during study sessions.
|
| 38 |
-
|
| 39 |
You will receive an observation about the student's current state and must choose the best action.
|
| 40 |
-
|
| 41 |
AVAILABLE ACTIONS:
|
| 42 |
- focus : Stay on task, no special action needed
|
| 43 |
- block_app : Block a distracting app (specify app_name)
|
|
@@ -48,22 +49,22 @@ AVAILABLE ACTIONS:
|
|
| 48 |
- adjust_energy : Do something to restore energy/focus (stretch, hydrate, etc.)
|
| 49 |
- check_app : (BAD) Give in to distraction β avoid this!
|
| 50 |
- quit_session : (BAD) End the session early β avoid this!
|
| 51 |
-
|
| 52 |
DECISION FRAMEWORK:
|
| 53 |
1. Check pending_event first β is it urgent? Can it be deferred?
|
| 54 |
2. Check cognitive_load β if > 0.75, consider a break
|
| 55 |
3. Check deadline_pressure β if > 0.7, prioritise study tasks
|
| 56 |
4. Block apps proactively, especially high-temptation ones
|
| 57 |
5. Plan the day at the start of each day (step 1)
|
| 58 |
-
|
| 59 |
CRITICAL: Your `reasoning` field MUST explain:
|
| 60 |
- What the most important signal in the observation is
|
| 61 |
- Why you chose this action over alternatives
|
| 62 |
- How this action serves the long-term goal
|
| 63 |
-
|
| 64 |
Poor reasoning = lower reward. Think carefully."""
|
| 65 |
-
|
| 66 |
-
|
| 67 |
def obs_to_prompt(obs: dict, step: int) -> str:
|
| 68 |
event_str = "None"
|
| 69 |
if obs.get("pending_event"):
|
|
@@ -73,45 +74,45 @@ def obs_to_prompt(obs: dict, step: int) -> str:
|
|
| 73 |
f"(urgency={e['urgency']:.2f}, can_defer={e['can_defer']}, "
|
| 74 |
f"expires_in={e.get('deadline_steps', 'N/A')} steps)"
|
| 75 |
)
|
| 76 |
-
|
| 77 |
dc = obs.get("day_context", {})
|
| 78 |
deadlines = dc.get("pending_deadlines", [])
|
| 79 |
dl_str = ", ".join(
|
| 80 |
f"{d['task']} (due step {d['due_step']})"
|
| 81 |
for d in deadlines if not d.get("completed")
|
| 82 |
) or "None"
|
| 83 |
-
|
| 84 |
return f"""=== STEP {step} OBSERVATION ===
|
| 85 |
-
|
| 86 |
SESSION STATE:
|
| 87 |
Phase : {obs['current_phase']}
|
| 88 |
Time remaining : {obs['time_remaining_seconds']//60}m {obs['time_remaining_seconds']%60}s
|
| 89 |
Sessions done : {obs['sessions_completed']}
|
| 90 |
Focus score : {obs['focus_score']:.3f}
|
| 91 |
-
|
| 92 |
ENVIRONMENT:
|
| 93 |
Active distractions : {', '.join(obs['active_distractions']) or 'None'}
|
| 94 |
Blocked apps : {', '.join(obs['blocked_apps']) or 'None'}
|
| 95 |
-
|
| 96 |
COGNITIVE & ENERGY:
|
| 97 |
Cognitive load : {obs['cognitive_load']:.2f} {'β HIGH' if obs['cognitive_load'] > 0.75 else 'β OK'}
|
| 98 |
Energy level : {dc.get('energy_level', 1.0):.2f}
|
| 99 |
Deadline pressure : {obs['deadline_pressure']:.2f} {'β URGENT' if obs['deadline_pressure'] > 0.7 else 'β OK'}
|
| 100 |
-
|
| 101 |
PENDING EVENT:
|
| 102 |
{event_str}
|
| 103 |
-
|
| 104 |
PENDING DEADLINES:
|
| 105 |
{dl_str}
|
| 106 |
-
|
| 107 |
LAST FEEDBACK:
|
| 108 |
{obs['last_action_feedback']}
|
| 109 |
Last reward : {obs.get('last_action_reward', 0.0):.4f}
|
| 110 |
Reasoning score : {obs.get('reasoning_quality_score', 0.0):.2f}
|
| 111 |
-
|
| 112 |
Choose your next action and provide clear reasoning."""
|
| 113 |
-
|
| 114 |
-
|
| 115 |
def call_llm(prompt: str) -> dict:
|
| 116 |
"""Call the LLM and parse its action response."""
|
| 117 |
response = client.chat.completions.create(
|
|
@@ -125,40 +126,37 @@ def call_llm(prompt: str) -> dict:
|
|
| 125 |
)
|
| 126 |
raw = response.choices[0].message.content
|
| 127 |
return json.loads(raw)
|
| 128 |
-
|
| 129 |
-
|
| 130 |
def run_episode(task_id: str, episode: int) -> dict:
|
| 131 |
"""Run a single episode. Returns episode stats."""
|
| 132 |
-
# Reset environment
|
| 133 |
r = requests.post(f"{ENV_URL}/reset", params={"task_id": task_id})
|
| 134 |
r.raise_for_status()
|
| 135 |
obs = r.json()
|
| 136 |
-
|
| 137 |
step = 0
|
| 138 |
total_reward = 0.0
|
| 139 |
reward_history = []
|
| 140 |
done = False
|
| 141 |
-
|
|
|
|
| 142 |
print(f"\n{'='*60}")
|
| 143 |
print(f"EPISODE {episode+1} | Task: {task_id}")
|
| 144 |
print(f"{'='*60}")
|
| 145 |
-
|
| 146 |
while not done:
|
| 147 |
step += 1
|
| 148 |
prompt = obs_to_prompt(obs, step)
|
| 149 |
-
|
| 150 |
-
# LLM decides action
|
| 151 |
try:
|
| 152 |
action = call_llm(prompt)
|
| 153 |
except Exception as e:
|
| 154 |
print(f" [LLM ERROR] {e} β defaulting to focus")
|
| 155 |
action = {"action_type": "focus", "reasoning": "LLM call failed, defaulting to focus."}
|
| 156 |
-
|
| 157 |
-
# Ensure reasoning is present
|
| 158 |
if not action.get("reasoning") or len(action["reasoning"].strip()) < 10:
|
| 159 |
action["reasoning"] = "Staying focused to complete the session efficiently."
|
| 160 |
-
|
| 161 |
-
# Step environment
|
| 162 |
try:
|
| 163 |
resp = requests.post(f"{ENV_URL}/step", json=action)
|
| 164 |
resp.raise_for_status()
|
|
@@ -166,49 +164,47 @@ def run_episode(task_id: str, episode: int) -> dict:
|
|
| 166 |
except Exception as e:
|
| 167 |
print(f" [ENV ERROR] {e}")
|
| 168 |
break
|
| 169 |
-
|
| 170 |
obs = result
|
| 171 |
reward = result.get("last_action_reward", 0.0)
|
| 172 |
done = result.get("done", False)
|
| 173 |
-
|
| 174 |
total_reward += reward
|
| 175 |
reward_history.append(reward)
|
| 176 |
-
|
| 177 |
print(
|
| 178 |
f" Step {step:3d} | action={action['action_type']:<18} "
|
| 179 |
f"reward={reward:+.4f} | cumulative={total_reward:.4f} | "
|
| 180 |
f"reasoning_q={result.get('reasoning_quality_score', 0):.2f}"
|
| 181 |
)
|
| 182 |
-
|
| 183 |
time.sleep(0.1) # rate limit
|
| 184 |
-
|
| 185 |
-
success = result.get("success", False) if
|
| 186 |
print(f"\n Episode {episode+1} done. Total reward: {total_reward:.4f} | Success: {success}")
|
| 187 |
-
|
| 188 |
return {
|
| 189 |
-
"episode":
|
| 190 |
-
"total_reward":
|
| 191 |
-
"steps":
|
| 192 |
-
"success":
|
| 193 |
"reward_history": reward_history,
|
| 194 |
}
|
| 195 |
-
|
| 196 |
-
|
| 197 |
def main():
|
| 198 |
print(f"FocusFlow Agent | Model: {MODEL_NAME} | Task: {TASK_ID}")
|
| 199 |
print(f"Environment: {ENV_URL}")
|
| 200 |
print()
|
| 201 |
-
|
| 202 |
-
# Health check
|
| 203 |
h = requests.get(f"{ENV_URL}/health")
|
| 204 |
print(f"Health: {h.json()}\n")
|
| 205 |
-
|
| 206 |
all_stats = []
|
| 207 |
for ep in range(MAX_EPISODES):
|
| 208 |
stats = run_episode(TASK_ID, ep)
|
| 209 |
all_stats.append(stats)
|
| 210 |
-
|
| 211 |
-
# Summary
|
| 212 |
rewards = [s["total_reward"] for s in all_stats]
|
| 213 |
print(f"\n{'='*60}")
|
| 214 |
print(f"SUMMARY over {MAX_EPISODES} episodes:")
|
|
@@ -216,13 +212,11 @@ def main():
|
|
| 216 |
print(f" Mean : {sum(rewards)/len(rewards):.4f}")
|
| 217 |
print(f" Best : {max(rewards):.4f}")
|
| 218 |
print(f" Success : {sum(s['success'] for s in all_stats)}/{MAX_EPISODES}")
|
| 219 |
-
|
| 220 |
-
# Save for reward curve plotting
|
| 221 |
with open("reward_log.json", "w") as f:
|
| 222 |
json.dump(all_stats, f, indent=2)
|
| 223 |
print("\nReward log saved to reward_log.json")
|
| 224 |
-
|
| 225 |
-
|
| 226 |
if __name__ == "__main__":
|
| 227 |
-
main()
|
| 228 |
-
|
|
|
|
| 1 |
"""
|
| 2 |
FocusFlow RL Environment β inference.py
|
| 3 |
+
|
| 4 |
LLM agent that:
|
| 5 |
1. Reads the full observation (NL events, deadlines, cognitive load)
|
| 6 |
2. Produces a chain-of-thought reasoning string
|
| 7 |
3. Selects the best action
|
| 8 |
4. Logs reward curves for showing training progress
|
| 9 |
+
|
| 10 |
Usage:
|
| 11 |
export API_BASE_URL=https://api.groq.com/openai/v1
|
| 12 |
export MODEL_NAME=llama-3.1-8b-instant
|
| 13 |
+
export GROQ_API_KEY=your_groq_key_here
|
| 14 |
export ENV_BASE_URL=http://localhost:7860
|
| 15 |
export TASK_ID=task_1
|
| 16 |
python inference.py
|
| 17 |
"""
|
| 18 |
+
|
| 19 |
import os
|
| 20 |
import json
|
| 21 |
import requests
|
| 22 |
import time
|
| 23 |
from typing import Optional
|
| 24 |
from openai import OpenAI
|
| 25 |
+
|
| 26 |
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 27 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 28 |
MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant")
|
| 29 |
ENV_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860")
|
| 30 |
TASK_ID = os.getenv("TASK_ID", "task_1")
|
| 31 |
MAX_EPISODES = int(os.getenv("MAX_EPISODES", "5"))
|
| 32 |
+
|
| 33 |
client = OpenAI(
|
| 34 |
api_key=os.getenv("OPENAI_API_KEY", os.getenv("GROQ_API_KEY", "dummy")),
|
| 35 |
base_url=API_BASE_URL,
|
| 36 |
)
|
| 37 |
+
|
| 38 |
SYSTEM_PROMPT = """You are FocusAgent β an AI assistant helping a student stay focused during study sessions.
|
| 39 |
+
|
| 40 |
You will receive an observation about the student's current state and must choose the best action.
|
| 41 |
+
|
| 42 |
AVAILABLE ACTIONS:
|
| 43 |
- focus : Stay on task, no special action needed
|
| 44 |
- block_app : Block a distracting app (specify app_name)
|
|
|
|
| 49 |
- adjust_energy : Do something to restore energy/focus (stretch, hydrate, etc.)
|
| 50 |
- check_app : (BAD) Give in to distraction β avoid this!
|
| 51 |
- quit_session : (BAD) End the session early β avoid this!
|
| 52 |
+
|
| 53 |
DECISION FRAMEWORK:
|
| 54 |
1. Check pending_event first β is it urgent? Can it be deferred?
|
| 55 |
2. Check cognitive_load β if > 0.75, consider a break
|
| 56 |
3. Check deadline_pressure β if > 0.7, prioritise study tasks
|
| 57 |
4. Block apps proactively, especially high-temptation ones
|
| 58 |
5. Plan the day at the start of each day (step 1)
|
| 59 |
+
|
| 60 |
CRITICAL: Your `reasoning` field MUST explain:
|
| 61 |
- What the most important signal in the observation is
|
| 62 |
- Why you chose this action over alternatives
|
| 63 |
- How this action serves the long-term goal
|
| 64 |
+
|
| 65 |
Poor reasoning = lower reward. Think carefully."""
|
| 66 |
+
|
| 67 |
+
|
| 68 |
def obs_to_prompt(obs: dict, step: int) -> str:
|
| 69 |
event_str = "None"
|
| 70 |
if obs.get("pending_event"):
|
|
|
|
| 74 |
f"(urgency={e['urgency']:.2f}, can_defer={e['can_defer']}, "
|
| 75 |
f"expires_in={e.get('deadline_steps', 'N/A')} steps)"
|
| 76 |
)
|
| 77 |
+
|
| 78 |
dc = obs.get("day_context", {})
|
| 79 |
deadlines = dc.get("pending_deadlines", [])
|
| 80 |
dl_str = ", ".join(
|
| 81 |
f"{d['task']} (due step {d['due_step']})"
|
| 82 |
for d in deadlines if not d.get("completed")
|
| 83 |
) or "None"
|
| 84 |
+
|
| 85 |
return f"""=== STEP {step} OBSERVATION ===
|
| 86 |
+
|
| 87 |
SESSION STATE:
|
| 88 |
Phase : {obs['current_phase']}
|
| 89 |
Time remaining : {obs['time_remaining_seconds']//60}m {obs['time_remaining_seconds']%60}s
|
| 90 |
Sessions done : {obs['sessions_completed']}
|
| 91 |
Focus score : {obs['focus_score']:.3f}
|
| 92 |
+
|
| 93 |
ENVIRONMENT:
|
| 94 |
Active distractions : {', '.join(obs['active_distractions']) or 'None'}
|
| 95 |
Blocked apps : {', '.join(obs['blocked_apps']) or 'None'}
|
| 96 |
+
|
| 97 |
COGNITIVE & ENERGY:
|
| 98 |
Cognitive load : {obs['cognitive_load']:.2f} {'β HIGH' if obs['cognitive_load'] > 0.75 else 'β OK'}
|
| 99 |
Energy level : {dc.get('energy_level', 1.0):.2f}
|
| 100 |
Deadline pressure : {obs['deadline_pressure']:.2f} {'β URGENT' if obs['deadline_pressure'] > 0.7 else 'β OK'}
|
| 101 |
+
|
| 102 |
PENDING EVENT:
|
| 103 |
{event_str}
|
| 104 |
+
|
| 105 |
PENDING DEADLINES:
|
| 106 |
{dl_str}
|
| 107 |
+
|
| 108 |
LAST FEEDBACK:
|
| 109 |
{obs['last_action_feedback']}
|
| 110 |
Last reward : {obs.get('last_action_reward', 0.0):.4f}
|
| 111 |
Reasoning score : {obs.get('reasoning_quality_score', 0.0):.2f}
|
| 112 |
+
|
| 113 |
Choose your next action and provide clear reasoning."""
|
| 114 |
+
|
| 115 |
+
|
| 116 |
def call_llm(prompt: str) -> dict:
|
| 117 |
"""Call the LLM and parse its action response."""
|
| 118 |
response = client.chat.completions.create(
|
|
|
|
| 126 |
)
|
| 127 |
raw = response.choices[0].message.content
|
| 128 |
return json.loads(raw)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
def run_episode(task_id: str, episode: int) -> dict:
|
| 132 |
"""Run a single episode. Returns episode stats."""
|
|
|
|
| 133 |
r = requests.post(f"{ENV_URL}/reset", params={"task_id": task_id})
|
| 134 |
r.raise_for_status()
|
| 135 |
obs = r.json()
|
| 136 |
+
|
| 137 |
step = 0
|
| 138 |
total_reward = 0.0
|
| 139 |
reward_history = []
|
| 140 |
done = False
|
| 141 |
+
result = {}
|
| 142 |
+
|
| 143 |
print(f"\n{'='*60}")
|
| 144 |
print(f"EPISODE {episode+1} | Task: {task_id}")
|
| 145 |
print(f"{'='*60}")
|
| 146 |
+
|
| 147 |
while not done:
|
| 148 |
step += 1
|
| 149 |
prompt = obs_to_prompt(obs, step)
|
| 150 |
+
|
|
|
|
| 151 |
try:
|
| 152 |
action = call_llm(prompt)
|
| 153 |
except Exception as e:
|
| 154 |
print(f" [LLM ERROR] {e} β defaulting to focus")
|
| 155 |
action = {"action_type": "focus", "reasoning": "LLM call failed, defaulting to focus."}
|
| 156 |
+
|
|
|
|
| 157 |
if not action.get("reasoning") or len(action["reasoning"].strip()) < 10:
|
| 158 |
action["reasoning"] = "Staying focused to complete the session efficiently."
|
| 159 |
+
|
|
|
|
| 160 |
try:
|
| 161 |
resp = requests.post(f"{ENV_URL}/step", json=action)
|
| 162 |
resp.raise_for_status()
|
|
|
|
| 164 |
except Exception as e:
|
| 165 |
print(f" [ENV ERROR] {e}")
|
| 166 |
break
|
| 167 |
+
|
| 168 |
obs = result
|
| 169 |
reward = result.get("last_action_reward", 0.0)
|
| 170 |
done = result.get("done", False)
|
| 171 |
+
|
| 172 |
total_reward += reward
|
| 173 |
reward_history.append(reward)
|
| 174 |
+
|
| 175 |
print(
|
| 176 |
f" Step {step:3d} | action={action['action_type']:<18} "
|
| 177 |
f"reward={reward:+.4f} | cumulative={total_reward:.4f} | "
|
| 178 |
f"reasoning_q={result.get('reasoning_quality_score', 0):.2f}"
|
| 179 |
)
|
| 180 |
+
|
| 181 |
time.sleep(0.1) # rate limit
|
| 182 |
+
|
| 183 |
+
success = result.get("success", False) if result else False
|
| 184 |
print(f"\n Episode {episode+1} done. Total reward: {total_reward:.4f} | Success: {success}")
|
| 185 |
+
|
| 186 |
return {
|
| 187 |
+
"episode": episode + 1,
|
| 188 |
+
"total_reward": round(total_reward, 4),
|
| 189 |
+
"steps": step,
|
| 190 |
+
"success": success,
|
| 191 |
"reward_history": reward_history,
|
| 192 |
}
|
| 193 |
+
|
| 194 |
+
|
| 195 |
def main():
|
| 196 |
print(f"FocusFlow Agent | Model: {MODEL_NAME} | Task: {TASK_ID}")
|
| 197 |
print(f"Environment: {ENV_URL}")
|
| 198 |
print()
|
| 199 |
+
|
|
|
|
| 200 |
h = requests.get(f"{ENV_URL}/health")
|
| 201 |
print(f"Health: {h.json()}\n")
|
| 202 |
+
|
| 203 |
all_stats = []
|
| 204 |
for ep in range(MAX_EPISODES):
|
| 205 |
stats = run_episode(TASK_ID, ep)
|
| 206 |
all_stats.append(stats)
|
| 207 |
+
|
|
|
|
| 208 |
rewards = [s["total_reward"] for s in all_stats]
|
| 209 |
print(f"\n{'='*60}")
|
| 210 |
print(f"SUMMARY over {MAX_EPISODES} episodes:")
|
|
|
|
| 212 |
print(f" Mean : {sum(rewards)/len(rewards):.4f}")
|
| 213 |
print(f" Best : {max(rewards):.4f}")
|
| 214 |
print(f" Success : {sum(s['success'] for s in all_stats)}/{MAX_EPISODES}")
|
| 215 |
+
|
|
|
|
| 216 |
with open("reward_log.json", "w") as f:
|
| 217 |
json.dump(all_stats, f, indent=2)
|
| 218 |
print("\nReward log saved to reward_log.json")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
if __name__ == "__main__":
|
| 222 |
+
main()
|
|
|