Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import textwrap | |
| from typing import List, Optional | |
| from openai import OpenAI | |
| from env import SOCEnvironment | |
| from tasks import evaluate_environment | |
| from models import Action | |
| # 1. Load Hackathon Required Environment Variables | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") | |
| HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("OPENAI_API_KEY", "dummy-token")) | |
| BENCHMARK = "soc-analyst-simulator" | |
| # Initialize the OpenAI Client | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| # --- STRICT LOGGING FUNCTIONS FORMATTED FOR HACKATHON GRADER --- | |
| def log_start(task: str, env: str, model: str) -> None: | |
| print(f"[START] task={task} env={env} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| error_val = error if error else "null" | |
| done_val = str(done).lower() | |
| print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) | |
| # --- AGENT PIPELINE --- | |
| def run_agent_on_task(env: SOCEnvironment, task_id: str) -> float: | |
| observation = env.reset(task_id) | |
| done = False | |
| history: List[str] = [] | |
| step_rewards: List[float] = [] | |
| steps_taken = 0 | |
| error_msg = None | |
| log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) | |
| system_prompt = textwrap.dedent(""" | |
| You are an elite, autonomous Security Operations Center (SOC) Analyst. | |
| Protect the network by analyzing SIEM logs and taking decisive action. | |
| AVAILABLE ACTIONS: INVESTIGATE, BLOCK_IP, ISOLATE_HOST, DISMISS_ALERT, ESCALATE_TO_HUMAN | |
| You MUST respond with ONLY valid JSON matching this schema: | |
| {"action_type": "ACTION", "target_ip": "IP_ADDRESS_OR_null"} | |
| """).strip() | |
| while not done and steps_taken < 10: | |
| steps_taken += 1 | |
| obs_json = observation.model_dump_json(indent=2) | |
| # Inject "Short-Term Memory" to fix the infinite loop | |
| history_block = "\n".join(history[-4:]) if history else "None" | |
| user_prompt = f"Current State:\n{obs_json}\n\nPast Actions:\n{history_block}\n\nWhat is your next action? Return ONLY JSON." | |
| action_str = "unknown" | |
| reward_val = 0.0 | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt} | |
| ], | |
| response_format={"type": "json_object"}, | |
| temperature=0.1 | |
| ) | |
| raw_reply = response.choices[0].message.content | |
| action_dict = json.loads(raw_reply) | |
| action = Action(**action_dict) | |
| action_str = f"{action.action_type.value}({action.target_ip})" | |
| # Execute step | |
| observation, reward_obj, done, info = env.step(action) | |
| reward_val = reward_obj.score_delta | |
| # Record memory | |
| history.append(f"Step {steps_taken}: Action {action_str} -> Result: {info['msg']}") | |
| except Exception as e: | |
| error_msg = str(e) | |
| action_str = "ERROR" | |
| done = True | |
| step_rewards.append(reward_val) | |
| log_step(step=steps_taken, action=action_str, reward=reward_val, done=done, error=error_msg) | |
| # Final Grader Evaluation | |
| final_score = evaluate_environment(env, task_id) | |
| success_bool = final_score > 0.0 | |
| log_end(success=success_bool, steps=steps_taken, score=final_score, rewards=step_rewards) | |
| return final_score | |
| if __name__ == "__main__": | |
| env = SOCEnvironment() | |
| tasks = ["task_1_triage", "task_2_false_positive", "task_3_kill_chain"] | |
| total_score = 0.0 | |
| for task in tasks: | |
| total_score += run_agent_on_task(env, task) |