rahul24raj commited on
Commit
a1121a8
·
verified ·
1 Parent(s): a0b8672

Upload folder using huggingface_hub

Browse files
baseline_inference.py CHANGED
@@ -1,184 +1,94 @@
1
- """Baseline inference script for the Code Review environment.
2
 
3
- Runs an LLM (via OpenAI-compatible API) against all tasks and reports
4
- reproducible baseline scores. This is a required submission artifact.
 
5
 
6
  Usage:
7
- # Start the environment server first:
8
- uvicorn src.agentic_rl.server.app:app --port 8000
9
-
10
- # Run baseline (requires OPENAI_API_KEY env var):
11
  python baseline_inference.py
12
 
13
- # Or specify a different model/base URL:
14
- python baseline_inference.py --model gpt-4o-mini --base-url http://localhost:8000
 
 
 
15
  """
16
 
17
  import argparse
18
  import json
19
- import os
20
- import sys
21
  import time
22
 
23
- try:
24
- from openai import OpenAI
25
- except ImportError:
26
- print("ERROR: openai package not installed. Run: pip install openai")
27
- sys.exit(1)
28
-
29
- try:
30
- import httpx
31
- except ImportError:
32
- print("ERROR: httpx package not installed. Run: pip install httpx")
33
- sys.exit(1)
34
-
35
-
36
- SYSTEM_PROMPT = """You are an expert code reviewer. You will be given a code snippet and must identify all bugs, logic errors, security vulnerabilities, and style issues.
37
-
38
- For each issue you find, provide:
39
- - line: the line number where the issue occurs
40
- - severity: "critical", "major", or "minor"
41
- - category: "bug", "security", "style", "performance", or "logic"
42
- - description: a clear explanation of what's wrong
43
- - suggestion: how to fix it
44
-
45
- Also provide an overall_assessment: "approve", "request_changes", or "comment".
46
-
47
- Respond ONLY with valid JSON in this exact format:
48
- {
49
- "issues_found": [
50
- {
51
- "line": "5",
52
- "severity": "critical",
53
- "category": "bug",
54
- "description": "Description of the issue",
55
- "suggestion": "How to fix it"
56
- }
57
- ],
58
- "overall_assessment": "request_changes",
59
- "confidence": 0.9
60
- }"""
61
-
62
-
63
- def call_llm(client: OpenAI, model: str, code: str, context: str) -> dict:
64
- """Send code to LLM for review and parse the response."""
65
- user_message = f"""Review this code for bugs, logic errors, and security vulnerabilities.
66
-
67
- Context: {context}
68
-
69
- ```python
70
- {code}
71
- ```
72
-
73
- Respond with JSON only."""
74
-
75
- response = client.chat.completions.create(
76
- model=model,
77
- messages=[
78
- {"role": "system", "content": SYSTEM_PROMPT},
79
- {"role": "user", "content": user_message},
80
- ],
81
- temperature=0.0, # Deterministic for reproducibility
82
- max_tokens=2000,
83
- )
84
-
85
- content = response.choices[0].message.content.strip()
86
 
87
- # Extract JSON from potential markdown code blocks
88
- if "```json" in content:
89
- content = content.split("```json")[1].split("```")[0].strip()
90
- elif "```" in content:
91
- content = content.split("```")[1].split("```")[0].strip()
92
 
93
- return json.loads(content)
 
 
 
94
 
95
-
96
- def run_baseline(env_url: str, model: str, openai_base_url: str = None):
97
- """Run baseline inference on all tasks and report scores."""
98
- api_key = os.environ.get("OPENAI_API_KEY")
99
- if not api_key:
100
- print("ERROR: OPENAI_API_KEY environment variable not set.")
101
- print("Set it with: export OPENAI_API_KEY=your_key_here")
102
- sys.exit(1)
103
-
104
- # Initialize clients
105
- llm_kwargs = {"api_key": api_key}
106
- if openai_base_url:
107
- llm_kwargs["base_url"] = openai_base_url
108
- llm_client = OpenAI(**llm_kwargs)
109
- env_client = httpx.Client(base_url=env_url, timeout=30.0)
110
-
111
- # Get all tasks
112
- tasks_resp = env_client.get("/tasks")
113
- tasks_resp.raise_for_status()
114
- all_tasks = tasks_resp.json()["tasks"]
115
 
