Spaces:
Sleeping
Sleeping
| """ | |
| FocusFlow RL Environment β inference.py | |
| LLM agent that: | |
| 1. Reads the full observation (NL events, deadlines, cognitive load) | |
| 2. Produces a chain-of-thought reasoning string | |
| 3. Selects the best action | |
| 4. Logs reward curves for showing training progress | |
| Usage: | |
| export API_BASE_URL=https://api.groq.com/openai/v1 | |
| export MODEL_NAME=llama-3.1-8b-instant | |
| export GROQ_API_KEY=your_groq_key_here | |
| export ENV_BASE_URL=http://localhost:7860 | |
| export TASK_ID=task_1 | |
| python inference.py | |
| """ | |
| import os | |
| import json | |
| import requests | |
| import time | |
| from typing import Optional | |
| from openai import OpenAI | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "llama-3.1-8b-instant") | |
| ENV_URL = os.getenv("ENV_BASE_URL", "http://localhost:7860") | |
| TASK_ID = os.getenv("TASK_ID", "task_1") | |
| MAX_EPISODES = int(os.getenv("MAX_EPISODES", "5")) | |
| client = OpenAI( | |
| api_key=os.getenv("OPENAI_API_KEY", os.getenv("GROQ_API_KEY", "dummy")), | |
| base_url=API_BASE_URL, | |
| ) | |
| SYSTEM_PROMPT = """You are FocusAgent β an AI assistant helping a student stay focused during study sessions. | |
| You will receive an observation about the student's current state and must choose the best action. | |
| AVAILABLE ACTIONS: | |
| - focus : Stay on task, no special action needed | |
| - block_app : Block a distracting app (specify app_name) | |
| - take_break : Take a study break | |
| - defer_event : Postpone handling a distraction event for later | |
| - respond_to_event : Immediately handle an urgent event (provide response_text) | |
| - plan_day : Set a day plan (provide day_plan as a list of steps) | |
| - adjust_energy : Do something to restore energy/focus (stretch, hydrate, etc.) | |
| - check_app : (BAD) Give in to distraction β avoid this! | |
| - quit_session : (BAD) End the session early β avoid this! | |
| DECISION FRAMEWORK: | |
| 1. Check pending_event first β is it urgent? Can it be deferred? | |
| 2. Check cognitive_load β if > 0.75, consider a break | |
| 3. Check deadline_pressure β if > 0.7, prioritise study tasks | |
| 4. Block apps proactively, especially high-temptation ones | |
| 5. Plan the day at the start of each day (step 1) | |
| CRITICAL: Your `reasoning` field MUST explain: | |
| - What the most important signal in the observation is | |
| - Why you chose this action over alternatives | |
| - How this action serves the long-term goal | |
| Poor reasoning = lower reward. Think carefully.""" | |
| def obs_to_prompt(obs: dict, step: int) -> str: | |
| event_str = "None" | |
| if obs.get("pending_event"): | |
| e = obs["pending_event"] | |
| event_str = ( | |
| f"[{e['type']}] {e['description']} " | |
| f"(urgency={e['urgency']:.2f}, can_defer={e['can_defer']}, " | |
| f"expires_in={e.get('deadline_steps', 'N/A')} steps)" | |
| ) | |
| dc = obs.get("day_context", {}) | |
| deadlines = dc.get("pending_deadlines", []) | |
| dl_str = ", ".join( | |
| f"{d['task']} (due step {d['due_step']})" | |
| for d in deadlines if not d.get("completed") | |
| ) or "None" | |
| return f"""=== STEP {step} OBSERVATION === | |
| SESSION STATE: | |
| Phase : {obs['current_phase']} | |
| Time remaining : {obs['time_remaining_seconds']//60}m {obs['time_remaining_seconds']%60}s | |
| Sessions done : {obs['sessions_completed']} | |
| Focus score : {obs['focus_score']:.3f} | |
| ENVIRONMENT: | |
| Active distractions : {', '.join(obs['active_distractions']) or 'None'} | |
| Blocked apps : {', '.join(obs['blocked_apps']) or 'None'} | |
| COGNITIVE & ENERGY: | |
| Cognitive load : {obs['cognitive_load']:.2f} {'β HIGH' if obs['cognitive_load'] > 0.75 else 'β OK'} | |
| Energy level : {dc.get('energy_level', 1.0):.2f} | |
| Deadline pressure : {obs['deadline_pressure']:.2f} {'β URGENT' if obs['deadline_pressure'] > 0.7 else 'β OK'} | |
| PENDING EVENT: | |
| {event_str} | |
| PENDING DEADLINES: | |
| {dl_str} | |
| LAST FEEDBACK: | |
| {obs['last_action_feedback']} | |
| Last reward : {obs.get('last_action_reward', 0.0):.4f} | |
| Reasoning score : {obs.get('reasoning_quality_score', 0.0):.2f} | |
| Choose your next action and provide clear reasoning.""" | |
| def call_llm(prompt: str) -> dict: | |
| """Call the LLM and parse its action response.""" | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| temperature=0.3, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": prompt + "\n\nRespond ONLY with valid JSON matching FocusAction schema."} | |
| ], | |
| response_format={"type": "json_object"}, | |
| ) | |
| raw = response.choices[0].message.content | |
| return json.loads(raw) | |
| def run_episode(task_id: str, episode: int) -> dict: | |
| """Run a single episode. Returns episode stats.""" | |
| r = requests.post(f"{ENV_URL}/reset", params={"task_id": task_id}) | |
| r.raise_for_status() | |
| obs = r.json() | |
| step = 0 | |
| total_reward = 0.0 | |
| reward_history = [] | |
| done = False | |
| result = {} | |
| print(f"\n{'='*60}") | |
| print(f"EPISODE {episode+1} | Task: {task_id}") | |
| print(f"{'='*60}") | |
| while not done: | |
| step += 1 | |
| prompt = obs_to_prompt(obs, step) | |
| try: | |
| action = call_llm(prompt) | |
| except Exception as e: | |
| print(f" [LLM ERROR] {e} β defaulting to focus") | |
| action = {"action_type": "focus", "reasoning": "LLM call failed, defaulting to focus."} | |
| if not action.get("reasoning") or len(action["reasoning"].strip()) < 10: | |
| action["reasoning"] = "Staying focused to complete the session efficiently." | |
| try: | |
| resp = requests.post(f"{ENV_URL}/step", json=action) | |
| resp.raise_for_status() | |
| result = resp.json() | |
| except Exception as e: | |
| print(f" [ENV ERROR] {e}") | |
| break | |
| obs = result | |
| reward = result.get("last_action_reward", 0.0) | |
| done = result.get("done", False) | |
| total_reward += reward | |
| reward_history.append(reward) | |
| print( | |
| f" Step {step:3d} | action={action['action_type']:<18} " | |
| f"reward={reward:+.4f} | cumulative={total_reward:.4f} | " | |
| f"reasoning_q={result.get('reasoning_quality_score', 0):.2f}" | |
| ) | |
| time.sleep(0.1) # rate limit | |
| success = result.get("success", False) if result else False | |
| print(f"\n Episode {episode+1} done. Total reward: {total_reward:.4f} | Success: {success}") | |
| return { | |
| "episode": episode + 1, | |
| "total_reward": round(total_reward, 4), | |
| "steps": step, | |
| "success": success, | |
| "reward_history": reward_history, | |
| } | |
| def main(): | |
| print(f"FocusFlow Agent | Model: {MODEL_NAME} | Task: {TASK_ID}") | |
| print(f"Environment: {ENV_URL}") | |
| print() | |
| h = requests.get(f"{ENV_URL}/health") | |
| print(f"Health: {h.json()}\n") | |
| all_stats = [] | |
| for ep in range(MAX_EPISODES): | |
| stats = run_episode(TASK_ID, ep) | |
| all_stats.append(stats) | |
| rewards = [s["total_reward"] for s in all_stats] | |
| print(f"\n{'='*60}") | |
| print(f"SUMMARY over {MAX_EPISODES} episodes:") | |
| print(f" Rewards : {[f'{r:.3f}' for r in rewards]}") | |
| print(f" Mean : {sum(rewards)/len(rewards):.4f}") | |
| print(f" Best : {max(rewards):.4f}") | |
| print(f" Success : {sum(s['success'] for s in all_stats)}/{MAX_EPISODES}") | |
| with open("reward_log.json", "w") as f: | |
| json.dump(all_stats, f, indent=2) | |
| print("\nReward log saved to reward_log.json") | |
| if __name__ == "__main__": | |
| main() |