openenv / inference.py
Zenoharsh01's picture
Upload 10 files
621ef3a verified
import os
import json
import textwrap
from typing import List, Optional
from openai import OpenAI
from env import SOCEnvironment
from tasks import evaluate_environment
from models import Action
# 1. Load Hackathon Required Environment Variables
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("OPENAI_API_KEY", "dummy-token"))
BENCHMARK = "soc-analyst-simulator"
# Initialize the OpenAI Client
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
# --- STRICT LOGGING FUNCTIONS FORMATTED FOR HACKATHON GRADER ---
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# --- AGENT PIPELINE ---
def run_agent_on_task(env: SOCEnvironment, task_id: str) -> float:
observation = env.reset(task_id)
done = False
history: List[str] = []
step_rewards: List[float] = []
steps_taken = 0
error_msg = None
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
system_prompt = textwrap.dedent("""
You are an elite, autonomous Security Operations Center (SOC) Analyst.
Protect the network by analyzing SIEM logs and taking decisive action.
AVAILABLE ACTIONS: INVESTIGATE, BLOCK_IP, ISOLATE_HOST, DISMISS_ALERT, ESCALATE_TO_HUMAN
You MUST respond with ONLY valid JSON matching this schema:
{"action_type": "ACTION", "target_ip": "IP_ADDRESS_OR_null"}
""").strip()
while not done and steps_taken < 10:
steps_taken += 1
obs_json = observation.model_dump_json(indent=2)
# Inject "Short-Term Memory" to fix the infinite loop
history_block = "\n".join(history[-4:]) if history else "None"
user_prompt = f"Current State:\n{obs_json}\n\nPast Actions:\n{history_block}\n\nWhat is your next action? Return ONLY JSON."
action_str = "unknown"
reward_val = 0.0
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
response_format={"type": "json_object"},
temperature=0.1
)
raw_reply = response.choices[0].message.content
action_dict = json.loads(raw_reply)
action = Action(**action_dict)
action_str = f"{action.action_type.value}({action.target_ip})"
# Execute step
observation, reward_obj, done, info = env.step(action)
reward_val = reward_obj.score_delta
# Record memory
history.append(f"Step {steps_taken}: Action {action_str} -> Result: {info['msg']}")
except Exception as e:
error_msg = str(e)
action_str = "ERROR"
done = True
step_rewards.append(reward_val)
log_step(step=steps_taken, action=action_str, reward=reward_val, done=done, error=error_msg)
# Final Grader Evaluation
final_score = evaluate_environment(env, task_id)
success_bool = final_score > 0.0
log_end(success=success_bool, steps=steps_taken, score=final_score, rewards=step_rewards)
return final_score
if __name__ == "__main__":
env = SOCEnvironment()
tasks = ["task_1_triage", "task_2_false_positive", "task_3_kill_chain"]
total_score = 0.0
for task in tasks:
total_score += run_agent_on_task(env, task)