116
  print(f"{'='*60}")
117
- print(f"Baseline Inference — Code Review Environment")
118
- print(f"Model: {model}")
119
- print(f"Environment: {env_url}")
120
  print(f"Tasks: {len(all_tasks)}")
121
  print(f"{'='*60}\n")
122
 
123
  results = []
 
124
 
125
  for task_info in all_tasks:
126
  task_id = task_info["task_id"]
127
- difficulty = task_info["difficulty"]
128
-
129
- print(f"--- Task: {task_id} ({difficulty}) ---")
130
-
131
- # Reset environment
132
- reset_resp = env_client.post("/reset", json={"task_id": task_id})
133
- reset_resp.raise_for_status()
134
- obs = reset_resp.json()
135
-
136
- code = obs["code_snippet"]
137
- context = obs["context"]
138
-
139
- # Get LLM review
140
- try:
141
- review = call_llm(llm_client, model, code, context)
142
- except (json.JSONDecodeError, Exception) as e:
143
- print(f" LLM Error: {e}")
144
- review = {
145
- "issues_found": [],
146
- "overall_assessment": "comment",
147
- "confidence": 0.0,
148
- }
149
-
150
- # Submit to environment
151
- step_resp = env_client.post("/step", json=review)
152
- step_resp.raise_for_status()
153
- result = step_resp.json()
154
-
155
- score = result["reward"]
156
- feedback = result["feedback"]
157
-
158
- print(f" Score: {score:.3f}")
159
- print(f" Feedback: {feedback}")
160
- print(f" Issues reported: {len(review.get('issues_found', []))}")
161
- print()
162
 
163
  results.append({
164
  "task_id": task_id,
165
- "difficulty": difficulty,
166
  "score": score,
167
- "issues_found": len(review.get("issues_found", [])),
168
- "feedback": feedback,
 
 
169
  })
170
 
 
 
171
  # Summary
172
- print(f"{'='*60}")
173
  print("BASELINE RESULTS SUMMARY")
174
  print(f"{'='*60}")
175
 
176
- by_difficulty = {"easy": [], "medium": [], "hard": []}
177
  for r in results:
178
  by_difficulty[r["difficulty"]].append(r["score"])
179
 
180
  total_scores = []
181
- for difficulty in ["easy", "medium", "hard"]:
182
  scores = by_difficulty[difficulty]
183
  if scores:
184
  avg = sum(scores) / len(scores)
@@ -187,39 +97,40 @@ def run_baseline(env_url: str, model: str, openai_base_url: str = None):
187
 
188
  overall_avg = sum(total_scores) / len(total_scores) if total_scores else 0.0
189
  print(f" {'OVERALL':8s}: {overall_avg:.3f} avg ({len(total_scores)} tasks)")
 
190
  print(f"{'='*60}")
191
 
192
  # Save results
193
  output = {
194
- "model": model,
195
- "environment": env_url,
196
  "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
197
  "results": results,
198
  "summary": {
199
- "overall_avg": overall_avg,
 
 
200
  "by_difficulty": {
201
- k: sum(v) / len(v) if v else 0.0
202
  for k, v in by_difficulty.items()
203
  },
204
  },
205
  }
206
 
207
- with open("baseline_results.json", "w") as f:
208
- json.dump(output, f, indent=2)
 
 
209
 
210
- print(f"\nResults saved to baseline_results.json")
211
  return output
212
 
213
 
214
  if __name__ == "__main__":
