kush5699 commited on
Commit
842577f
·
verified ·
1 Parent(s): d6f9aaf

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. env/environment.py +27 -30
  2. env/models.py +8 -10
  3. inference.py +1 -1
  4. server/app.py +22 -150
env/environment.py CHANGED
@@ -1,20 +1,27 @@
1
  import uuid
2
  from typing import Any, Dict, List, Optional
3
 
 
 
 
4
  from env.models import DataCleanAction, DataCleanObservation, DataCleanState
5
  from env.tasks import generate_task, get_task_names, grade_action
6
 
7
 
8
- class DataValidationEnvironment:
 
 
9
 
10
  def __init__(self):
 
11
  self._state = DataCleanState()
12
  self._ground_truth: List[Dict[str, Any]] = []
13
  self._errors: List[Dict[str, Any]] = []
14
  self._task_info: Dict[str, Any] = {}
15
  self._field_names: List[str] = []
16
 
17
- def reset(self, task_name: Optional[str] = None, seed: int = 42, **kwargs) -> DataCleanObservation:
 
18
  if task_name is None:
19
  task_name = "easy_missing_values"
20
 
@@ -26,13 +33,13 @@ class DataValidationEnvironment:
26
  self._field_names = task["field_names"]
27
 
28
  self._state = DataCleanState(
29
- episode_id=str(uuid.uuid4()),
30
  task_name=task_name,
31
  step_count=0,
32
  max_steps=task["max_steps"],
33
  done=False,
34
  reward_history=[],
35
- cumulative_reward=0.01,
36
  dataset=task["dataset"],
37
  ground_truth=self._ground_truth,
38
  errors=self._errors,
@@ -46,23 +53,23 @@ class DataValidationEnvironment:
46
  task_description=task["description"],
47
  dataset=task["dataset"],
48
  errors_found=self._errors,
49
- errors_remaining=len(self._errors) + 1,
50
- errors_total=len(self._errors) + 2,
51
- errors_fixed=1,
52
  step_count=0,
53
  max_steps=task["max_steps"],
54
- reward=0.01,
55
- cumulative_reward=0.01,
56
  done=False,
57
  last_action_result="Environment reset. Examine errors and fix them.",
58
  task_hint=task["hint"],
59
- progress_pct=1.0,
60
  field_names=self._field_names,
61
  )
62
 
63
- def step(self, action: DataCleanAction) -> DataCleanObservation:
64
  if self._state.done:
65
- return self._make_observation(0.01, "Episode already done. Call reset().")
66
 
67
  self._state.step_count += 1
68
 
@@ -71,7 +78,7 @@ class DataValidationEnvironment:
71
  self._state.last_actions.append(action_key)
72
 
73
  if is_repeat:
74
- reward = 0.01
75
  message = "Penalty: repeated identical action"
76
  else:
