hannan2859r commited on
Commit
7bc45d7
Β·
verified Β·
1 Parent(s): 8ca38b0

Update environment.py

Browse files
Files changed (1) hide show
  1. environment.py +352 -5
environment.py CHANGED
@@ -99,7 +99,7 @@ EVENT_POOL: List[Dict[str, Any]] = [
99
  "hint": "Cognitive fatigue signal β†’ take a break before performance crashes"
100
  },
101
  ]
102
-
103
  def grade_reasoning(reasoning: str, action_type: str, event: Optional[DistractionEvent]) -> float:
104
  """
105
  Upgraded heuristic grader with anti-spam protections.
@@ -135,7 +135,6 @@ def grade_reasoning(reasoning: str, action_type: str, event: Optional[Distractio
135
  score += 0.2
136
 
137
  return round(min(1.0, score), 3)
138
-
139
  # ─── Tasks ────────────────────────────────────────────────────────────────────
140
 
141
  TASKS = [
@@ -191,6 +190,38 @@ TASKS = [
191
  ]
192
 
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  # ─── Environment ──────────────────────────────────────────────────────────────
195
 
196
  class FocusFlowEnvironment:
@@ -251,8 +282,7 @@ class FocusFlowEnvironment:
251
  {"task": "CS Project Demo", "due_day": 3, "due_step": 200,"completed": False},
252
  ]
253
  return deadlines[:self.task["days"]]
254
-
255
- #Randomly picking apps which are not blocked and called at the start when new session begin
256
  def _sample_apps(self, n: int) -> List[str]:
257
  available = [d.name for d in DISTRACTION_POOL if d.name not in self.apps_blocked]
258
  return random.sample(available, min(n, len(available)))
@@ -293,4 +323,321 @@ class FocusFlowEnvironment:
293
  if action_type == "focus":
294
  self.cognitive_load = min(1.0, self.cognitive_load + 0.04)
295
  elif action_type == "take_break":
296
- self.cognitive_
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  "hint": "Cognitive fatigue signal β†’ take a break before performance crashes"
100
  },
101
  ]
102
+
103
  def grade_reasoning(reasoning: str, action_type: str, event: Optional[DistractionEvent]) -> float:
104
  """
105
  Upgraded heuristic grader with anti-spam protections.
 
135
  score += 0.2
136
 
137
  return round(min(1.0, score), 3)
 
138
  # ─── Tasks ────────────────────────────────────────────────────────────────────
139
 
140
  TASKS = [
 
190
  ]
191
 
192
 
193
+ # ─── Reasoning quality grader ─────────────────────────────────────────────────
194
+
195
+ def grade_reasoning(reasoning: str, action_type: str, event: Optional[DistractionEvent]) -> float:
196
+ """
197
+ Simple heuristic grader for reasoning quality (0–1).
198
+ Real training would use an LLM-as-judge here.
199
+ """
200
+ if not reasoning or len(reasoning.strip()) < 10:
201
+ return 0.0
202
+
203
+ score = 0.3 # baseline for non-empty reasoning
204
+
205
+ text = reasoning.lower()
206
+
207
+ # Reward mentioning relevant concepts
208
+ #It checks how many of these words appear in the reasoning text. More relevant words = higher score.
209
+ focus_keywords = ["focus", "deadline", "study", "priority", "session", "pomodoro"]
210
+ context_keywords = ["urgent", "can wait", "defer", "later", "energy", "tired", "break"]
211
+ planning_words = ["because", "since", "therefore", "so that", "in order to", "plan"]
212
+
213
+ score += 0.1 * min(2, sum(1 for k in focus_keywords if k in text)) / 2
214
+ score += 0.2 * min(2, sum(1 for k in context_keywords if k in text)) / 2
215
+ score += 0.2 * min(2, sum(1 for k in planning_words if k in text)) / 2
216
+
217
+ # Bonus: reasoning matches correct action for event
218
+ if event and event.correct_action == action_type:
219
+ score += 0.2
220
+ #If score above 0.5 reward else penalty
221
+
222
+ return round(min(1.0, score), 3)
223
+
224
+
225
  # ─── Environment ──────────────────────────────────────────────────────────────
