hannan2859r commited on
Commit
ebf4b94
Β·
verified Β·
1 Parent(s): f9f5e0d

Update environment.py

Browse files
Files changed (1) hide show
  1. environment.py +115 -137
environment.py CHANGED
@@ -40,7 +40,6 @@ DISTRACTION_POOL: List[DistractingApp] = [
40
  ]
41
 
42
  # ─── Rich NL distraction events ───────────────────────────────────────────────
43
- # Each has a `correct_action` so the reward function can grade the agent's choice.
44
  EVENT_POOL: List[Dict[str, Any]] = [
45
  {
46
  "type": DistractionType.social_message,
@@ -99,29 +98,31 @@ EVENT_POOL: List[Dict[str, Any]] = [
99
  "hint": "Cognitive fatigue signal β†’ take a break before performance crashes"
100
  },
101
  ]
102
-
 
 
103
  def grade_reasoning(reasoning: str, action_type: str, event: Optional[DistractionEvent]) -> float:
104
  """
105
  Upgraded heuristic grader with anti-spam protections.
 
106
  """
107
  if not reasoning or len(reasoning.strip()) < 10:
108
  return 0.0
109
 
110
- text = reasoning.lower()
111
  words = text.split()
112
-
113
- # ANTI-SPAM: Penalize if the agent is just repeating the same words
114
  unique_ratio = len(set(words)) / max(1, len(words))
115
  if unique_ratio < 0.5:
116
  return 0.0 # Zero score for word salad/spam
117
-
118
- # ANTI-SHORTCUT: Must be at least a somewhat complete thought (e.g., > 5 words)
119
  if len(words) < 5:
120
  return 0.1
121
 
122
  score = 0.3 # baseline for valid reasoning
123
 
124
- # Reward mentioning relevant concepts
125
  focus_keywords = ["focus", "deadline", "study", "priority", "session", "pomodoro"]
126
  context_keywords = ["urgent", "can wait", "defer", "later", "energy", "tired", "break"]
127
  planning_words = ["because", "since", "therefore", "so that", "in order to", "plan"]
@@ -135,8 +136,9 @@ def grade_reasoning(reasoning: str, action_type: str, event: Optional[Distractio
135
  score += 0.2
136
 
137
  return round(min(1.0, score), 3)
138
- # ─── Tasks ────────────────────────────────────────────────────────────────────
139
 
 
 
140
  TASKS = [
141
  {
142
  "id": "task_1",
@@ -146,7 +148,6 @@ TASKS = [
146
  ),
147
  "max_steps": 60,
148
  "success_fn": lambda s: s["sessions_completed"] >= 1 and len(s["apps_checked"]) == 0,
149
- #The bonus fn here is giving the good scores on top of it if agent did really well .
150
  "bonus_fn": lambda s: 0.25 if s["reasoning_scores"] and
151
  sum(s["reasoning_scores"]) / len(s["reasoning_scores"]) > 0.7
152
  else 0.0,
@@ -190,51 +191,19 @@ TASKS = [
190
  ]
191
 
192
 
193
- # ─── Reasoning quality grader ─────────────────────────────────────────────────
194
-
195
- def grade_reasoning(reasoning: str, action_type: str, event: Optional[DistractionEvent]) -> float:
196
- """
197
- Simple heuristic grader for reasoning quality (0–1).
198
- Real training would use an LLM-as-judge here.
199
- """
200
- if not reasoning or len(reasoning.strip()) < 10:
201
- return 0.0
202
-
203
- score = 0.3 # baseline for non-empty reasoning
204
-
205
- text = reasoning.lower()
206
-
207
- # Reward mentioning relevant concepts
208
- #It checks how many of these words appear in the reasoning text. More relevant words = higher score.
209
- focus_keywords = ["focus", "deadline", "study", "priority", "session", "pomodoro"]
210
- context_keywords = ["urgent", "can wait", "defer", "later", "energy", "tired", "break"]
211
- planning_words = ["because", "since", "therefore", "so that", "in order to", "plan"]
212
-
213
- score += 0.1 * min(2, sum(1 for k in focus_keywords if k in text)) / 2
214
- score += 0.2 * min(2, sum(1 for k in context_keywords if k in text)) / 2
215
- score += 0.2 * min(2, sum(1 for k in planning_words if k in text)) / 2
216
-
217
- # Bonus: reasoning matches correct action for event
218
- if event and event.correct_action == action_type:
219
- score += 0.2
220
- #If score above 0.5 reward else penalty
221
-
222
- return round(min(1.0, score), 3)
223
-
224
-
225
  # ─── Environment ──────────────────────────────────────────────────────────────
226
-
227
  class FocusFlowEnvironment:
228
  """
229
  OpenEnv-compatible RL environment.
230
 
231
- Key upgrades over v1:
232
  - Rich NL distraction events with urgency & correct_action grading
233
  - Mandatory reasoning field scored by grade_reasoning()
234
  - Multi-day context with energy decay and deadline tracking
235
  - Cognitive load dynamics (overwork β†’ worse performance)
236
  - Deferred events expire after deadline_steps
237
  - plan_day action graded against actual completion
 
238
  """
239
 
240
  def __init__(self, task_id: str = "task_1", seed: int = 42):
@@ -243,7 +212,6 @@ class FocusFlowEnvironment:
243
  self._reset_internal()
244
 
245
  # ── Internal helpers ──────────────────────────────────────────────────────
246
- #It makes eveything to back on zero and make a fresh run state
247
  def _reset_internal(self):
248
  self.step_count = 0
249
  self.max_steps = self.task["max_steps"]
@@ -271,10 +239,9 @@ class FocusFlowEnvironment:
271
  energy_level=1.0,
272
  pending_deadlines=self._generate_deadlines(),
273
  )
274
- # Day plan set by agent via plan_day action
275
  self._agent_day_plan: List[str] = []
276
  self._last_reasoning_score = 0.0
277
-
278
  def _generate_deadlines(self) -> List[Dict[str, Any]]:
279
  deadlines = [
280
  {"task": "Math Assignment", "due_day": 1, "due_step": 45, "completed": False},
@@ -282,35 +249,32 @@ class FocusFlowEnvironment:
282
  {"task": "CS Project Demo", "due_day": 3, "due_step": 200,"completed": False},
283
  ]
284
  return deadlines[:self.task["days"]]
285
- #Randomly picking apps which are not blocked and called at the start when new session begin
286
  def _sample_apps(self, n: int) -> List[str]:
287
  available = [d.name for d in DISTRACTION_POOL if d.name not in self.apps_blocked]
288
  return random.sample(available, min(n, len(available)))
289
-
290
  def _maybe_spawn_event(self) -> Optional[DistractionEvent]:
291
- """25% chance per step to surface a rich NL distraction event."""
292
  if self.pending_event is not None:
293
  return None # one event at a time
294
- if random.random() < 0.25:
295
- raw = random.choice(EVENT_POOL)
296
- event = DistractionEvent(
297
- id=f"evt_{self.step_count}",
298
- type=raw["type"],
299
- description=raw["description"],
300
- urgency=raw["urgency"],
301
- can_defer=raw["can_defer"],
302
- deadline_steps=raw.get("deadline_steps"),
303
- correct_action=raw.get("correct_action", "focus"),
304
- )
305
- return event
306
- return None
307
-
308
  def _tick_event(self):
309
  """Age pending event. Penalise if it expires un-handled."""
310
  if self.pending_event and self.pending_event.deadline_steps is not None:
311
  self.pending_event.deadline_steps -= 1
312
  if self.pending_event.deadline_steps <= 0:
313
- # Event expired
314
  if not self.pending_event.can_defer:
315
  self.deadlines_missed += 1
316
  self.pending_event = None
@@ -327,13 +291,17 @@ class FocusFlowEnvironment:
327
  elif action_type == "adjust_energy":
328
  self.cognitive_load = max(0.0, self.cognitive_load - 0.10)
329
  self.max_cognitive_load = max(self.max_cognitive_load, self.cognitive_load)
330
- #subtract 60 second everytime when it hits 0
331
  def _advance_time(self):
 
 
 
 
 
332
  self.time_remaining -= SECONDS_PER_STEP
333
  if self.time_remaining <= 0:
334
  if self.current_phase == "focus":
335
  self.sessions_completed += 1
336
- self.total_focus_secs += FOCUS_DURATION_SECONDS
337
  # Mark relevant deadlines as completed
338
  for dl in self.day_context.pending_deadlines:
339
  if not dl["completed"] and dl["due_step"] <= self.step_count:
@@ -352,9 +320,9 @@ class FocusFlowEnvironment:
352
  self.current_phase = "focus"
353
  self.time_remaining = FOCUS_DURATION_SECONDS
354
  self.active_distractions = self._sample_apps(2)
355
-
356
  def _compute_reward(self, action: FocusAction) -> Tuple[float, str]:
357
- reward = 0.0
358
  feedback_parts = []
359
 
360
  # ── 1. Reasoning quality (universal) ─────────────────────────────────
@@ -364,23 +332,28 @@ class FocusFlowEnvironment:
364
  self._last_reasoning_score = r_score
365
  self.reasoning_scores.append(r_score)
366
 
367
- reasoning_bonus = (r_score - 0.5) * 0.20 # range: -0.10 to +0.10
 
 
 
 
 
 
 
 
 
 
368
  reward += reasoning_bonus
369
- if r_score < 0.3:
370
- feedback_parts.append(f"⚠ Weak reasoning (score {r_score:.2f}): -0.10 penalty.")
371
- elif r_score > 0.7:
372
- feedback_parts.append(f"βœ“ Good reasoning (score {r_score:.2f}): +0.10 bonus.")
373
 
374
  # ── 2. Action-specific rewards ────────────────────────────────────────
375
  atype = action.action_type
376
- #focus β€” +0.05 Γ— (1 βˆ’ cognitive_load Γ— 0.8)
377
  if atype == "focus":
378
  base = 0.05
379
- # Cognitive load penalty: reward shrinks when overloaded
380
  base *= max(0.2, 1.0 - self.cognitive_load * 0.8)
381
  reward += base
382
  feedback_parts.append(f"Focused. Step reward: +{base:.3f} (load={self.cognitive_load:.2f}).")
383
-
384
  elif atype == "block_app":
385
  if action.app_name and action.app_name not in self.apps_blocked:
386
  app_obj = next((d for d in DISTRACTION_POOL if d.name == action.app_name), None)
@@ -400,17 +373,15 @@ class FocusFlowEnvironment:
400
 
401
  elif atype == "take_break":
402
  if self.current_phase == "focus" and self.time_remaining <= 120:
403
- # Well-timed: within 2 min of session end
404
  reward += 0.30
405
  feedback_parts.append("Well-timed break at session boundary: +0.30.")
406
- self.current_phase = "break"
407
  self.time_remaining = SHORT_BREAK_SECONDS
408
  self.breaks_taken += 1
409
  elif self.cognitive_load > 0.75:
410
- # Needed break due to high cognitive load
411
  reward += 0.20
412
  feedback_parts.append(f"Recovery break (load={self.cognitive_load:.2f}): +0.20.")
413
- self.current_phase = "break"
414
  self.time_remaining = SHORT_BREAK_SECONDS
415
  self.breaks_taken += 1
416
  elif self.current_phase == "break":
@@ -419,7 +390,7 @@ class FocusFlowEnvironment:
419
  reward -= 0.10
420
  feedback_parts.append("Premature break: -0.10.")
421
  self.breaks_taken += 1
422
- #whether I can defer this event or not it gives reward based on the differ of the events
423
  elif atype == "defer_event":
424
  if self.pending_event:
425
  if self.pending_event.can_defer:
@@ -436,13 +407,12 @@ class FocusFlowEnvironment:
436
  feedback_parts.append("Cannot defer this event! -0.20 penalty.")
437
  else:
438
  feedback_parts.append("No pending event to defer.")
439
- #This event is urgent to do and take action urgently
440
  elif atype == "respond_to_event":
441
  if self.pending_event:
442
  correct = self.pending_event.correct_action == "respond_to_event"
443
  r = 0.20 if correct else -0.10
444
  reward += r
445
- # Extra: score the response text quality
446
  if action.response_text and len(action.response_text) > 15:
447
  reward += 0.05
448
  feedback_parts.append("Good response text: +0.05.")
@@ -456,10 +426,9 @@ class FocusFlowEnvironment:
456
 
457
  elif atype == "plan_day":
458
  if action.day_plan and len(action.day_plan) >= 2:
459
- # Basic plan quality: does it mention sessions and breaks?
460
- plan_text = " ".join(action.day_plan).lower()
461
- has_sessions = "focus" in plan_text or "study" in plan_text or "session" in plan_text
462
- has_breaks = "break" in plan_text or "rest" in plan_text
463
  has_deadlines = any(
464
  dl["task"].lower().split()[0] in plan_text
465
  for dl in self.day_context.pending_deadlines
@@ -473,7 +442,7 @@ class FocusFlowEnvironment:
473
  else:
474
  reward -= 0.10
475
  feedback_parts.append("Empty or trivial plan: -0.10.")
476
- #If energy is less or cognitive load is greater than the given criteria reward else less reward for minor tasks
477
  elif atype == "adjust_energy":
478
  if self.day_context.energy_level < 0.5 or self.cognitive_load > 0.6:
479
  reward += 0.10
@@ -481,16 +450,15 @@ class FocusFlowEnvironment:
481
  else:
482
  reward += 0.01
483
  feedback_parts.append("Energy fine, minor action: +0.01.")
484
- #It checks for app whether it is in the distraction apps or not if its not give none otherwise give -0.50 penalty
485
  elif atype == "check_app":
486
  app = action.app_name or (
487
  self.active_distractions[0] if self.active_distractions else None
488
  )
489
  if app:
490
  reward -= 0.50
491
- #Which app for checked for later analysis
492
  self.apps_checked.append(app)
493
- self.total_distraction_s += 60#Adds 60 seconds when total time wasted on distractions
494
  self.cognitive_load = min(1.0, self.cognitive_load + 0.10)
495
  feedback_parts.append(f"Gave in to {app}: -0.50 hard penalty.")
496
  else:
@@ -506,45 +474,55 @@ class FocusFlowEnvironment:
506
  feedback_parts.append(f"Unknown action '{atype}': -0.05.")
507
 
508
  return reward, " | ".join(feedback_parts)
509
-
510
  def _compute_deadline_pressure(self) -> float:
511
- # For each uncompleted deadline, it calculates how close you are to missing it.
512
- # At 50+ steps away β†’ pressure = 0.0. At 0 steps away β†’ pressure=1.0.
513
- # Returns the highest pressure across all deadlines.
514
- # This number appears in the observation so the LLM knows when to stop chatting and start studying.
 
515
  if not self.day_context.pending_deadlines:
516
  return 0.0
517
  pressures = []
518
  for dl in self.day_context.pending_deadlines:
519
- if dl.get("completed", False):
520
  continue
521
  steps_left = dl["due_step"] - self.step_count
522
  if steps_left <= 0:
523
  pressures.append(1.0)
524
  else:
525
- pressures.append(max(0.0, 1.0 - (steps_left / 50.0)))
526
  return max(pressures) if pressures else 0.0
527
 
528
  # ── Public OpenEnv API ────────────────────────────────────────────────────
529
  def reset(self) -> FocusObservation:
530
  self._reset_internal()
531
  return FocusObservation(
532
- time_remaining_seconds = self.time_remaining,
533
- current_phase = self.current_phase,
534
- active_distractions = list(self.active_distractions),
535
- blocked_apps = list(self.apps_blocked),
536
- sessions_completed = 0,
537
- focus_score = 0.0,
538
- pending_event = None,
539
- day_context = self.day_context,
540
- cognitive_load = self.cognitive_load,
541
- deadline_pressure = self._compute_deadline_pressure(),
542
- last_action_feedback = f"Environment reset. Task: {self.task['description']}",
543
- last_action_reward = 0.0,
544
- reasoning_quality_score= 0.0,
545
  )
546
- '''The main loop. Every call does this in order:'''
547
  def step(self, action: FocusAction) -> Tuple[FocusObservation, float, bool, dict]:
 
 
 
 
 
 
 
 
 
548
  if self.done:
549
  raise RuntimeError("Episode done. Call reset().")
550
 
@@ -558,12 +536,12 @@ class FocusFlowEnvironment:
558
  # Compute reward
559
  reward, feedback = self._compute_reward(action)
560
 
561
- # Maybe spawn new event (higher chance at high cognitive load)
562
  spawn_chance = 0.25 + 0.15 * self.cognitive_load
563
  if self.pending_event is None and random.random() < spawn_chance:
564
  self.pending_event = self._maybe_spawn_event()
565
 
566
- # Focus score
567
  focus_ratio = (
568
  self.total_focus_secs /
569
  max(1, self.total_focus_secs + self.total_distraction_s)
@@ -612,12 +590,12 @@ class FocusFlowEnvironment:
612
  )
613
 
614
  info = {
615
- "step": self.step_count,
616
- "success": success,
617
- "timed_out": timed_out,
618
- "cumulative": round(self.cumulative_reward, 4),
619
- "deadlines_missed":self.deadlines_missed,
620
- "reasoning_avg": round(
621
  sum(self.reasoning_scores) / max(1, len(self.reasoning_scores)), 3
622
  ),
623
  }
@@ -626,20 +604,20 @@ class FocusFlowEnvironment:
626
 
627
  def state(self) -> FocusState:
628
  return FocusState(
629
- episode_step = self.step_count,
630
- max_steps = self.max_steps,
631
- total_focus_seconds = self.total_focus_secs,
632
- total_distraction_seconds= self.total_distraction_s,
633
- sessions_completed = self.sessions_completed,
634
- breaks_taken = self.breaks_taken,
635
- apps_blocked = list(self.apps_blocked),
636
- apps_checked = list(self.apps_checked),
637
- events_deferred = list(self.events_deferred),
638
- events_responded = list(self.events_responded),
639
- current_phase = self.current_phase,
640
- time_remaining_seconds = self.time_remaining,
641
- cumulative_reward = round(self.cumulative_reward, 4),
642
- day_context = self.day_context,
643
- cognitive_load = round(self.cognitive_load, 3),
644
- done = self.done,
645
  )
 
40
  ]
41
 
42
  # ─── Rich NL distraction events ───────────────────────────────────────────────
 
43
  EVENT_POOL: List[Dict[str, Any]] = [
44
  {
45
  "type": DistractionType.social_message,
 
98
  "hint": "Cognitive fatigue signal β†’ take a break before performance crashes"
99
  },
100
  ]
101
+
102
+
103
+ # ─── Reasoning quality grader (SINGLE definition β€” anti-spam version) ─────────
104
  def grade_reasoning(reasoning: str, action_type: str, event: Optional[DistractionEvent]) -> float:
105
  """
106
  Upgraded heuristic grader with anti-spam protections.
107
+ Returns a score from 0.0 to 1.0.
108
  """
109
  if not reasoning or len(reasoning.strip()) < 10:
110
  return 0.0
111
 
112
+ text = reasoning.lower()
113
  words = text.split()
114
+
115
+ # ANTI-SPAM: Penalize if agent is just repeating the same words
116
  unique_ratio = len(set(words)) / max(1, len(words))
117
  if unique_ratio < 0.5:
118
  return 0.0 # Zero score for word salad/spam
119
+
120
+ # ANTI-SHORTCUT: Must be at least a somewhat complete thought
121
  if len(words) < 5:
122
  return 0.1
123
 
124
  score = 0.3 # baseline for valid reasoning
125
 
 
126
  focus_keywords = ["focus", "deadline", "study", "priority", "session", "pomodoro"]
127
  context_keywords = ["urgent", "can wait", "defer", "later", "energy", "tired", "break"]
128
  planning_words = ["because", "since", "therefore", "so that", "in order to", "plan"]
 
136
  score += 0.2
137
 
138
  return round(min(1.0, score), 3)
 
139
 
140
+
141
+ # ─── Tasks ────────────────────────────────────────────────────────────────────
142
  TASKS = [
143
  {
144
  "id": "task_1",
 
148
  ),
149
  "max_steps": 60,
150
  "success_fn": lambda s: s["sessions_completed"] >= 1 and len(s["apps_checked"]) == 0,
 
151
  "bonus_fn": lambda s: 0.25 if s["reasoning_scores"] and
152
  sum(s["reasoning_scores"]) / len(s["reasoning_scores"]) > 0.7
153
  else 0.0,
 
191
  ]
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # ─── Environment ──────────────────────────────────────────────────────────────
 
195
  class FocusFlowEnvironment:
196
  """
197
  OpenEnv-compatible RL environment.
198
 
199
+ Key features:
200
  - Rich NL distraction events with urgency & correct_action grading
201
  - Mandatory reasoning field scored by grade_reasoning()
202
  - Multi-day context with energy decay and deadline tracking
203
  - Cognitive load dynamics (overwork β†’ worse performance)
204
  - Deferred events expire after deadline_steps
205
  - plan_day action graded against actual completion
206
+ - Per-step focus tracking for real-time focus_score
207
  """
208
 
209
  def __init__(self, task_id: str = "task_1", seed: int = 42):
 
212
  self._reset_internal()
213
 
214
  # ── Internal helpers ──────────────────────────────────────────────────────
 
215
  def _reset_internal(self):
216
  self.step_count = 0
217
  self.max_steps = self.task["max_steps"]
 
239
  energy_level=1.0,
240
  pending_deadlines=self._generate_deadlines(),
241
  )
 
242
  self._agent_day_plan: List[str] = []
243
  self._last_reasoning_score = 0.0
244
+
245
  def _generate_deadlines(self) -> List[Dict[str, Any]]:
246
  deadlines = [
247
  {"task": "Math Assignment", "due_day": 1, "due_step": 45, "completed": False},
 
249
  {"task": "CS Project Demo", "due_day": 3, "due_step": 200,"completed": False},
250
  ]
251
  return deadlines[:self.task["days"]]
252
+
253
  def _sample_apps(self, n: int) -> List[str]:
254
  available = [d.name for d in DISTRACTION_POOL if d.name not in self.apps_blocked]
255
  return random.sample(available, min(n, len(available)))
256
+
257
  def _maybe_spawn_event(self) -> Optional[DistractionEvent]:
258
+ """Spawn a rich NL distraction event. Caller handles probability."""
259
  if self.pending_event is not None:
260
  return None # one event at a time
261
+ raw = random.choice(EVENT_POOL)
262
+ event = DistractionEvent(
263
+ id=f"evt_{self.step_count}",
264
+ type=raw["type"],
265
+ description=raw["description"],
266
+ urgency=raw["urgency"],
267
+ can_defer=raw["can_defer"],
268
+ deadline_steps=raw.get("deadline_steps"),
269
+ correct_action=raw.get("correct_action", "focus"),
270
+ )
271
+ return event
272
+
 
 
273
  def _tick_event(self):
274
  """Age pending event. Penalise if it expires un-handled."""
275
  if self.pending_event and self.pending_event.deadline_steps is not None:
276
  self.pending_event.deadline_steps -= 1
277
  if self.pending_event.deadline_steps <= 0:
 
278
  if not self.pending_event.can_defer:
279
  self.deadlines_missed += 1
280
  self.pending_event = None
 
291
  elif action_type == "adjust_energy":
292
  self.cognitive_load = max(0.0, self.cognitive_load - 0.10)
293
  self.max_cognitive_load = max(self.max_cognitive_load, self.cognitive_load)
294
+
295
  def _advance_time(self):
296
+ """Advance simulation clock by one step (1 minute)."""
297
+ # FIX: Track focus seconds per step, not just per session
298
+ if self.current_phase == "focus":
299
+ self.total_focus_secs += SECONDS_PER_STEP
300
+
301
  self.time_remaining -= SECONDS_PER_STEP
302
  if self.time_remaining <= 0:
303
  if self.current_phase == "focus":
304
  self.sessions_completed += 1
 
305
  # Mark relevant deadlines as completed
306
  for dl in self.day_context.pending_deadlines:
307
  if not dl["completed"] and dl["due_step"] <= self.step_count:
 
320
  self.current_phase = "focus"
321
  self.time_remaining = FOCUS_DURATION_SECONDS
322
  self.active_distractions = self._sample_apps(2)
323
+
324
  def _compute_reward(self, action: FocusAction) -> Tuple[float, str]:
325
+ reward = 0.0
326
  feedback_parts = []
327
 
328
  # ── 1. Reasoning quality (universal) ─────────────────────────────────
 
332
  self._last_reasoning_score = r_score
333
  self.reasoning_scores.append(r_score)
334
 
335
+ # FIX: Stronger penalty for zero reasoning
336
+ if r_score == 0.0:
337
+ reasoning_bonus = -0.15
338
+ feedback_parts.append("⚠ No/spam reasoning: -0.15 hard penalty.")
339
+ else:
340
+ reasoning_bonus = (r_score - 0.5) * 0.20 # range: -0.10 to +0.10
341
+ if r_score < 0.3:
342
+ feedback_parts.append(f"⚠ Weak reasoning (score {r_score:.2f}): penalty applied.")
343
+ elif r_score > 0.7:
344
+ feedback_parts.append(f"βœ“ Good reasoning (score {r_score:.2f}): +bonus.")
345
+
346
  reward += reasoning_bonus
 
 
 
 
347
 
348
  # ── 2. Action-specific rewards ────────────────────────────────────────
349
  atype = action.action_type
350
+
351
  if atype == "focus":
352
  base = 0.05
 
353
  base *= max(0.2, 1.0 - self.cognitive_load * 0.8)
354
  reward += base
355
  feedback_parts.append(f"Focused. Step reward: +{base:.3f} (load={self.cognitive_load:.2f}).")
356
+
357
  elif atype == "block_app":
358
  if action.app_name and action.app_name not in self.apps_blocked:
359
  app_obj = next((d for d in DISTRACTION_POOL if d.name == action.app_name), None)
 
373
 
374
  elif atype == "take_break":
375
  if self.current_phase == "focus" and self.time_remaining <= 120:
 
376
  reward += 0.30
377
  feedback_parts.append("Well-timed break at session boundary: +0.30.")
378
+ self.current_phase = "break"
379
  self.time_remaining = SHORT_BREAK_SECONDS
380
  self.breaks_taken += 1
381
  elif self.cognitive_load > 0.75:
 
382
  reward += 0.20
383
  feedback_parts.append(f"Recovery break (load={self.cognitive_load:.2f}): +0.20.")
384
+ self.current_phase = "break"
385
  self.time_remaining = SHORT_BREAK_SECONDS
386
  self.breaks_taken += 1
387
  elif self.current_phase == "break":
 
390
  reward -= 0.10
391
  feedback_parts.append("Premature break: -0.10.")
392
  self.breaks_taken += 1
393
+
394
  elif atype == "defer_event":
395
  if self.pending_event:
396
  if self.pending_event.can_defer:
 
407
  feedback_parts.append("Cannot defer this event! -0.20 penalty.")
408
  else:
409
  feedback_parts.append("No pending event to defer.")
410
+
411
  elif atype == "respond_to_event":
412
  if self.pending_event:
413
  correct = self.pending_event.correct_action == "respond_to_event"
414
  r = 0.20 if correct else -0.10
415
  reward += r
 
416
  if action.response_text and len(action.response_text) > 15:
417
  reward += 0.05
418
  feedback_parts.append("Good response text: +0.05.")
 
426
 
427
  elif atype == "plan_day":
428
  if action.day_plan and len(action.day_plan) >= 2:
429
+ plan_text = " ".join(action.day_plan).lower()
430
+ has_sessions = "focus" in plan_text or "study" in plan_text or "session" in plan_text
431
+ has_breaks = "break" in plan_text or "rest" in plan_text
 
432
  has_deadlines = any(
433
  dl["task"].lower().split()[0] in plan_text
434
  for dl in self.day_context.pending_deadlines
 
442
  else:
443
  reward -= 0.10
444
  feedback_parts.append("Empty or trivial plan: -0.10.")
445
+
446
  elif atype == "adjust_energy":
447
  if self.day_context.energy_level < 0.5 or self.cognitive_load > 0.6:
448
  reward += 0.10
 
450
  else:
451
  reward += 0.01
452
  feedback_parts.append("Energy fine, minor action: +0.01.")
453
+
454
  elif atype == "check_app":
455
  app = action.app_name or (
456
  self.active_distractions[0] if self.active_distractions else None
457
  )
458
  if app:
459
  reward -= 0.50
 
460
  self.apps_checked.append(app)
461
+ self.total_distraction_s += 60
462
  self.cognitive_load = min(1.0, self.cognitive_load + 0.10)
463
  feedback_parts.append(f"Gave in to {app}: -0.50 hard penalty.")
464
  else:
 
474
  feedback_parts.append(f"Unknown action '{atype}': -0.05.")
475
 
476
  return reward, " | ".join(feedback_parts)
477
+
478
  def _compute_deadline_pressure(self) -> float:
479
+ """
480
+ For each uncompleted deadline, calculates how close you are to missing it.
481
+ At 50+ steps away β†’ pressure = 0.0. At 0 steps away β†’ pressure = 1.0.
482
+ Returns the highest pressure across all deadlines.
483
+ """
484
  if not self.day_context.pending_deadlines:
485
  return 0.0
486
  pressures = []
487
  for dl in self.day_context.pending_deadlines:
488
+ if dl["completed"]:
489
  continue
490
  steps_left = dl["due_step"] - self.step_count
491
  if steps_left <= 0:
492
  pressures.append(1.0)
493
  else:
494
+ pressures.append(max(0.0, 1.0 - steps_left / 50.0))
495
  return max(pressures) if pressures else 0.0
496
 
497
  # ── Public OpenEnv API ────────────────────────────────────────────────────
498
  def reset(self) -> FocusObservation:
499
  self._reset_internal()
500
  return FocusObservation(
501
+ time_remaining_seconds = self.time_remaining,
502
+ current_phase = self.current_phase,
503
+ active_distractions = list(self.active_distractions),
504
+ blocked_apps = list(self.apps_blocked),
505
+ sessions_completed = 0,
506
+ focus_score = 0.0,
507
+ pending_event = None,
508
+ day_context = self.day_context,
509
+ cognitive_load = self.cognitive_load,
510
+ deadline_pressure = self._compute_deadline_pressure(),
511
+ last_action_feedback = f"Environment reset. Task: {self.task['description']}",
512
+ last_action_reward = 0.0,
513
+ reasoning_quality_score = 0.0,
514
  )
515
+
516
  def step(self, action: FocusAction) -> Tuple[FocusObservation, float, bool, dict]:
517
+ """
518
+ Main loop. Every call:
519
+ 1. Advances time
520
+ 2. Ticks pending event expiry
521
+ 3. Updates cognitive load
522
+ 4. Computes reward
523
+ 5. Maybe spawns new event (probability controlled here)
524
+ 6. Checks success/timeout
525
+ """
526
  if self.done:
527
  raise RuntimeError("Episode done. Call reset().")
528
 
 
536
  # Compute reward
537
  reward, feedback = self._compute_reward(action)
538
 
539
+ # FIX: Single probability check here (not doubled inside _maybe_spawn_event)
540
  spawn_chance = 0.25 + 0.15 * self.cognitive_load
541
  if self.pending_event is None and random.random() < spawn_chance:
542
  self.pending_event = self._maybe_spawn_event()
543
 
544
+ # Focus score β€” now updates every step
545
  focus_ratio = (
546
  self.total_focus_secs /
547
  max(1, self.total_focus_secs + self.total_distraction_s)
 
590
  )
591
 
592
  info = {
593
+ "step": self.step_count,
594
+ "success": success,
595
+ "timed_out": timed_out,
596
+ "cumulative": round(self.cumulative_reward, 4),
597
+ "deadlines_missed": self.deadlines_missed,
598
+ "reasoning_avg": round(
599
  sum(self.reasoning_scores) / max(1, len(self.reasoning_scores)), 3
600
  ),
601
  }
 
604
 
605
  def state(self) -> FocusState:
606
  return FocusState(
607
+ episode_step = self.step_count,
608
+ max_steps = self.max_steps,
609
+ total_focus_seconds = self.total_focus_secs,
610
+ total_distraction_seconds = self.total_distraction_s,
611
+ sessions_completed = self.sessions_completed,
612
+ breaks_taken = self.breaks_taken,
613
+ apps_blocked = list(self.apps_blocked),
614
+ apps_checked = list(self.apps_checked),
615
+ events_deferred = list(self.events_deferred),
616
+ events_responded = list(self.events_responded),
617
+ current_phase = self.current_phase,
618
+ time_remaining_seconds = self.time_remaining,
619
+ cumulative_reward = round(self.cumulative_reward, 4),
620
+ day_context = self.day_context,
621
+ cognitive_load = round(self.cognitive_load, 3),
622
+ done = self.done,
623
  )