Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- env/environment.py +10 -7
- env/models.py +2 -2
env/environment.py
CHANGED
|
@@ -39,7 +39,7 @@ class DataValidationEnvironment(Environment):
|
|
| 39 |
max_steps=task["max_steps"],
|
| 40 |
done=False,
|
| 41 |
reward_history=[],
|
| 42 |
-
cumulative_reward=0.
|
| 43 |
dataset=task["dataset"],
|
| 44 |
ground_truth=self._ground_truth,
|
| 45 |
errors=self._errors,
|
|
@@ -58,8 +58,8 @@ class DataValidationEnvironment(Environment):
|
|
| 58 |
errors_fixed=0,
|
| 59 |
step_count=0,
|
| 60 |
max_steps=task["max_steps"],
|
| 61 |
-
reward=0.
|
| 62 |
-
cumulative_reward=0.
|
| 63 |
done=False,
|
| 64 |
last_action_result="Environment reset. Examine errors and fix them.",
|
| 65 |
task_hint=task["hint"],
|
|
@@ -69,7 +69,7 @@ class DataValidationEnvironment(Environment):
|
|
| 69 |
|
| 70 |
def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation:
|
| 71 |
if self._state.done:
|
| 72 |
-
return self._make_observation(0.
|
| 73 |
|
| 74 |
self._state.step_count += 1
|
| 75 |
|
|
@@ -78,7 +78,7 @@ class DataValidationEnvironment(Environment):
|
|
| 78 |
self._state.last_actions.append(action_key)
|
| 79 |
|
| 80 |
if is_repeat:
|
| 81 |
-
reward = 0.
|
| 82 |
message = "Penalty: repeated identical action"
|
| 83 |
else:
|
| 84 |
reward, message, fixed = grade_action(
|
|
@@ -118,6 +118,9 @@ class DataValidationEnvironment(Environment):
|
|
| 118 |
|
| 119 |
unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
|
| 120 |
|
|
|
|
|
|
|
|
|
|
| 121 |
return DataCleanObservation(
|
| 122 |
task_name=self._state.task_name,
|
| 123 |
task_description=self._task_info.get("description", ""),
|
|
@@ -128,8 +131,8 @@ class DataValidationEnvironment(Environment):
|
|
| 128 |
errors_fixed=self._state.errors_fixed,
|
| 129 |
step_count=self._state.step_count,
|
| 130 |
max_steps=self._state.max_steps,
|
| 131 |
-
reward=
|
| 132 |
-
cumulative_reward=
|
| 133 |
done=self._state.done,
|
| 134 |
last_action_result=message,
|
| 135 |
task_hint=self._task_info.get("hint", ""),
|
|
|
|
| 39 |
max_steps=task["max_steps"],
|
| 40 |
done=False,
|
| 41 |
reward_history=[],
|
| 42 |
+
cumulative_reward=0.01,
|
| 43 |
dataset=task["dataset"],
|
| 44 |
ground_truth=self._ground_truth,
|
| 45 |
errors=self._errors,
|
|
|
|
| 58 |
errors_fixed=0,
|
| 59 |
step_count=0,
|
| 60 |
max_steps=task["max_steps"],
|
| 61 |
+
reward=0.01,
|
| 62 |
+
cumulative_reward=0.01,
|
| 63 |
done=False,
|
| 64 |
last_action_result="Environment reset. Examine errors and fix them.",
|
| 65 |
task_hint=task["hint"],
|
|
|
|
| 69 |
|
| 70 |
def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation:
|
| 71 |
if self._state.done:
|
| 72 |
+
return self._make_observation(0.01, "Episode already done. Call reset().")
|
| 73 |
|
| 74 |
self._state.step_count += 1
|
| 75 |
|
|
|
|
| 78 |
self._state.last_actions.append(action_key)
|
| 79 |
|
| 80 |
if is_repeat:
|
| 81 |
+
reward = 0.01
|
| 82 |
message = "Penalty: repeated identical action"
|
| 83 |
else:
|
| 84 |
reward, message, fixed = grade_action(
|
|
|
|
| 118 |
|
| 119 |
unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
|
| 120 |
|
| 121 |
+
clamped_reward = max(0.01, min(0.99, reward))
|
| 122 |
+
clamped_cumulative = max(0.01, min(0.99, self._state.cumulative_reward))
|
| 123 |
+
|
| 124 |
return DataCleanObservation(
|
| 125 |
task_name=self._state.task_name,
|
| 126 |
task_description=self._task_info.get("description", ""),
|
|
|
|
| 131 |
errors_fixed=self._state.errors_fixed,
|
| 132 |
step_count=self._state.step_count,
|
| 133 |
max_steps=self._state.max_steps,
|
| 134 |
+
reward=clamped_reward,
|
| 135 |
+
cumulative_reward=clamped_cumulative,
|
| 136 |
done=self._state.done,
|
| 137 |
last_action_result=message,
|
| 138 |
task_hint=self._task_info.get("hint", ""),
|
env/models.py
CHANGED
|
@@ -21,7 +21,7 @@ class DataCleanObservation(Observation):
|
|
| 21 |
errors_fixed: int = Field(default=0)
|
| 22 |
step_count: int = Field(default=0)
|
| 23 |
max_steps: int = Field(default=20)
|
| 24 |
-
cumulative_reward: float = Field(default=0.
|
| 25 |
last_action_result: str = Field(default="")
|
| 26 |
task_hint: str = Field(default="")
|
| 27 |
available_actions: List[str] = Field(
|
|
@@ -39,7 +39,7 @@ class DataCleanState(State):
|
|
| 39 |
max_steps: int = Field(default=20)
|
| 40 |
done: bool = Field(default=False)
|
| 41 |
reward_history: List[float] = Field(default_factory=list)
|
| 42 |
-
cumulative_reward: float = Field(default=0.
|
| 43 |
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 44 |
ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
|
| 45 |
errors: List[Dict[str, Any]] = Field(default_factory=list)
|
|
|
|
| 21 |
errors_fixed: int = Field(default=0)
|
| 22 |
step_count: int = Field(default=0)
|
| 23 |
max_steps: int = Field(default=20)
|
| 24 |
+
cumulative_reward: float = Field(default=0.01)
|
| 25 |
last_action_result: str = Field(default="")
|
| 26 |
task_hint: str = Field(default="")
|
| 27 |
available_actions: List[str] = Field(
|
|
|
|
| 39 |
max_steps: int = Field(default=20)
|
| 40 |
done: bool = Field(default=False)
|
| 41 |
reward_history: List[float] = Field(default_factory=list)
|
| 42 |
+
cumulative_reward: float = Field(default=0.01)
|
| 43 |
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 44 |
ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
|
| 45 |
errors: List[Dict[str, Any]] = Field(default_factory=list)
|