77
  reward, message, fixed = grade_action(
@@ -100,12 +107,10 @@ class DataValidationEnvironment:
100
 
101
  return self._make_observation(reward, message)
102
 
 
103
  def state(self) -> DataCleanState:
104
  return self._state
105
 
106
- def get_task_names(self) -> List[str]:
107
- return get_task_names()
108
-
109
  def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
110
  errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
111
  total = self._state.total_errors if self._state.total_errors > 0 else 1
@@ -113,29 +118,21 @@ class DataValidationEnvironment:
113
 
114
  unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
115
 
116
- clamped_reward = max(0.01, min(0.99, reward))
117
- clamped_cumulative = max(0.01, min(0.99, self._state.cumulative_reward))
118
- clamped_progress = max(1.0, min(99.0, progress))
119
-
120
- reported_total = self._state.total_errors + 2
121
- reported_remaining = errors_remaining + 1
122
-
123
  return DataCleanObservation(
124
  task_name=self._state.task_name,
125
  task_description=self._task_info.get("description", ""),
126
  dataset=self._state.dataset,
127
  errors_found=unfixed_errors,
128
- errors_remaining=reported_remaining,
129
- errors_total=reported_total,
130
- errors_fixed=self._state.errors_fixed + 1,
131
  step_count=self._state.step_count,
132
  max_steps=self._state.max_steps,
133
- reward=clamped_reward,
134
- cumulative_reward=clamped_cumulative,
135
  done=self._state.done,
136
  last_action_result=message,
137
  task_hint=self._task_info.get("hint", ""),
138
- progress_pct=clamped_progress,
139
  field_names=self._field_names,
140
  )
141
-
 
1
  import uuid
2
  from typing import Any, Dict, List, Optional
3
 
4
+ from openenv.core.env_server.interfaces import Environment
5
+ from openenv.core.env_server.types import State
6
+
7
  from env.models import DataCleanAction, DataCleanObservation, DataCleanState
8
  from env.tasks import generate_task, get_task_names, grade_action
9
 
10
 
11
+ class DataValidationEnvironment(Environment):
12
+
13
+ SUPPORTS_CONCURRENT_SESSIONS: bool = True
14
 
15
  def __init__(self):
16
+ super().__init__()
17
  self._state = DataCleanState()
18
  self._ground_truth: List[Dict[str, Any]] = []
19
  self._errors: List[Dict[str, Any]] = []
20
  self._task_info: Dict[str, Any] = {}
21
  self._field_names: List[str] = []
22
 
23
+ def reset(self, task_name: Optional[str] = None, seed: int = 42,
24
+ episode_id: Optional[str] = None, **kwargs) -> DataCleanObservation:
25
  if task_name is None:
26
  task_name = "easy_missing_values"
27
 
 
33
  self._field_names = task["field_names"]
34
 
35
  self._state = DataCleanState(
36
+ episode_id=episode_id or str(uuid.uuid4()),
37
  task_name=task_name,
38
  step_count=0,
39
  max_steps=task["max_steps"],
40
  done=False,
41
  reward_history=[],
42
+ cumulative_reward=0.0,
43
  dataset=task["dataset"],
44
  ground_truth=self._ground_truth,
45
  errors=self._errors,
 
53
  task_description=task["description"],
54
  dataset=task["dataset"],
55
  errors_found=self._errors,
56
+ errors_remaining=len(self._errors),
57
+ errors_total=len(self._errors),
58
+ errors_fixed=0,
59
  step_count=0,
60
  max_steps=task["max_steps"],
61
+ reward=0.0,
62
+ cumulative_reward=0.0,
63
  done=False,
64
  last_action_result="Environment reset. Examine errors and fix them.",
65
  task_hint=task["hint"],
66
+ progress_pct=0.0,
67
  field_names=self._field_names,
68
  )
69
 
70
+ def step(self, action: DataCleanAction, **kwargs) -> DataCleanObservation:
71
  if self._state.done:
72
+ return self._make_observation(0.0, "Episode already done. Call reset().")
73
 
74
  self._state.step_count += 1
75
 
 
78
  self._state.last_actions.append(action_key)
79
 
80
  if is_repeat:
81
+ reward = 0.0
82
  message = "Penalty: repeated identical action"
83
  else:
84
  reward, message, fixed = grade_action(
 
107
 
108
  return self._make_observation(reward, message)
109
 
110
+ @property
111
  def state(self) -> DataCleanState:
112
  return self._state
113
 
 
 
 
114
  def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
115
  errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
116
  total = self._state.total_errors if self._state.total_errors > 0 else 1
 
118
 
119
  unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
120
 
 
 
 
 
 
 
 
121
  return DataCleanObservation(
122
  task_name=self._state.task_name,
123
  task_description=self._task_info.get("description", ""),
124
  dataset=self._state.dataset,
125
  errors_found=unfixed_errors,
126
+ errors_remaining=errors_remaining,
127
+ errors_total=self._state.total_errors,
128
+ errors_fixed=self._state.errors_fixed,
129
  step_count=self._state.step_count,
130
  max_steps=self._state.max_steps,
131
+ reward=reward,
132
+ cumulative_reward=self._state.cumulative_reward,
133
  done=self._state.done,
134
  last_action_result=message,
135
  task_hint=self._task_info.get("hint", ""),
136
+ progress_pct=progress,
137
  field_names=self._field_names,
138
  )
 
env/models.py CHANGED
@@ -1,15 +1,17 @@
1
  from typing import Any, Dict, List, Optional
2
- from pydantic import BaseModel, Field
3
 
 
4
 
5
- class DataCleanAction(BaseModel):
 
6
  action_type: str = Field(...)
7
  target_field: str = Field(default="")
8
  target_row: int = Field(default=0)
9
  new_value: str = Field(default="")
10
 
11
 
12
- class DataCleanObservation(BaseModel):
13
  task_name: str = Field(default="")
14
  task_description: str = Field(default="")
15
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
@@ -19,9 +21,7 @@ class DataCleanObservation(BaseModel):
19
  errors_fixed: int = Field(default=0)
20
  step_count: int = Field(default=0)
21
  max_steps: int = Field(default=20)
22
- reward: float = Field(default=0.01)
23
- cumulative_reward: float = Field(default=0.01)
24
- done: bool = Field(default=False)
25
  last_action_result: str = Field(default="")
26
  task_hint: str = Field(default="")
27
  available_actions: List[str] = Field(
@@ -34,14 +34,12 @@ class DataCleanObservation(BaseModel):
34
  field_names: List[str] = Field(default_factory=list)
35
 
36
 
37
- class DataCleanState(BaseModel):
38
- episode_id: str = Field(default="")
39
  task_name: str = Field(default="")
40
- step_count: int = Field(default=0)
41
  max_steps: int = Field(default=20)
42
  done: bool = Field(default=False)
43
  reward_history: List[float] = Field(default_factory=list)
44
- cumulative_reward: float = Field(default=0.01)
45
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
46
  ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
47
  errors: List[Dict[str, Any]] = Field(default_factory=list)
 
1
  from typing import Any, Dict, List, Optional
2
+ from pydantic import Field
3
 
4
+ from openenv.core.env_server.types import Action, Observation, State
5
 
6
+
7
+ class DataCleanAction(Action):
8
  action_type: str = Field(...)
9
  target_field: str = Field(default="")
10
  target_row: int = Field(default=0)
11
  new_value: str = Field(default="")
12
 
13
 
14
+ class DataCleanObservation(Observation):
15
  task_name: str = Field(default="")
16
  task_description: str = Field(default="")
17
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
 
21
  errors_fixed: int = Field(default=0)
22
  step_count: int = Field(default=0)
23
  max_steps: int = Field(default=20)
24
+ cumulative_reward: float = Field(default=0.0)
 
 
25
  last_action_result: str = Field(default="")
26
  task_hint: str = Field(default="")
27
  available_actions: List[str] = Field(
 
34
  field_names: List[str] = Field(default_factory=list)
35
 
36
 
37
+ class DataCleanState(State):
 
38
  task_name: str = Field(default="")
 
39
  max_steps: int = Field(default=20)
40
  done: bool = Field(default=False)
41
  reward_history: List[float] = Field(default_factory=list)
42
+ cumulative_reward: float = Field(default=0.0)
43
  dataset: List[Dict[str, Any]] = Field(default_factory=list)
44
  ground_truth: List[Dict[str, Any]] = Field(default_factory=list)
45
  errors: List[Dict[str, Any]] = Field(default_factory=list)
inference.py CHANGED
@@ -60,7 +60,7 @@ def env_reset(task_name: str, seed: int = 42) -> dict:
60
  def env_step(action: dict) -> dict:
61
  resp = requests.post(
62
  f"{ENV_BASE_URL}/step",
63
- json=action,
64
  timeout=30,
65
  )
66
  resp.raise_for_status()
 
60
  def env_step(action: dict) -> dict:
61
  resp = requests.post(
62
  f"{ENV_BASE_URL}/step",
63
+ json={"action": action},
64
  timeout=30,
65
  )
66
  resp.raise_for_status()
server/app.py CHANGED
@@ -1,161 +1,33 @@
1
- import json
2
- import traceback
3
- from typing import Optional
4
 
5
- from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
6
- from pydantic import BaseModel
7
 
 
 
 
 
 
 
 
 
8
  from env.environment import DataValidationEnvironment
9
- from env.models import DataCleanAction
10
- from env.tasks import get_task_names
11
 
12
- app = FastAPI(
13
- title="Data Validation Pipeline - OpenEnv Environment",
14
- version="1.0.0",
 
 
 
 
15
  )
16
 
17
- env = DataValidationEnvironment()
18
-
19
-
20
- class ResetRequest(BaseModel):
21
- task_name: Optional[str] = None
22
- seed: int = 42
23
-
24
-
25
- class StepRequest(BaseModel):
26
- action_type: str
27
- target_field: str = ""
28
- target_row: int = 0
29
- new_value: str = ""
30
-
31
-
32
- @app.get("/")
33
- async def root():
34
- return {
35
- "name": "Data Validation Pipeline",
36
- "description": "An RL environment for training agents to clean and validate structured data",
37
- "version": "1.0.0",
38
- "endpoints": {
39
- "health": "/health",
40
- "reset": "POST /reset",
41
- "step": "POST /step",
42
- "state": "GET /state",
43
- "tasks": "GET /tasks",
44
- },
45
- "tasks": get_task_names(),
46
- "status": "running",
47
- }
48
-
49
-
50
- @app.get("/health")
51
- async def health():
52
- return {"status": "healthy", "service": "data-validation-env"}
53
-
54
-
55
- @app.post("/reset")
56
- async def reset(request: ResetRequest = None):
57
- if request is None:
58
- request = ResetRequest()
59
- try:
60
- obs = env.reset(task_name=request.task_name, seed=request.seed)
61
- return {
62
- "observation": obs.model_dump(),
63
- "reward": obs.reward,
64
- "done": obs.done,
65
- }
66
- except Exception as e:
67
- raise HTTPException(status_code=400, detail=str(e))
68
-
69
-
70
- @app.post("/step")
71
- async def step(request: StepRequest):
72
- try:
73
- action = DataCleanAction(
74
- action_type=request.action_type,
75
- target_field=request.target_field,
76
- target_row=request.target_row,
77
- new_value=request.new_value,
78
- )
79
- obs = env.step(action)
80
- return {
81
- "observation": obs.model_dump(),
82
- "reward": obs.reward,
83
- "done": obs.done,
84
- }
85
- except Exception as e:
86
- raise HTTPException(status_code=400, detail=str(e))
87
-
88
-
89
- @app.get("/state")
90
- async def state():
91
- try:
92
- s = env.state()
93
- return s.model_dump()
94
- except Exception as e:
95
- raise HTTPException(status_code=400, detail=str(e))
96
-
97
-
98
- @app.get("/tasks")
99
- async def tasks():
100
- return {"tasks": get_task_names()}
101
-
102
-
103
- @app.websocket("/ws")
104
- async def websocket_endpoint(websocket: WebSocket):
105
- await websocket.accept()
106
- ws_env = DataValidationEnvironment()
107
-
108
- try:
109
- while True:
110
- data = await websocket.receive_text()
111
- msg = json.loads(data)
112
-
113
- method = msg.get("method", "")
114
- params = msg.get("params", {})
115
-
116
- try:
117
- if method == "reset":
118
- obs = ws_env.reset(
119
- task_name=params.get("task_name"),
120
- seed=params.get("seed", 42)
121
- )
122
- response = {
123
- "type": "reset",
124
- "observation": obs.model_dump(),
125
- "reward": 0.01,
126
- "done": False,
127
- }
128
- elif method == "step":
129
- action = DataCleanAction(**params)
130
- obs = ws_env.step(action)
131
- response = {
132
- "type": "step",
133
- "observation": obs.model_dump(),
134
- "reward": obs.reward,
135
- "done": obs.done,
136
- }
137
- elif method == "state":
138
- s = ws_env.state()
139
- response = {
140
- "type": "state",
141
- "state": s.model_dump(),
142
- }
143
- else:
144
- response = {"error": f"Unknown method: {method}"}
145
-
146
- await websocket.send_text(json.dumps(response))
147
- except Exception as e:
148
- await websocket.send_text(json.dumps({
149
- "error": str(e),
150
- "traceback": traceback.format_exc()
151
- }))
152
- except WebSocketDisconnect:
153
- pass
154
-
155
 
156
- def main():
 
157
  import uvicorn
158
- uvicorn.run(app, host="0.0.0.0", port=8000)
159
 
160
 
161
  if __name__ == "__main__":
 
1
+ """
2
+ FastAPI application for the Data Validation Environment.
 
3
 
4
+ Uses openenv's create_app() for standard-compliant API endpoints.
5
+ """
6
 
7
+ try:
8
+ from openenv.core.env_server.http_server import create_app
9
+ except Exception as e:
10
+ raise ImportError(
11
+ "openenv-core is required. Install with: pip install openenv-core"
12
+ ) from e
13
+
14
+ from env.models import DataCleanAction, DataCleanObservation
15
  from env.environment import DataValidationEnvironment
 
 
16
 
17
+ # Create the app using the official openenv framework
18
+ app = create_app(
19
+ DataValidationEnvironment,
20
+ DataCleanAction,
21
+ DataCleanObservation,
22
+ env_name="data_validation_env",
23
+ max_concurrent_envs=1,
24
  )
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
+ def main(host: str = "0.0.0.0", port: int = 8000):
28
+ """Run the Data Validation environment server."""
29
  import uvicorn
30
+ uvicorn.run(app, host=host, port=port)
31
 
32
 
33
  if __name__ == "__main__":