kush5699 commited on
Commit
1bac517
·
verified ·
1 Parent(s): 842577f

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. env/environment.py +10 -7
  2. env/models.py +2 -2
env/environment.py CHANGED
@@ -39,7 +39,7 @@ class DataValidationEnvironment(Environment):
39
  max_steps=task["max_steps"],
40
  done=False,
41
  reward_history=[],
42
- cumulative_reward=0.0,
43
  dataset=task["dataset"],
44
  ground_truth=self._ground_truth,
45
  errors=self._errors,
@@ -58,8 +58,8 @@ class DataValidationEnvironment(Environment):
58
  errors_fixed=0,
59
  step_count=0,
60
  max_steps=task["max_steps"],
61
- reward=0.0,
62
- cumulative_reward=0.0,
63
  done=False,
64
  last_action_result="Environment reset. Examine errors and fix them.",
65
  task_hint=task["hint"],
@@ -69,7 +69,7 @@ class DataValidationEnvironment(Environment):
69
 
70
  def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation:
71
  if self._state.done:
72
- return self._make_observation(0.0, "Episode already done. Call reset().")
73
 
74
  self._state.step_count += 1
75
 
@@ -78,7 +78,7 @@ class DataValidationEnvironment(Environment):
78
  self._state.last_actions.append(action_key)
79
 
80
  if is_repeat:
81
- reward = 0.0
82
  message = "Penalty: repeated identical action"
83
  else:
84
  reward, message, fixed = grade_action(
@@ -118,6 +118,9 @@ class DataValidationEnvironment(Environment):
118
 
119
  unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
120
 
 
 
 
121
  return DataCleanObservation(
122
  task_name=self._state.task_name,
123
  task_description=self._task_info.get("description", ""),
@@ -128,8 +131,8 @@ class DataValidationEnvironment(Environment):
128
  errors_fixed=self._state.errors_fixed,
129
  step_count=self._state.step_count,
130
  max_steps=self._state.max_steps,
131
- reward=reward,
132
- cumulative_reward=self._state.cumulative_reward,
133
  done=self._state.done,
134
  last_action_result=message,
135
  task_hint=self._task_info.get("hint", ""),
 
39
  max_steps=task["max_steps"],
40
  done=False,
41
  reward_history=[],
42
+ cumulative_reward=0.01,
43
  dataset=task["dataset"],
44
  ground_truth=self._ground_truth,
45
  errors=self._errors,
 
58
  errors_fixed=0,
59
  step_count=0,
60
  max_steps=task["max_steps"],
61
+ reward=0.01,
62
+ cumulative_reward=0.01,
63
  done=False,
64
  last_action_result="Environment reset. Examine errors and fix them.",
65
  task_hint=task["hint"],
 
69
 
70
  def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation:
71
  if self._state.done:
72
+ return self._make_observation(0.01, "Episode already done. Call reset().")
73
 
74
  self._state.step_count += 1
75
 
 
78
  self._state.last_actions.append(action_key)
79
 
80
  if is_repeat:
81
+ reward = 0.01
82
  message = "Penalty: repeated identical action"
83
  else:
84
  reward, message, fixed = grade_action(
 
118
 
119
  unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
120
 
121
+ clamped_reward = max(0.01, min(0.99, reward))
122
+ clamped_cumulative = max(0.01, min(0.99, self._state.cumulative_reward))
123
+
124
  return DataCleanObservation(
125
  task_name=self._state.task_name,
126
  task_description=self._task_info.get("description", ""),
 
131
  errors_fixed=self._state.errors_fixed,
132
  step_count=self._state.step_count,
133
  max_steps=self._state.max_steps,
134
+ reward=clamped_reward,
135
+ cumulative_reward=clamped_cumulative,
136
  done=self._state.done,
137
  last_action_result=message,
138
  task_hint=self._task_info.get("hint", ""),
env/models.py CHANGED
@@ -21,7 +21,7 @@ class DataCleanObservation(Observation):
21
  errors_fixed: int = Field(default=0)
22
  step_count: int = Field(default=0)
23
  max_steps: int = Field(default=20)
24
- cumulative_reward: float = Field(default=0.0)
25
  last_action_result: str = Field(default="")
26
  task_hint: str = Field(default="")
27
  available_actions: List[str] = Field(
@@ -39,7 +39,7 @@ class DataCleanState(State):
39
  max_steps: int = Field(default=20)
40
  done: bool = Field(default=False)
41
  reward_history: List[float] = Field(default_factory=list)
42
- cumulative_reward: float = Field(default=0.0)
43
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
44
  ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
45
  errors: List[Dict[str, Any]] = Field(default_factory=list)
 
21
  errors_fixed: int = Field(default=0)
22
  step_count: int = Field(default=0)
23
  max_steps: int = Field(default=20)
24
+ cumulative_reward: float = Field(default=0.01)
25
  last_action_result: str = Field(default="")
26
  task_hint: str = Field(default="")
27
  available_actions: List[str] = Field(
 
39
  max_steps: int = Field(default=20)
40
  done: bool = Field(default=False)
41
  reward_history: List[float] = Field(default_factory=list)
42
+ cumulative_reward: float = Field(default=0.01)
43
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
44
  ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
45
  errors: List[Dict[str, Any]] = Field(default_factory=list)