CivicAI / scripts /baseline_inference.py
mahammadaftab's picture
Initial Update
315caa2
"""
CivicAI Baseline Inference Script
Uses OpenAI GPT-4o-mini to make policy decisions.
Falls back to rule-based agent if no API key is available.
"""
from __future__ import annotations
import json
import os
import sys
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from civicai.environment import CivicAIEnv
from civicai.models import Action, SubsidyPolicy
def parse_action(response_text: str) -> Action:
"""Parse LLM response into an Action. Falls back to defaults."""
try:
# Try JSON parse
text = response_text.strip()
if "```json" in text:
text = text.split("```json")[1].split("```")[0]
elif "```" in text:
text = text.split("```")[1].split("```")[0]
data = json.loads(text)
return Action(
tax_rate=max(0, min(1, float(data.get("tax_rate", 0.25)))),
healthcare_budget=max(0, min(1, float(data.get("healthcare_budget", 0.20)))),
education_budget=max(0, min(1, float(data.get("education_budget", 0.15)))),
police_budget=max(0, min(1, float(data.get("police_budget", 0.10)))),
subsidy_policy=SubsidyPolicy(data.get("subsidy_policy", "none")),
emergency_response=data.get("emergency_response"),
)
except Exception:
return Action() # Use defaults
def build_prompt(obs_dict: dict) -> str:
"""Build a structured prompt for the LLM."""
return f"""You are an AI policy advisor managing a society. Analyze the current state and decide on policy actions.
CURRENT STATE:
- Turn: {obs_dict['turn']}/50
- Population: {obs_dict['population']:,}
- Employment Rate: {obs_dict['employment_rate']:.1%}
- Inflation: {obs_dict['inflation']:.1%}
- Public Satisfaction: {obs_dict['public_satisfaction']:.1%}
- Health Index: {obs_dict['health_index']:.1%}
- Crime Rate: {obs_dict['crime_rate']:.1%}
- GDP: ${obs_dict['gdp']:.1f}B
- Budget Balance: {obs_dict['budget_balance']:.1%}
- Resources: {json.dumps(obs_dict['resources'], indent=2)}
- Active Events: {obs_dict['active_events']}
Respond with ONLY a JSON object (no other text):
{{
"tax_rate": 0.0-1.0,
"healthcare_budget": 0.0-1.0,
"education_budget": 0.0-1.0,
"police_budget": 0.0-1.0,
"subsidy_policy": "none|agriculture|industry|technology",
"emergency_response": null or "lockdown|stimulus|open"
}}"""
def run_llm_agent(task_id: str = "stabilize_economy") -> dict:
"""Run baseline with OpenAI GPT-4o-mini."""
from openai import OpenAI
client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
env = CivicAIEnv()
obs = env.reset(task_id)
total_reward = 0.0
rewards = []
steps = []
print(f"\n{'='*60}")
print(f" CivicAI Baseline β€” Task: {task_id}")
print(f" Model: GPT-4o-mini")
print(f"{'='*60}\n")
for turn in range(50):
prompt = build_prompt(obs.model_dump())
response = client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": prompt}],
temperature=0.3,
)
action = parse_action(response.choices[0].message.content or "")
obs, reward, done, info = env.step(action)
total_reward += reward
rewards.append(reward)
steps.append({
"turn": turn,
"action": action.model_dump(),
"reward": reward,
"obs": obs.model_dump(),
})
print(f" Turn {turn:2d} | Reward: {reward:.3f} | "
f"Emp: {obs.employment_rate:.1%} | Inf: {obs.inflation:.1%} | "
f"Sat: {obs.public_satisfaction:.1%} | Crime: {obs.crime_rate:.1%}")
if done:
print(f"\n Episode ended: {info.get('termination_reason', 'max_steps')}")
break
print(f"\n{'='*60}")
print(f" Total Reward: {total_reward:.4f}")
print(f" Avg Reward: {total_reward / len(rewards):.4f}")
print(f" Steps: {len(rewards)}")
print(f"{'='*60}\n")
return {
"task_id": task_id,
"total_reward": total_reward,
"avg_reward": total_reward / len(rewards),
"steps": len(rewards),
"reward_curve": rewards,
"step_details": steps,
}
def run_rule_agent(task_id: str = "stabilize_economy") -> dict:
"""Run baseline with multi-agent rule-based system (no API key needed)."""
from agents.orchestrator import Orchestrator
env = CivicAIEnv()
orch = Orchestrator(env)
print(f"\n{'='*60}")
print(f" CivicAI Baseline β€” Task: {task_id}")
print(f" Model: Multi-Agent Rule System")
print(f"{'='*60}\n")
result = orch.run_episode(task_id)
for i, r in enumerate(result["reward_curve"]):
obs = result.get("step_log", [{}])[i] if i < len(result.get("step_log", [])) else {}
print(f" Turn {i:2d} | Reward: {r:.3f}")
print(f"\n{'='*60}")
print(f" Total Reward: {result['total_reward']:.4f}")
print(f" Avg Reward: {result['avg_reward']:.4f}")
print(f" Steps: {result['steps']}")
print(f"{'='*60}\n")
# Emergent insights
summary = result.get("emergent_summary", {})
if summary.get("key_insights"):
print(" 🧠 Emergent Insights:")
for insight in summary["key_insights"]:
print(f" β†’ {insight}")
if summary.get("patterns"):
print(" πŸ“Š Patterns Detected:")
for p in summary["patterns"]:
print(f" β†’ {p}")
return result
def run_random_agent(task_id: str = "stabilize_economy") -> dict:
"""Run baseline with random actions."""
import random
env = CivicAIEnv()
obs = env.reset(task_id)
total_reward = 0.0
rewards = []
print(f"\n{'='*60}")
print(f" CivicAI Baseline β€” Task: {task_id}")
print(f" Model: Random Agent")
print(f"{'='*60}\n")
for turn in range(50):
action = Action(
tax_rate=random.uniform(0.1, 0.5),
healthcare_budget=random.uniform(0.05, 0.4),
education_budget=random.uniform(0.05, 0.3),
police_budget=random.uniform(0.03, 0.2),
subsidy_policy=random.choice(list(SubsidyPolicy)),
)
obs, reward, done, info = env.step(action)
total_reward += reward
rewards.append(reward)
if done:
break
print(f" Total Reward: {total_reward:.4f}")
print(f" Avg Reward: {total_reward / max(1, len(rewards)):.4f}")
return {
"task_id": task_id,
"total_reward": total_reward,
"avg_reward": total_reward / max(1, len(rewards)),
"steps": len(rewards),
"reward_curve": rewards,
}
if __name__ == "__main__":
task = sys.argv[1] if len(sys.argv) > 1 else "stabilize_economy"
if os.getenv("OPENAI_API_KEY"):
run_llm_agent(task)
else:
print(" No OPENAI_API_KEY found. Running rule-based agent.\n")
run_rule_agent(task)