215
- parser = argparse.ArgumentParser(description="Run baseline inference")
216
- parser.add_argument("--model", default="gpt-4o-mini", help="Model name")
217
- parser.add_argument("--base-url", default="http://localhost:8000", help="Env server URL")
218
- parser.add_argument("--openai-base-url", default=None, help="OpenAI API base URL")
 
219
  args = parser.parse_args()
220
 
221
- run_baseline(
222
- env_url=args.base_url,
223
- model=args.model,
224
- openai_base_url=args.openai_base_url,
225
- )
 
1
+ """Baseline inference script for the Fish Farm environment.
2
 
3
+ Runs the heuristic agent against all 12 tasks and reports reproducible
4
+ baseline scores. This is a required submission artifact that demonstrates
5
+ the graders produce meaningful, varying signal across difficulty levels.
6
 
7
  Usage:
8
+ # Heuristic baseline (no API key needed, no server needed)
 
 
 
9
  python baseline_inference.py
10
 
11
+ # Save results to file
12
+ python baseline_inference.py --output baseline_results.json
13
+
14
+ # Single task
15
+ python baseline_inference.py --task feeding_basics
16
  """
17
 
18
  import argparse
19
  import json
 
 
20
  import time
21
 
22
+ from src.agentic_rl.server.environment import FishFarmEnvironment
23
+ from src.agentic_rl.models import FarmAction
24
+ from src.agentic_rl.tasks import list_all_tasks
25
+ from inference import heuristic_action
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
27
 
28
+ def run_baseline(task_ids=None, output_file=None):
29
+ """Run heuristic baseline on all tasks and report scores."""
30
+ all_tasks = list_all_tasks()
31
+ all_tasks.sort(key=lambda t: t["episode_hours"])
32
 
33
+ if task_ids:
34
+ all_tasks = [t for t in all_tasks if t["task_id"] in task_ids]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  print(f"{'='*60}")
37
+ print("Baseline Inference — Fish Farm Environment")
38
+ print("Agent: Heuristic (rule-based, deterministic)")
 
39
  print(f"Tasks: {len(all_tasks)}")
40
  print(f"{'='*60}\n")
41
 
42
  results = []
43
+ total_start = time.time()
44
 
45
  for task_info in all_tasks:
46
  task_id = task_info["task_id"]
47
+ max_hours = task_info["episode_hours"]
48
+
49
+ print(f"--- {task_id} ({task_info['difficulty']}, {max_hours}h) ---")
50
+
51
+ env = FishFarmEnvironment()
52
+ obs = env.reset(task_id=task_id, seed=42)
53
+ obs_dict = obs.model_dump()
54
+
55
+ steps = 0
56
+ while not obs_dict.get("done", False) and steps < max_hours:
57
+ action_dict = heuristic_action(obs_dict, task_id, steps, max_hours)
58
+ action = FarmAction(**action_dict)
59
+ obs = env.step(action)
60
+ obs_dict = obs.model_dump()
61
+ steps += 1
62
+
63
+ score = obs_dict.get("reward", 0) or 0
64
+
65
+ print(f" Score: {score:.3f} | Weight: {obs_dict.get('avg_fish_weight', 0):.0f}g "
66
+ f"| Pop: {obs_dict.get('population', 0)} "
67
+ f"| Profit: ${obs_dict.get('current_profit', 0):.0f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  results.append({
70
  "task_id": task_id,
71
+ "difficulty": task_info["difficulty"],
72
  "score": score,
73
+ "steps": steps,
74
+ "final_weight": obs_dict.get("avg_fish_weight", 0),
75
+ "final_population": obs_dict.get("population", 0),
76
+ "final_profit": obs_dict.get("current_profit", 0),
77
  })
78
 
79
+ total_elapsed = time.time() - total_start
80
+
81
  # Summary
82
+ print(f"\n{'='*60}")
83
  print("BASELINE RESULTS SUMMARY")
84
  print(f"{'='*60}")
85
 
86
+ by_difficulty = {"easy": [], "medium": [], "hard": [], "extreme": []}
87
  for r in results:
88
  by_difficulty[r["difficulty"]].append(r["score"])
89
 
90
  total_scores = []
91
+ for difficulty in ["easy", "medium", "hard", "extreme"]:
92
  scores = by_difficulty[difficulty]
93
  if scores:
94
  avg = sum(scores) / len(scores)
 
97
 
98
  overall_avg = sum(total_scores) / len(total_scores) if total_scores else 0.0
99
  print(f" {'OVERALL':8s}: {overall_avg:.3f} avg ({len(total_scores)} tasks)")
100
+ print(f" Time: {total_elapsed:.1f}s")
101
  print(f"{'='*60}")
102
 
103
  # Save results
104
  output = {
105
+ "agent": "heuristic",
106
+ "seed": 42,
107
  "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
108
  "results": results,
109
  "summary": {
110
+ "overall_avg": round(overall_avg, 3),
111
+ "total_tasks": len(results),
112
+ "elapsed_s": round(total_elapsed, 1),
113
  "by_difficulty": {
114
+ k: round(sum(v) / len(v), 3) if v else 0.0
115
  for k, v in by_difficulty.items()
116
  },
117
  },
118
  }
119
 
120
+ if output_file:
121
+ with open(output_file, "w") as f:
122
+ json.dump(output, f, indent=2)
123
+ print(f"\nResults saved to {output_file}")
124
 
 
125
  return output
126
 
127
 
128
  if __name__ == "__main__":
129
+ parser = argparse.ArgumentParser(description="Fish Farm Baseline Inference")
130
+ parser.add_argument("--task", type=str, nargs="+", default=None,
131
+ help="Specific task(s) to run")
132
+ parser.add_argument("--output", type=str, default="baseline_results.json",
133
+ help="Output file (default: baseline_results.json)")
134
  args = parser.parse_args()
135
 
136
+ run_baseline(task_ids=args.task, output_file=args.output)
 
 
 
 
baseline_results.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "agent": "heuristic",
3
+ "seed": 42,
4
+ "timestamp": "2026-04-02T14:34:26Z",
5
+ "results": [
6
+ {
7
+ "task_id": "oxygen_management",
8
+ "difficulty": "easy",
9
+ "score": 1.0,
10
+ "steps": 72,
11
+ "final_weight": 102.82,
12
+ "final_population": 4000,
13
+ "final_profit": 318.35
14
+ },
15
+ {
16
+ "task_id": "feeding_basics",
17
+ "difficulty": "easy",
18
+ "score": 0.857,
19
+ "steps": 168,
20
+ "final_weight": 55.32,
21
+ "final_population": 5000,
22
+ "final_profit": -176.92
23
+ }
24
+ ],
25
+ "summary": {
26
+ "overall_avg": 0.928,
27
+ "total_tasks": 2,
28
+ "elapsed_s": 0.0,
29
+ "by_difficulty": {
30
+ "easy": 0.928,
31
+ "medium": 0.0,
32
+ "hard": 0.0,
33
+ "extreme": 0.0
34
+ }
35
+ }
36
+ }
inference.py CHANGED
@@ -24,8 +24,7 @@ import json
24
  import os
25
  import time
26
  import re
27
- import math
28
- from typing import Any, Dict, List, Optional
29
 
30
  import httpx
31
  from openai import OpenAI