226
 
227
  class FocusFlowEnvironment:
 
282
  {"task": "CS Project Demo", "due_day": 3, "due_step": 200,"completed": False},
283
  ]
284
  return deadlines[:self.task["days"]]
285
+ #Randomly picking apps which are not blocked and called at the start when new session begin
 
286
  def _sample_apps(self, n: int) -> List[str]:
287
  available = [d.name for d in DISTRACTION_POOL if d.name not in self.apps_blocked]
288
  return random.sample(available, min(n, len(available)))
 
323
  if action_type == "focus":
324
  self.cognitive_load = min(1.0, self.cognitive_load + 0.04)
325
  elif action_type == "take_break":
326
+ self.cognitive_load = max(0.0, self.cognitive_load - 0.25)
327
+ elif action_type == "adjust_energy":
328
+ self.cognitive_load = max(0.0, self.cognitive_load - 0.10)
329
+ self.max_cognitive_load = max(self.max_cognitive_load, self.cognitive_load)
330
+ #subtract 60 second everytime when it hits 0
331
+ def _advance_time(self):
332
+ self.time_remaining -= SECONDS_PER_STEP
333
+ if self.time_remaining <= 0:
334
+ if self.current_phase == "focus":
335
+ self.sessions_completed += 1
336
+ self.total_focus_secs += FOCUS_DURATION_SECONDS
337
+ # Mark relevant deadlines as completed
338
+ for dl in self.day_context.pending_deadlines:
339
+ if not dl["completed"] and dl["due_step"] <= self.step_count:
340
+ dl["completed"] = True
341
+ self.current_phase = "break"
342
+ self.time_remaining = (
343
+ SHORT_BREAK_SECONDS if self.sessions_completed % 4 != 0
344
+ else LONG_BREAK_SECONDS
345
+ )
346
+ # Energy decay each completed session
347
+ self.day_context.energy_level = max(
348
+ 0.1,
349
+ self.day_context.energy_level - 0.08
350
+ )
351
+ else:
352
+ self.current_phase = "focus"
353
+ self.time_remaining = FOCUS_DURATION_SECONDS
354
+ self.active_distractions = self._sample_apps(2)
355
+
356
+ def _compute_reward(self, action: FocusAction) -> Tuple[float, str]:
357
+ reward = 0.0
358
+ feedback_parts = []
359
+
360
+ # ── 1. Reasoning quality (universal) ─────────────────────────────────
361
+ r_score = grade_reasoning(
362
+ action.reasoning, action.action_type, self.pending_event
363
+ )
364
+ self._last_reasoning_score = r_score
365
+ self.reasoning_scores.append(r_score)
366
+
367
+ reasoning_bonus = (r_score - 0.5) * 0.20 # range: -0.10 to +0.10
368
+ reward += reasoning_bonus
369
+ if r_score < 0.3:
370
+ feedback_parts.append(f"⚠ Weak reasoning (score {r_score:.2f}): -0.10 penalty.")
371
+ elif r_score > 0.7:
372
+ feedback_parts.append(f"βœ“ Good reasoning (score {r_score:.2f}): +0.10 bonus.")
373
+
374
+ # ── 2. Action-specific rewards ────────────────────────────────────────
375
+ atype = action.action_type
376
+ #focus β€” +0.05 Γ— (1 βˆ’ cognitive_load Γ— 0.8)
377
+ if atype == "focus":
378
+ base = 0.05
379
+ # Cognitive load penalty: reward shrinks when overloaded
380
+ base *= max(0.2, 1.0 - self.cognitive_load * 0.8)
381
+ reward += base
382
+ feedback_parts.append(f"Focused. Step reward: +{base:.3f} (load={self.cognitive_load:.2f}).")
383
+
384
+ elif atype == "block_app":
385
+ if action.app_name and action.app_name not in self.apps_blocked:
386
+ app_obj = next((d for d in DISTRACTION_POOL if d.name == action.app_name), None)
387
+ if app_obj:
388
+ self.apps_blocked.append(action.app_name)
389
+ if action.app_name in self.active_distractions:
390
+ self.active_distractions.remove(action.app_name)
391
+ r = 0.20 * app_obj.temptation_level
392
+ reward += r
393
+ feedback_parts.append(
394
+ f"Blocked {action.app_name} (temptation={app_obj.temptation_level}): +{r:.2f}."
395
+ )
396
+ else:
397
+ feedback_parts.append("App not in pool β€” no reward.")
398
+ else:
399
+ feedback_parts.append("Already blocked or not specified.")
400
+
401
+ elif atype == "take_break":
402
+ if self.current_phase == "focus" and self.time_remaining <= 120:
403
+ # Well-timed: within 2 min of session end
404
+ reward += 0.30
405
+ feedback_parts.append("Well-timed break at session boundary: +0.30.")
406
+ self.current_phase = "break"
407
+ self.time_remaining = SHORT_BREAK_SECONDS
408
+ self.breaks_taken += 1
409
+ elif self.cognitive_load > 0.75:
410
+ # Needed break due to high cognitive load
411
+ reward += 0.20
412
+ feedback_parts.append(f"Recovery break (load={self.cognitive_load:.2f}): +0.20.")
413
+ self.current_phase = "break"
414
+ self.time_remaining = SHORT_BREAK_SECONDS
415
+ self.breaks_taken += 1
416
+ elif self.current_phase == "break":
417
+ feedback_parts.append("Already on break. No reward.")
418
+ else:
419
+ reward -= 0.10
420
+ feedback_parts.append("Premature break: -0.10.")
421
+ self.breaks_taken += 1
422
+ #whether I can defer this event or not it gives reward based on the differ of the events
423
+ elif atype == "defer_event":
424
+ if self.pending_event:
425
+ if self.pending_event.can_defer:
426
+ r = 0.15 if self.pending_event.correct_action == "defer_event" else -0.05
427
+ reward += r
428
+ self.events_deferred.append(self.pending_event.id)
429
+ self.day_context.deferred_events.append(self.pending_event)
430
+ label = "Correct defer" if r > 0 else "Should have responded"
431
+ feedback_parts.append(f"{label}: {r:+.2f}.")
432
+ self.pending_event = None
433
+ else:
434
+ reward -= 0.20
435
+ self.deadlines_missed += 1
436
+ feedback_parts.append("Cannot defer this event! -0.20 penalty.")
437
+ else:
438
+ feedback_parts.append("No pending event to defer.")
439
+ #This event is urgent to do and take action urgently
440
+ elif atype == "respond_to_event":
441
+ if self.pending_event:
442
+ correct = self.pending_event.correct_action == "respond_to_event"
443
+ r = 0.20 if correct else -0.10
444
+ reward += r
445
+ # Extra: score the response text quality
446
+ if action.response_text and len(action.response_text) > 15:
447
+ reward += 0.05
448
+ feedback_parts.append("Good response text: +0.05.")
449
+ self.events_responded.append(self.pending_event.id)
450
+ self.pending_event = None
451
+ feedback_parts.append(
452
+ f"{'Correct' if correct else 'Wrong'} response to event: {r:+.2f}."
453
+ )
454
+ else:
455
+ feedback_parts.append("No pending event.")
456
+
457
+ elif atype == "plan_day":
458
+ if action.day_plan and len(action.day_plan) >= 2:
459
+ # Basic plan quality: does it mention sessions and breaks?
460
+ plan_text = " ".join(action.day_plan).lower()
461
+ has_sessions = "focus" in plan_text or "study" in plan_text or "session" in plan_text
462
+ has_breaks = "break" in plan_text or "rest" in plan_text
463
+ has_deadlines = any(
464
+ dl["task"].lower().split()[0] in plan_text
465
+ for dl in self.day_context.pending_deadlines
466
+ )
467
+ score = sum([has_sessions, has_breaks, has_deadlines]) / 3.0
468
+ reward += 0.30 * score
469
+ self._agent_day_plan = action.day_plan
470
+ feedback_parts.append(
471
+ f"Day plan quality: {score:.0%} β†’ +{0.30*score:.2f}."
472
+ )
473
+ else:
474
+ reward -= 0.10
475
+ feedback_parts.append("Empty or trivial plan: -0.10.")
476
+ #If energy is less or cognitive load is greater than the given criteria reward else less reward for minor tasks
477
+ elif atype == "adjust_energy":
478
+ if self.day_context.energy_level < 0.5 or self.cognitive_load > 0.6:
479
+ reward += 0.10
480
+ feedback_parts.append("Energy management action: +0.10.")
481
+ else:
482
+ reward += 0.01
483
+ feedback_parts.append("Energy fine, minor action: +0.01.")
484
+ #It checks for app whether it is in the distraction apps or not if its not give none otherwise give -0.50 penalty
485
+ elif atype == "check_app":
486
+ app = action.app_name or (
487
+ self.active_distractions[0] if self.active_distractions else None
488
+ )
489
+ if app:
490
+ reward -= 0.50
491
+ #Which app for checked for later analysis
492
+ self.apps_checked.append(app)
493
+ self.total_distraction_s += 60#Adds 60 seconds when total time wasted on distractions
494
+ self.cognitive_load = min(1.0, self.cognitive_load + 0.10)
495
+ feedback_parts.append(f"Gave in to {app}: -0.50 hard penalty.")
496
+ else:
497
+ feedback_parts.append("No active distraction to check.")
498
+
499
+ elif atype == "quit_session":
500
+ reward -= 0.30
501
+ self.done = True
502
+ feedback_parts.append("Session quit early: -0.30.")
503
+
504
+ else:
505
+ reward -= 0.05
506
+ feedback_parts.append(f"Unknown action '{atype}': -0.05.")
507
+
508
+ return reward, " | ".join(feedback_parts)
509
+ '''For each uncompleted deadline, it calculates how close you are to missing it. At 50+ steps away β†’ pressure = 0.0. At 0 steps away β†’ pressure=1.0.
510
+ Returns the highest pressure across all deadlines.
511
+ This number appears in the observation so the LLM knows when to stop chatting and start studying.'''
512
+ def _compute_deadline_pressure(self) -> float:
513
+ if not self.day_context.pending_deadlines:
514
+ return 0.0
515
+ pressures = []
516
+ for dl in self.day_context.pending_deadlines:
517
+ if dl["completed"]:
518
+ continue
519
+ steps_left = dl["due_step"] - self.step_count
520
+ if steps_left <= 0:
521
+ pressures.append(1.0)
522
+ else:
523
+ pressures.append(max(0.0, 1.0 - steps_left / 50.0))
524
+ return max(pressures) if pressures else 0.0
525
+
526
+ # ── Public OpenEnv API ────────────────────────────────────────────────────
527
+ def reset(self) -> FocusObservation:
528
+ self._reset_internal()
529
+ return FocusObservation(
530
+ time_remaining_seconds = self.time_remaining,
531
+ current_phase = self.current_phase,
532
+ active_distractions = list(self.active_distractions),
533
+ blocked_apps = list(self.apps_blocked),
534
+ sessions_completed = 0,
535
+ focus_score = 0.0,
536
+ pending_event = None,
537
+ day_context = self.day_context,
538
+ cognitive_load = self.cognitive_load,
539
+ deadline_pressure = self._compute_deadline_pressure(),
540
+ last_action_feedback = f"Environment reset. Task: {self.task['description']}",
541
+ last_action_reward = 0.0,
542
+ reasoning_quality_score= 0.0,
543
+ )
544
+ '''The main loop. Every call does this in order:'''
545
+ def step(self, action: FocusAction) -> Tuple[FocusObservation, float, bool, dict]:
546
+ if self.done:
547
+ raise RuntimeError("Episode done. Call reset().")
548
+
549
+ self.step_count += 1
550
+
551
+ # Tick timers
552
+ self._advance_time()
553
+ self._tick_event()
554
+ self._update_cognitive_load(action.action_type)
555
+
556
+ # Compute reward
557
+ reward, feedback = self._compute_reward(action)
558
+
559
+ # Maybe spawn new event (higher chance at high cognitive load)
560
+ spawn_chance = 0.25 + 0.15 * self.cognitive_load
561
+ if self.pending_event is None and random.random() < spawn_chance:
562
+ self.pending_event = self._maybe_spawn_event()
563
+
564
+ # Focus score
565
+ focus_ratio = (
566
+ self.total_focus_secs /
567
+ max(1, self.total_focus_secs + self.total_distraction_s)
568
+ )
569
+
570
+ # Deadline pressure
571
+ deadline_pressure = self._compute_deadline_pressure()
572
+
573
+ # Success check
574
+ state_snapshot = {
575
+ "sessions_completed": self.sessions_completed,
576
+ "apps_checked": self.apps_checked,
577
+ "breaks_taken": self.breaks_taken,
578
+ "max_cognitive_load": self.max_cognitive_load,
579
+ "deadlines_missed": self.deadlines_missed,
580
+ "streak_days": self.day_context.streak_days,
581
+ "reasoning_scores": self.reasoning_scores,
582
+ }
583
+ success = self.task["success_fn"](state_snapshot)
584
+ timed_out = self.step_count >= self.max_steps
585
+
586
+ if success or timed_out:
587
+ self.done = True
588
+ if success:
589
+ bonus = self.task["bonus_fn"](state_snapshot)
590
+ reward += bonus
591
+ if bonus > 0:
592
+ feedback += f" | πŸŽ‰ Bonus: +{bonus:.2f} ({self.task['bonus_desc']})"
593
+
594
+ self.cumulative_reward += reward
595
+
596
+ obs = FocusObservation(
597
+ time_remaining_seconds = self.time_remaining,
598
+ current_phase = self.current_phase,
599
+ active_distractions = list(self.active_distractions),
600
+ blocked_apps = list(self.apps_blocked),
601
+ sessions_completed = self.sessions_completed,
602
+ focus_score = round(focus_ratio, 3),
603
+ pending_event = self.pending_event,
604
+ day_context = self.day_context,
605
+ cognitive_load = round(self.cognitive_load, 3),
606
+ deadline_pressure = round(deadline_pressure, 3),
607
+ last_action_feedback = feedback,
608
+ last_action_reward = round(reward, 4),
609
+ reasoning_quality_score = self._last_reasoning_score,
610
+ )
611
+
612
+ info = {
613
+ "step": self.step_count,
614
+ "success": success,
615
+ "timed_out": timed_out,
616
+ "cumulative": round(self.cumulative_reward, 4),
617
+ "deadlines_missed":self.deadlines_missed,
618
+ "reasoning_avg": round(
619
+ sum(self.reasoning_scores) / max(1, len(self.reasoning_scores)), 3
620
+ ),
621
+ }
622
+
623
+ return obs, round(reward, 4), self.done, info
624
+
625
+ def state(self) -> FocusState:
626
+ return FocusState(
627
+ episode_step = self.step_count,
628
+ max_steps = self.max_steps,
629
+ total_focus_seconds = self.total_focus_secs,
630
+ total_distraction_seconds= self.total_distraction_s,
631
+ sessions_completed = self.sessions_completed,
632
+ breaks_taken = self.breaks_taken,
633
+ apps_blocked = list(self.apps_blocked),
634
+ apps_checked = list(self.apps_checked),
635
+ events_deferred = list(self.events_deferred),
636
+ events_responded = list(self.events_responded),
637
+ current_phase = self.current_phase,
638
+ time_remaining_seconds = self.time_remaining,
639
+ cumulative_reward = round(self.cumulative_reward, 4),
640
+ day_context = self.day_context,
641
+ cognitive_load = round(self.cognitive_load, 3),
642
+ done = self.done,
643
+ )