Spaces:
Sleeping
Sleeping
Update environment.py
Browse files- environment.py +8 -8
environment.py
CHANGED
|
@@ -5,7 +5,7 @@ from typing import cast
|
|
| 5 |
|
| 6 |
from pydantic import ValidationError
|
| 7 |
|
| 8 |
-
from graders import grade_easy, grade_hard, grade_medium_step
|
| 9 |
from models import (
|
| 10 |
EmailObservation,
|
| 11 |
EnvironmentState,
|
|
@@ -117,7 +117,7 @@ class EmailTriageEnv:
|
|
| 117 |
if self._done:
|
| 118 |
return StepResult(
|
| 119 |
observation=self._terminal_observation(),
|
| 120 |
-
reward=
|
| 121 |
done=True,
|
| 122 |
info={
|
| 123 |
"task_id": self.task_id,
|
|
@@ -132,11 +132,11 @@ class EmailTriageEnv:
|
|
| 132 |
validated_action = TriageAction.model_validate(action)
|
| 133 |
except ValidationError as validation_error:
|
| 134 |
self._current_step += 1
|
| 135 |
-
self._reward_history.append(
|
| 136 |
self._done = self._current_step >= self._max_steps
|
| 137 |
return StepResult(
|
| 138 |
observation=self._build_observation(self._current_index),
|
| 139 |
-
reward=
|
| 140 |
done=self._done,
|
| 141 |
info={
|
| 142 |
"task_id": self.task_id,
|
|
@@ -351,8 +351,8 @@ class EmailTriageEnv:
|
|
| 351 |
"""
|
| 352 |
if not self._ground_truth:
|
| 353 |
return RewardResult(
|
| 354 |
-
score=
|
| 355 |
-
breakdown={"missing_ground_truth": 1.0},
|
| 356 |
feedback="Missing ground truth for task.",
|
| 357 |
)
|
| 358 |
|
|
@@ -458,7 +458,7 @@ class EmailTriageEnv:
|
|
| 458 |
)
|
| 459 |
|
| 460 |
def _clip_reward(self, reward_value: float) -> float:
|
| 461 |
-
"""Clip reward to the
|
| 462 |
|
| 463 |
Args:
|
| 464 |
reward_value: Raw reward value.
|
|
@@ -466,4 +466,4 @@ class EmailTriageEnv:
|
|
| 466 |
Returns:
|
| 467 |
Clipped reward.
|
| 468 |
"""
|
| 469 |
-
return max(
|
|
|
|
| 5 |
|
| 6 |
from pydantic import ValidationError
|
| 7 |
|
| 8 |
+
from graders import SCORE_EPSILON, grade_easy, grade_hard, grade_medium_step
|
| 9 |
from models import (
|
| 10 |
EmailObservation,
|
| 11 |
EnvironmentState,
|
|
|
|
| 117 |
if self._done:
|
| 118 |
return StepResult(
|
| 119 |
observation=self._terminal_observation(),
|
| 120 |
+
reward=SCORE_EPSILON,
|
| 121 |
done=True,
|
| 122 |
info={
|
| 123 |
"task_id": self.task_id,
|
|
|
|
| 132 |
validated_action = TriageAction.model_validate(action)
|
| 133 |
except ValidationError as validation_error:
|
| 134 |
self._current_step += 1
|
| 135 |
+
self._reward_history.append(SCORE_EPSILON)
|
| 136 |
self._done = self._current_step >= self._max_steps
|
| 137 |
return StepResult(
|
| 138 |
observation=self._build_observation(self._current_index),
|
| 139 |
+
reward=SCORE_EPSILON,
|
| 140 |
done=self._done,
|
| 141 |
info={
|
| 142 |
"task_id": self.task_id,
|
|
|
|
| 351 |
"""
|
| 352 |
if not self._ground_truth:
|
| 353 |
return RewardResult(
|
| 354 |
+
score=SCORE_EPSILON,
|
| 355 |
+
breakdown={"missing_ground_truth": 1.0 - SCORE_EPSILON},
|
| 356 |
feedback="Missing ground truth for task.",
|
| 357 |
)
|
| 358 |
|
|
|
|
| 458 |
)
|
| 459 |
|
| 460 |
def _clip_reward(self, reward_value: float) -> float:
|
| 461 |
+
"""Clip reward to the strict range [0.0, 1.0].
|
| 462 |
|
| 463 |
Args:
|
| 464 |
reward_value: Raw reward value.
|
|
|
|
| 466 |
Returns:
|
| 467 |
Clipped reward.
|
| 468 |
"""
|
| 469 |
+
return max(SCORE_EPSILON, min(1.0 - SCORE_EPSILON, reward_value))
|