@@ -174,7 +173,6 @@ def heuristic_action(obs: Dict[str, Any], task_id: str, step: int, max_hours: in
174
  nighttime_do_risk = obs.get("nighttime_do_risk", 0.0)
175
  feed_price = obs.get("feed_price_per_kg", 0.50)
176
  hours_left = max_hours - step
177
- wq_score = obs.get("water_quality_score", 0.8)
178
 
179
  # ---- Aeration (proactive nighttime crash prevention) ----
180
  algae_bloom = obs.get("algae_bloom", False)
@@ -787,7 +785,7 @@ async def async_main():
787
 
788
  # Summary
789
  print(f"\n{'='*60}")
790
- print(f" SUMMARY")
791
  print(f"{'='*60}")
792
  avg_score = sum(r["final_reward"] for r in results) / len(results) if results else 0
793
  total_llm = sum(r["llm_calls"] for r in results)
 
24
  import os
25
  import time
26
  import re
27
+ from typing import Any, Dict, List
 
28
 
29
  import httpx
30
  from openai import OpenAI
 
173
  nighttime_do_risk = obs.get("nighttime_do_risk", 0.0)
174
  feed_price = obs.get("feed_price_per_kg", 0.50)
175
  hours_left = max_hours - step
 
176
 
177
  # ---- Aeration (proactive nighttime crash prevention) ----
178
  algae_bloom = obs.get("algae_bloom", False)
 
785
 
786
  # Summary
787
  print(f"\n{'='*60}")
788
+ print(" SUMMARY")
789
  print(f"{'='*60}")
790
  avg_score = sum(r["final_reward"] for r in results) / len(results) if results else 0
791
  total_llm = sum(r["llm_calls"] for r in results)
inference_local.py CHANGED
@@ -18,15 +18,13 @@ Usage:
18
  """
19
 
20
  import argparse
21
- import json
22
  import os
23
- import re
24
  import time
25
- from typing import Any, Dict, List, Optional
26
 
27
  from src.agentic_rl.server.environment import FishFarmEnvironment
28
  from src.agentic_rl.models import FarmAction
29
- from src.agentic_rl.tasks import TASKS, list_all_tasks
30
  from inference import (
31
  heuristic_action,
32
  build_observation_prompt,
@@ -233,10 +231,6 @@ def main():
233
 
234
  # Time budget: 18 min total
235
  total_budget_s = 18 * 60
236
- total_hours = sum(
237
- next(t["episode_hours"] for t in list_all_tasks() if t["task_id"] == tid)
238
- for tid in task_ids
239
- )
240
 
241
  results = []
242
  total_start = time.time()
@@ -262,7 +256,7 @@ def main():
262
 
263
  # Summary
264
  print(f"\n{'='*60}")
265
- print(f" SUMMARY")
266
  print(f"{'='*60}")
267
  avg_score = sum(r["score"] for r in results) / len(results) if results else 0
268
  total_llm = sum(r["llm_calls"] for r in results)
 
18
  """
19
 
20
  import argparse
 
21
  import os
 
22
  import time
23
+ from typing import Any, Dict, List
24
 
25
  from src.agentic_rl.server.environment import FishFarmEnvironment
26
  from src.agentic_rl.models import FarmAction
27
+ from src.agentic_rl.tasks import list_all_tasks
28
  from inference import (
29
  heuristic_action,
30
  build_observation_prompt,
 
231
 
232
  # Time budget: 18 min total
233
  total_budget_s = 18 * 60
 
 
 
 
234
 
235
  results = []
236
  total_start = time.time()
 
256
 
257
  # Summary
258
  print(f"\n{'='*60}")
259
+ print(" SUMMARY")
260
  print(f"{'='*60}")
261
  avg_score = sum(r["score"] for r in results) / len(results) if results else 0
262
  total_llm = sum(r["llm_calls"] for r in results)
training/train_grpo.py CHANGED
@@ -19,7 +19,7 @@ import os
19
 
20
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
21
 
22
- from rewards import CompositeReward, SurvivalReward, WaterQualityReward
23
 
24
 
25
  SYSTEM_PROMPT = """You are an expert Nile Tilapia aquaculture manager. Given the current state of a 100m³ RAS fish farm, decide the next hour's actions.
 
19
 
20
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
21
 
22
+ from rewards import CompositeReward
23
 
24
 
25
  SYSTEM_PROMPT = """You are an expert Nile Tilapia aquaculture manager. Given the current state of a 100m³ RAS fish farm, decide the next hour's actions.