Spaces:
Paused
Paused
| from __future__ import annotations | |
| import json | |
| import os | |
| import sys | |
| from typing import Any | |
| import requests | |
| from osint_env.baselines.openai_runner import SYSTEM_PROMPT, build_action_tools | |
| from osint_env.llm.interface import OllamaLLMClient | |
| SPACE_URL = os.getenv("SPACE_URL", "https://siddeshwar1625-osint.hf.space").rstrip("/") | |
| OLLAMA_BASE = os.getenv("OLLAMA_BASE_URL", "http://127.0.0.1:11434").rstrip("/") | |
| MODEL = os.getenv("OLLAMA_MODEL", "qwen3:2b") | |
| MAX_STEPS = int(os.getenv("MAX_STEPS", "8")) | |
| REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "90")) | |
| TASK_INDICES = [int(x.strip()) for x in os.getenv("TASK_INDICES", "0").split(",") if x.strip()] | |
| def _message_text(message: Any) -> str: | |
| content = getattr(message, "content", "") | |
| if isinstance(content, str): | |
| return content | |
| if isinstance(content, list): | |
| parts: list[str] = [] | |
| for item in content: | |
| if isinstance(item, dict) and item.get("type") == "text": | |
| parts.append(str(item.get("text", ""))) | |
| return "\n".join(part for part in parts if part) | |
| return str(content or "") | |
| def _assistant_tool_call_id(message: dict[str, Any]) -> str | None: | |
| tool_calls = list(message.get("tool_calls", [])) | |
| if not tool_calls: | |
| return None | |
| tool_call_id = tool_calls[0].get("id") | |
| return str(tool_call_id) if tool_call_id else None | |
| def _tool_result_message(assistant_message: dict[str, Any], result: dict[str, Any]) -> dict[str, Any] | None: | |
| tool_call_id = _assistant_tool_call_id(assistant_message) | |
| if not tool_call_id: | |
| return None | |
| return { | |
| "role": "tool", | |
| "tool_call_id": tool_call_id, | |
| "content": json.dumps(result, sort_keys=True), | |
| } | |
| def _decode_action(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: | |
| if tool_name == "submit_answer": | |
| return {"action_type": "ANSWER", "payload": {"answer": str(args.get("answer", "")).strip()}} | |
| if tool_name == "add_edge": | |
| return { | |
| "action_type": "ADD_EDGE", | |
| "payload": { | |
| "src": str(args.get("src", "")).strip(), | |
| "rel": str(args.get("rel", "")).strip(), | |
| "dst": str(args.get("dst", "")).strip(), | |
| "confidence": float(args.get("confidence", 1.0)), | |
| }, | |
| } | |
| return {"action_type": "CALL_TOOL", "payload": {"tool_name": tool_name, "args": dict(args)}} | |
| def _format_action(action: dict[str, Any]) -> str: | |
| action_type = str(action.get("action_type", "")) | |
| payload = dict(action.get("payload", {})) | |
| if action_type == "ANSWER": | |
| return f"answer({payload.get('answer', 'unknown')})" | |
| if action_type == "ADD_EDGE": | |
| return ( | |
| "add_edge(" | |
| f"{payload.get('src', '')}," | |
| f"{payload.get('rel', '')}," | |
| f"{payload.get('dst', '')}," | |
| f"{float(payload.get('confidence', 1.0)):.2f}" | |
| ")" | |
| ) | |
| tool_name = str(payload.get("tool_name", "tool")) | |
| args = dict(payload.get("args", {})) | |
| if not args: | |
| return f"{tool_name}()" | |
| arg_str = ",".join(f"{key}={value}" for key, value in sorted(args.items())) | |
| return f"{tool_name}({arg_str})" | |
| def get_model_action(client: OllamaLLMClient, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> tuple[dict[str, Any], dict[str, Any]]: | |
| llm_resp = client.generate(messages, tools) | |
| content = llm_resp.content or "" | |
| tool_calls = list(llm_resp.tool_calls or []) | |
| if not tool_calls: | |
| return {"action_type": "ANSWER", "payload": {"answer": content.strip() or "unknown"}}, { | |
| "role": "assistant", | |
| "content": content, | |
| } | |
| tool_call = tool_calls[0] | |
| tool_name = str(tool_call.get("tool_name", "")) | |
| args = dict(tool_call.get("args", {})) | |
| assistant_message = { | |
| "role": "assistant", | |
| "content": content, | |
| "tool_calls": [ | |
| { | |
| "id": "local", | |
| "type": "function", | |
| "function": {"name": tool_name, "arguments": json.dumps(args, sort_keys=True)}, | |
| } | |
| ], | |
| } | |
| return _decode_action(tool_name, args), assistant_message | |
| def main() -> None: | |
| try: | |
| ping = requests.get(f"{SPACE_URL}/healthz", timeout=REQUEST_TIMEOUT) | |
| ping.raise_for_status() | |
| print(f"Space health: {ping.json()}") | |
| except Exception as exc: | |
| raise SystemExit(f"Space health check failed: {exc}") from exc | |
| client = OllamaLLMClient(model=MODEL, base_url=OLLAMA_BASE, timeout_seconds=REQUEST_TIMEOUT) | |
| tools = build_action_tools() | |
| for task_index in TASK_INDICES: | |
| print(f"Resetting task {task_index} via {SPACE_URL}/openenv/reset") | |
| resp = requests.post(f"{SPACE_URL}/openenv/reset", json={"task_index": task_index}, timeout=REQUEST_TIMEOUT) | |
| resp.raise_for_status() | |
| data = resp.json() | |
| session_id = str(data.get("session_id")) | |
| observation = data.get("observation", {}) | |
| messages: list[dict[str, Any]] = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": json.dumps(observation, indent=2, sort_keys=True)}, | |
| ] | |
| done = bool(data.get("done", False)) | |
| step = 0 | |
| rewards: list[float] = [] | |
| while not done and step < MAX_STEPS: | |
| step += 1 | |
| action, assistant_message = get_model_action(client, messages, tools) | |
| error = None | |
| try: | |
| result = requests.post( | |
| f"{SPACE_URL}/openenv/step", | |
| json={ | |
| "session_id": session_id, | |
| "action_type": action["action_type"], | |
| "payload": action["payload"], | |
| }, | |
| timeout=REQUEST_TIMEOUT, | |
| ) | |
| result.raise_for_status() | |
| result = result.json() | |
| except Exception as exc: | |
| error = str(exc) | |
| print(f"Step {step}: request failed: {error}") | |
| break | |
| reward = float(result.get("reward", 0.0) or 0.0) | |
| done = bool(result.get("done", False)) | |
| rewards.append(reward) | |
| print(f"Step {step}: action={_format_action(action)} reward={reward:.3f} done={done} error={error}") | |
| messages.append(assistant_message) | |
| tool_message = _tool_result_message(assistant_message, result) | |
| if tool_message is not None: | |
| messages.append(tool_message) | |
| print(f"Episode finished. steps={step} total_reward={sum(rewards):.3f} rewards={rewards}") | |
| if __name__ == "__main__": | |
| try: | |
| main() | |
| except KeyboardInterrupt: | |
| print("Interrupted", file=sys.stderr) | |
| sys.exit(1) | |