Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- baseline_inference.py +71 -160
- baseline_results.json +36 -0
- inference.py +2 -4
- inference_local.py +3 -9
- training/train_grpo.py +1 -1
baseline_inference.py
CHANGED
|
@@ -1,184 +1,94 @@
|
|
| 1 |
-
"""Baseline inference script for the
|
| 2 |
|
| 3 |
-
Runs
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
Usage:
|
| 7 |
-
#
|
| 8 |
-
uvicorn src.agentic_rl.server.app:app --port 8000
|
| 9 |
-
|
| 10 |
-
# Run baseline (requires OPENAI_API_KEY env var):
|
| 11 |
python baseline_inference.py
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
python baseline_inference.py --
|
|
|
|
|
|
|
|
|
|
| 15 |
"""
|
| 16 |
|
| 17 |
import argparse
|
| 18 |
import json
|
| 19 |
-
import os
|
| 20 |
-
import sys
|
| 21 |
import time
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
sys.exit(1)
|
| 28 |
-
|
| 29 |
-
try:
|
| 30 |
-
import httpx
|
| 31 |
-
except ImportError:
|
| 32 |
-
print("ERROR: httpx package not installed. Run: pip install httpx")
|
| 33 |
-
sys.exit(1)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
SYSTEM_PROMPT = """You are an expert code reviewer. You will be given a code snippet and must identify all bugs, logic errors, security vulnerabilities, and style issues.
|
| 37 |
-
|
| 38 |
-
For each issue you find, provide:
|
| 39 |
-
- line: the line number where the issue occurs
|
| 40 |
-
- severity: "critical", "major", or "minor"
|
| 41 |
-
- category: "bug", "security", "style", "performance", or "logic"
|
| 42 |
-
- description: a clear explanation of what's wrong
|
| 43 |
-
- suggestion: how to fix it
|
| 44 |
-
|
| 45 |
-
Also provide an overall_assessment: "approve", "request_changes", or "comment".
|
| 46 |
-
|
| 47 |
-
Respond ONLY with valid JSON in this exact format:
|
| 48 |
-
{
|
| 49 |
-
"issues_found": [
|
| 50 |
-
{
|
| 51 |
-
"line": "5",
|
| 52 |
-
"severity": "critical",
|
| 53 |
-
"category": "bug",
|
| 54 |
-
"description": "Description of the issue",
|
| 55 |
-
"suggestion": "How to fix it"
|
| 56 |
-
}
|
| 57 |
-
],
|
| 58 |
-
"overall_assessment": "request_changes",
|
| 59 |
-
"confidence": 0.9
|
| 60 |
-
}"""
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def call_llm(client: OpenAI, model: str, code: str, context: str) -> dict:
|
| 64 |
-
"""Send code to LLM for review and parse the response."""
|
| 65 |
-
user_message = f"""Review this code for bugs, logic errors, and security vulnerabilities.
|
| 66 |
-
|
| 67 |
-
Context: {context}
|
| 68 |
-
|
| 69 |
-
```python
|
| 70 |
-
{code}
|
| 71 |
-
```
|
| 72 |
-
|
| 73 |
-
Respond with JSON only."""
|
| 74 |
-
|
| 75 |
-
response = client.chat.completions.create(
|
| 76 |
-
model=model,
|
| 77 |
-
messages=[
|
| 78 |
-
{"role": "system", "content": SYSTEM_PROMPT},
|
| 79 |
-
{"role": "user", "content": user_message},
|
| 80 |
-
],
|
| 81 |
-
temperature=0.0, # Deterministic for reproducibility
|
| 82 |
-
max_tokens=2000,
|
| 83 |
-
)
|
| 84 |
-
|
| 85 |
-
content = response.choices[0].message.content.strip()
|
| 86 |
|
| 87 |
-
# Extract JSON from potential markdown code blocks
|
| 88 |
-
if "```json" in content:
|
| 89 |
-
content = content.split("```json")[1].split("```")[0].strip()
|
| 90 |
-
elif "```" in content:
|
| 91 |
-
content = content.split("```")[1].split("```")[0].strip()
|
| 92 |
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
"""Run baseline inference on all tasks and report scores."""
|
| 98 |
-
api_key = os.environ.get("OPENAI_API_KEY")
|
| 99 |
-
if not api_key:
|
| 100 |
-
print("ERROR: OPENAI_API_KEY environment variable not set.")
|
| 101 |
-
print("Set it with: export OPENAI_API_KEY=your_key_here")
|
| 102 |
-
sys.exit(1)
|
| 103 |
-
|
| 104 |
-
# Initialize clients
|
| 105 |
-
llm_kwargs = {"api_key": api_key}
|
| 106 |
-
if openai_base_url:
|
| 107 |
-
llm_kwargs["base_url"] = openai_base_url
|
| 108 |
-
llm_client = OpenAI(**llm_kwargs)
|
| 109 |
-
env_client = httpx.Client(base_url=env_url, timeout=30.0)
|
| 110 |
-
|
| 111 |
-
# Get all tasks
|
| 112 |
-
tasks_resp = env_client.get("/tasks")
|
| 113 |
-
tasks_resp.raise_for_status()
|
| 114 |
-
all_tasks = tasks_resp.json()["tasks"]
|
| 115 |
|
| 116 |
print(f"{'='*60}")
|
| 117 |
-
print(
|
| 118 |
-
print(
|
| 119 |
-
print(f"Environment: {env_url}")
|
| 120 |
print(f"Tasks: {len(all_tasks)}")
|
| 121 |
print(f"{'='*60}\n")
|
| 122 |
|
| 123 |
results = []
|
|
|
|
| 124 |
|
| 125 |
for task_info in all_tasks:
|
| 126 |
task_id = task_info["task_id"]
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
print(f"---
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
-
# Submit to environment
|
| 151 |
-
step_resp = env_client.post("/step", json=review)
|
| 152 |
-
step_resp.raise_for_status()
|
| 153 |
-
result = step_resp.json()
|
| 154 |
-
|
| 155 |
-
score = result["reward"]
|
| 156 |
-
feedback = result["feedback"]
|
| 157 |
-
|
| 158 |
-
print(f" Score: {score:.3f}")
|
| 159 |
-
print(f" Feedback: {feedback}")
|
| 160 |
-
print(f" Issues reported: {len(review.get('issues_found', []))}")
|
| 161 |
-
print()
|
| 162 |
|
| 163 |
results.append({
|
| 164 |
"task_id": task_id,
|
| 165 |
-
"difficulty": difficulty,
|
| 166 |
"score": score,
|
| 167 |
-
"
|
| 168 |
-
"
|
|
|
|
|
|
|
| 169 |
})
|
| 170 |
|
|
|
|
|
|
|
| 171 |
# Summary
|
| 172 |
-
print(f"{'='*60}")
|
| 173 |
print("BASELINE RESULTS SUMMARY")
|
| 174 |
print(f"{'='*60}")
|
| 175 |
|
| 176 |
-
by_difficulty = {"easy": [], "medium": [], "hard": []}
|
| 177 |
for r in results:
|
| 178 |
by_difficulty[r["difficulty"]].append(r["score"])
|
| 179 |
|
| 180 |
total_scores = []
|
| 181 |
-
for difficulty in ["easy", "medium", "hard"]:
|
| 182 |
scores = by_difficulty[difficulty]
|
| 183 |
if scores:
|
| 184 |
avg = sum(scores) / len(scores)
|
|
@@ -187,39 +97,40 @@ def run_baseline(env_url: str, model: str, openai_base_url: str = None):
|
|
| 187 |
|
| 188 |
overall_avg = sum(total_scores) / len(total_scores) if total_scores else 0.0
|
| 189 |
print(f" {'OVERALL':8s}: {overall_avg:.3f} avg ({len(total_scores)} tasks)")
|
|
|
|
| 190 |
print(f"{'='*60}")
|
| 191 |
|
| 192 |
# Save results
|
| 193 |
output = {
|
| 194 |
-
"
|
| 195 |
-
"
|
| 196 |
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
| 197 |
"results": results,
|
| 198 |
"summary": {
|
| 199 |
-
"overall_avg": overall_avg,
|
|
|
|
|
|
|
| 200 |
"by_difficulty": {
|
| 201 |
-
k: sum(v) / len(v) if v else 0.0
|
| 202 |
for k, v in by_difficulty.items()
|
| 203 |
},
|
| 204 |
},
|
| 205 |
}
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
|
|
|
|
|
|
| 209 |
|
| 210 |
-
print(f"\nResults saved to baseline_results.json")
|
| 211 |
return output
|
| 212 |
|
| 213 |
|
| 214 |
if __name__ == "__main__":
|
| 215 |
-
parser = argparse.ArgumentParser(description="
|
| 216 |
-
parser.add_argument("--
|
| 217 |
-
|
| 218 |
-
parser.add_argument("--
|
|
|
|
| 219 |
args = parser.parse_args()
|
| 220 |
|
| 221 |
-
run_baseline(
|
| 222 |
-
env_url=args.base_url,
|
| 223 |
-
model=args.model,
|
| 224 |
-
openai_base_url=args.openai_base_url,
|
| 225 |
-
)
|
|
|
|
| 1 |
+
"""Baseline inference script for the Fish Farm environment.
|
| 2 |
|
| 3 |
+
Runs the heuristic agent against all 12 tasks and reports reproducible
|
| 4 |
+
baseline scores. This is a required submission artifact that demonstrates
|
| 5 |
+
the graders produce meaningful, varying signal across difficulty levels.
|
| 6 |
|
| 7 |
Usage:
|
| 8 |
+
# Heuristic baseline (no API key needed, no server needed)
|
|
|
|
|
|
|
|
|
|
| 9 |
python baseline_inference.py
|
| 10 |
|
| 11 |
+
# Save results to file
|
| 12 |
+
python baseline_inference.py --output baseline_results.json
|
| 13 |
+
|
| 14 |
+
# Single task
|
| 15 |
+
python baseline_inference.py --task feeding_basics
|
| 16 |
"""
|
| 17 |
|
| 18 |
import argparse
|
| 19 |
import json
|
|
|
|
|
|
|
| 20 |
import time
|
| 21 |
|
| 22 |
+
from src.agentic_rl.server.environment import FishFarmEnvironment
|
| 23 |
+
from src.agentic_rl.models import FarmAction
|
| 24 |
+
from src.agentic_rl.tasks import list_all_tasks
|
| 25 |
+
from inference import heuristic_action
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
def run_baseline(task_ids=None, output_file=None):
|
| 29 |
+
"""Run heuristic baseline on all tasks and report scores."""
|
| 30 |
+
all_tasks = list_all_tasks()
|
| 31 |
+
all_tasks.sort(key=lambda t: t["episode_hours"])
|
| 32 |
|
| 33 |
+
if task_ids:
|
| 34 |
+
all_tasks = [t for t in all_tasks if t["task_id"] in task_ids]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
print(f"{'='*60}")
|
| 37 |
+
print("Baseline Inference — Fish Farm Environment")
|
| 38 |
+
print("Agent: Heuristic (rule-based, deterministic)")
|
|
|
|
| 39 |
print(f"Tasks: {len(all_tasks)}")
|
| 40 |
print(f"{'='*60}\n")
|
| 41 |
|
| 42 |
results = []
|
| 43 |
+
total_start = time.time()
|
| 44 |
|
| 45 |
for task_info in all_tasks:
|
| 46 |
task_id = task_info["task_id"]
|
| 47 |
+
max_hours = task_info["episode_hours"]
|
| 48 |
+
|
| 49 |
+
print(f"--- {task_id} ({task_info['difficulty']}, {max_hours}h) ---")
|
| 50 |
+
|
| 51 |
+
env = FishFarmEnvironment()
|
| 52 |
+
obs = env.reset(task_id=task_id, seed=42)
|
| 53 |
+
obs_dict = obs.model_dump()
|
| 54 |
+
|
| 55 |
+
steps = 0
|
| 56 |
+
while not obs_dict.get("done", False) and steps < max_hours:
|
| 57 |
+
action_dict = heuristic_action(obs_dict, task_id, steps, max_hours)
|
| 58 |
+
action = FarmAction(**action_dict)
|
| 59 |
+
obs = env.step(action)
|
| 60 |
+
obs_dict = obs.model_dump()
|
| 61 |
+
steps += 1
|
| 62 |
+
|
| 63 |
+
score = obs_dict.get("reward", 0) or 0
|
| 64 |
+
|
| 65 |
+
print(f" Score: {score:.3f} | Weight: {obs_dict.get('avg_fish_weight', 0):.0f}g "
|
| 66 |
+
f"| Pop: {obs_dict.get('population', 0)} "
|
| 67 |
+
f"| Profit: ${obs_dict.get('current_profit', 0):.0f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
results.append({
|
| 70 |
"task_id": task_id,
|
| 71 |
+
"difficulty": task_info["difficulty"],
|
| 72 |
"score": score,
|
| 73 |
+
"steps": steps,
|
| 74 |
+
"final_weight": obs_dict.get("avg_fish_weight", 0),
|
| 75 |
+
"final_population": obs_dict.get("population", 0),
|
| 76 |
+
"final_profit": obs_dict.get("current_profit", 0),
|
| 77 |
})
|
| 78 |
|
| 79 |
+
total_elapsed = time.time() - total_start
|
| 80 |
+
|
| 81 |
# Summary
|
| 82 |
+
print(f"\n{'='*60}")
|
| 83 |
print("BASELINE RESULTS SUMMARY")
|
| 84 |
print(f"{'='*60}")
|
| 85 |
|
| 86 |
+
by_difficulty = {"easy": [], "medium": [], "hard": [], "extreme": []}
|
| 87 |
for r in results:
|
| 88 |
by_difficulty[r["difficulty"]].append(r["score"])
|
| 89 |
|
| 90 |
total_scores = []
|
| 91 |
+
for difficulty in ["easy", "medium", "hard", "extreme"]:
|
| 92 |
scores = by_difficulty[difficulty]
|
| 93 |
if scores:
|
| 94 |
avg = sum(scores) / len(scores)
|
|
|
|
| 97 |
|
| 98 |
overall_avg = sum(total_scores) / len(total_scores) if total_scores else 0.0
|
| 99 |
print(f" {'OVERALL':8s}: {overall_avg:.3f} avg ({len(total_scores)} tasks)")
|
| 100 |
+
print(f" Time: {total_elapsed:.1f}s")
|
| 101 |
print(f"{'='*60}")
|
| 102 |
|
| 103 |
# Save results
|
| 104 |
output = {
|
| 105 |
+
"agent": "heuristic",
|
| 106 |
+
"seed": 42,
|
| 107 |
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
| 108 |
"results": results,
|
| 109 |
"summary": {
|
| 110 |
+
"overall_avg": round(overall_avg, 3),
|
| 111 |
+
"total_tasks": len(results),
|
| 112 |
+
"elapsed_s": round(total_elapsed, 1),
|
| 113 |
"by_difficulty": {
|
| 114 |
+
k: round(sum(v) / len(v), 3) if v else 0.0
|
| 115 |
for k, v in by_difficulty.items()
|
| 116 |
},
|
| 117 |
},
|
| 118 |
}
|
| 119 |
|
| 120 |
+
if output_file:
|
| 121 |
+
with open(output_file, "w") as f:
|
| 122 |
+
json.dump(output, f, indent=2)
|
| 123 |
+
print(f"\nResults saved to {output_file}")
|
| 124 |
|
|
|
|
| 125 |
return output
|
| 126 |
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
| 129 |
+
parser = argparse.ArgumentParser(description="Fish Farm Baseline Inference")
|
| 130 |
+
parser.add_argument("--task", type=str, nargs="+", default=None,
|
| 131 |
+
help="Specific task(s) to run")
|
| 132 |
+
parser.add_argument("--output", type=str, default="baseline_results.json",
|
| 133 |
+
help="Output file (default: baseline_results.json)")
|
| 134 |
args = parser.parse_args()
|
| 135 |
|
| 136 |
+
run_baseline(task_ids=args.task, output_file=args.output)
|
|
|
|
|
|
|
|
|
|
|
|
baseline_results.json
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"agent": "heuristic",
|
| 3 |
+
"seed": 42,
|
| 4 |
+
"timestamp": "2026-04-02T14:34:26Z",
|
| 5 |
+
"results": [
|
| 6 |
+
{
|
| 7 |
+
"task_id": "oxygen_management",
|
| 8 |
+
"difficulty": "easy",
|
| 9 |
+
"score": 1.0,
|
| 10 |
+
"steps": 72,
|
| 11 |
+
"final_weight": 102.82,
|
| 12 |
+
"final_population": 4000,
|
| 13 |
+
"final_profit": 318.35
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"task_id": "feeding_basics",
|
| 17 |
+
"difficulty": "easy",
|
| 18 |
+
"score": 0.857,
|
| 19 |
+
"steps": 168,
|
| 20 |
+
"final_weight": 55.32,
|
| 21 |
+
"final_population": 5000,
|
| 22 |
+
"final_profit": -176.92
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"summary": {
|
| 26 |
+
"overall_avg": 0.928,
|
| 27 |
+
"total_tasks": 2,
|
| 28 |
+
"elapsed_s": 0.0,
|
| 29 |
+
"by_difficulty": {
|
| 30 |
+
"easy": 0.928,
|
| 31 |
+
"medium": 0.0,
|
| 32 |
+
"hard": 0.0,
|
| 33 |
+
"extreme": 0.0
|
| 34 |
+
}
|
| 35 |
+
}
|
| 36 |
+
}
|
inference.py
CHANGED
|
@@ -24,8 +24,7 @@ import json
|
|
| 24 |
import os
|
| 25 |
import time
|
| 26 |
import re
|
| 27 |
-
import
|
| 28 |
-
from typing import Any, Dict, List, Optional
|
| 29 |
|
| 30 |
import httpx
|
| 31 |
from openai import OpenAI
|
|
@@ -174,7 +173,6 @@ def heuristic_action(obs: Dict[str, Any], task_id: str, step: int, max_hours: in
|
|
| 174 |
nighttime_do_risk = obs.get("nighttime_do_risk", 0.0)
|
| 175 |
feed_price = obs.get("feed_price_per_kg", 0.50)
|
| 176 |
hours_left = max_hours - step
|
| 177 |
-
wq_score = obs.get("water_quality_score", 0.8)
|
| 178 |
|
| 179 |
# ---- Aeration (proactive nighttime crash prevention) ----
|
| 180 |
algae_bloom = obs.get("algae_bloom", False)
|
|
@@ -787,7 +785,7 @@ async def async_main():
|
|
| 787 |
|
| 788 |
# Summary
|
| 789 |
print(f"\n{'='*60}")
|
| 790 |
-
print(
|
| 791 |
print(f"{'='*60}")
|
| 792 |
avg_score = sum(r["final_reward"] for r in results) / len(results) if results else 0
|
| 793 |
total_llm = sum(r["llm_calls"] for r in results)
|
|
|
|
| 24 |
import os
|
| 25 |
import time
|
| 26 |
import re
|
| 27 |
+
from typing import Any, Dict, List
|
|
|
|
| 28 |
|
| 29 |
import httpx
|
| 30 |
from openai import OpenAI
|
|
|
|
| 173 |
nighttime_do_risk = obs.get("nighttime_do_risk", 0.0)
|
| 174 |
feed_price = obs.get("feed_price_per_kg", 0.50)
|
| 175 |
hours_left = max_hours - step
|
|
|
|
| 176 |
|
| 177 |
# ---- Aeration (proactive nighttime crash prevention) ----
|
| 178 |
algae_bloom = obs.get("algae_bloom", False)
|
|
|
|
| 785 |
|
| 786 |
# Summary
|
| 787 |
print(f"\n{'='*60}")
|
| 788 |
+
print(" SUMMARY")
|
| 789 |
print(f"{'='*60}")
|
| 790 |
avg_score = sum(r["final_reward"] for r in results) / len(results) if results else 0
|
| 791 |
total_llm = sum(r["llm_calls"] for r in results)
|
inference_local.py
CHANGED
|
@@ -18,15 +18,13 @@ Usage:
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import argparse
|
| 21 |
-
import json
|
| 22 |
import os
|
| 23 |
-
import re
|
| 24 |
import time
|
| 25 |
-
from typing import Any, Dict, List
|
| 26 |
|
| 27 |
from src.agentic_rl.server.environment import FishFarmEnvironment
|
| 28 |
from src.agentic_rl.models import FarmAction
|
| 29 |
-
from src.agentic_rl.tasks import
|
| 30 |
from inference import (
|
| 31 |
heuristic_action,
|
| 32 |
build_observation_prompt,
|
|
@@ -233,10 +231,6 @@ def main():
|
|
| 233 |
|
| 234 |
# Time budget: 18 min total
|
| 235 |
total_budget_s = 18 * 60
|
| 236 |
-
total_hours = sum(
|
| 237 |
-
next(t["episode_hours"] for t in list_all_tasks() if t["task_id"] == tid)
|
| 238 |
-
for tid in task_ids
|
| 239 |
-
)
|
| 240 |
|
| 241 |
results = []
|
| 242 |
total_start = time.time()
|
|
@@ -262,7 +256,7 @@ def main():
|
|
| 262 |
|
| 263 |
# Summary
|
| 264 |
print(f"\n{'='*60}")
|
| 265 |
-
print(
|
| 266 |
print(f"{'='*60}")
|
| 267 |
avg_score = sum(r["score"] for r in results) / len(results) if results else 0
|
| 268 |
total_llm = sum(r["llm_calls"] for r in results)
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
import argparse
|
|
|
|
| 21 |
import os
|
|
|
|
| 22 |
import time
|
| 23 |
+
from typing import Any, Dict, List
|
| 24 |
|
| 25 |
from src.agentic_rl.server.environment import FishFarmEnvironment
|
| 26 |
from src.agentic_rl.models import FarmAction
|
| 27 |
+
from src.agentic_rl.tasks import list_all_tasks
|
| 28 |
from inference import (
|
| 29 |
heuristic_action,
|
| 30 |
build_observation_prompt,
|
|
|
|
| 231 |
|
| 232 |
# Time budget: 18 min total
|
| 233 |
total_budget_s = 18 * 60
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
| 235 |
results = []
|
| 236 |
total_start = time.time()
|
|
|
|
| 256 |
|
| 257 |
# Summary
|
| 258 |
print(f"\n{'='*60}")
|
| 259 |
+
print(" SUMMARY")
|
| 260 |
print(f"{'='*60}")
|
| 261 |
avg_score = sum(r["score"] for r in results) / len(results) if results else 0
|
| 262 |
total_llm = sum(r["llm_calls"] for r in results)
|
training/train_grpo.py
CHANGED
|
@@ -19,7 +19,7 @@ import os
|
|
| 19 |
|
| 20 |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
|
| 22 |
-
from rewards import CompositeReward
|
| 23 |
|
| 24 |
|
| 25 |
SYSTEM_PROMPT = """You are an expert Nile Tilapia aquaculture manager. Given the current state of a 100m³ RAS fish farm, decide the next hour's actions.
|
|
|
|
| 19 |
|
| 20 |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 21 |
|
| 22 |
+
from rewards import CompositeReward
|
| 23 |
|
| 24 |
|
| 25 |
SYSTEM_PROMPT = """You are an expert Nile Tilapia aquaculture manager. Given the current state of a 100m³ RAS fish farm, decide the next hour's actions.
|