OSINT1 / scripts /test_ollama_space.py
siddeshwar-kagatikar
Deploy clean snapshot to Hugging Face Space.
db4fa53
from __future__ import annotations
import json
import os
import sys
from typing import Any
import requests
from osint_env.baselines.openai_runner import SYSTEM_PROMPT, build_action_tools
from osint_env.llm.interface import OllamaLLMClient
SPACE_URL = os.getenv("SPACE_URL", "https://siddeshwar1625-osint.hf.space").rstrip("/")
OLLAMA_BASE = os.getenv("OLLAMA_BASE_URL", "http://127.0.0.1:11434").rstrip("/")
MODEL = os.getenv("OLLAMA_MODEL", "qwen3:2b")
MAX_STEPS = int(os.getenv("MAX_STEPS", "8"))
REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "90"))
TASK_INDICES = [int(x.strip()) for x in os.getenv("TASK_INDICES", "0").split(",") if x.strip()]
def _message_text(message: Any) -> str:
content = getattr(message, "content", "")
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, dict) and item.get("type") == "text":
parts.append(str(item.get("text", "")))
return "\n".join(part for part in parts if part)
return str(content or "")
def _assistant_tool_call_id(message: dict[str, Any]) -> str | None:
tool_calls = list(message.get("tool_calls", []))
if not tool_calls:
return None
tool_call_id = tool_calls[0].get("id")
return str(tool_call_id) if tool_call_id else None
def _tool_result_message(assistant_message: dict[str, Any], result: dict[str, Any]) -> dict[str, Any] | None:
tool_call_id = _assistant_tool_call_id(assistant_message)
if not tool_call_id:
return None
return {
"role": "tool",
"tool_call_id": tool_call_id,
"content": json.dumps(result, sort_keys=True),
}
def _decode_action(tool_name: str, args: dict[str, Any]) -> dict[str, Any]:
if tool_name == "submit_answer":
return {"action_type": "ANSWER", "payload": {"answer": str(args.get("answer", "")).strip()}}
if tool_name == "add_edge":
return {
"action_type": "ADD_EDGE",
"payload": {
"src": str(args.get("src", "")).strip(),
"rel": str(args.get("rel", "")).strip(),
"dst": str(args.get("dst", "")).strip(),
"confidence": float(args.get("confidence", 1.0)),
},
}
return {"action_type": "CALL_TOOL", "payload": {"tool_name": tool_name, "args": dict(args)}}
def _format_action(action: dict[str, Any]) -> str:
action_type = str(action.get("action_type", ""))
payload = dict(action.get("payload", {}))
if action_type == "ANSWER":
return f"answer({payload.get('answer', 'unknown')})"
if action_type == "ADD_EDGE":
return (
"add_edge("
f"{payload.get('src', '')},"
f"{payload.get('rel', '')},"
f"{payload.get('dst', '')},"
f"{float(payload.get('confidence', 1.0)):.2f}"
")"
)
tool_name = str(payload.get("tool_name", "tool"))
args = dict(payload.get("args", {}))
if not args:
return f"{tool_name}()"
arg_str = ",".join(f"{key}={value}" for key, value in sorted(args.items()))
return f"{tool_name}({arg_str})"
def get_model_action(client: OllamaLLMClient, messages: list[dict[str, Any]], tools: list[dict[str, Any]]) -> tuple[dict[str, Any], dict[str, Any]]:
llm_resp = client.generate(messages, tools)
content = llm_resp.content or ""
tool_calls = list(llm_resp.tool_calls or [])
if not tool_calls:
return {"action_type": "ANSWER", "payload": {"answer": content.strip() or "unknown"}}, {
"role": "assistant",
"content": content,
}
tool_call = tool_calls[0]
tool_name = str(tool_call.get("tool_name", ""))
args = dict(tool_call.get("args", {}))
assistant_message = {
"role": "assistant",
"content": content,
"tool_calls": [
{
"id": "local",
"type": "function",
"function": {"name": tool_name, "arguments": json.dumps(args, sort_keys=True)},
}
],
}
return _decode_action(tool_name, args), assistant_message
def main() -> None:
try:
ping = requests.get(f"{SPACE_URL}/healthz", timeout=REQUEST_TIMEOUT)
ping.raise_for_status()
print(f"Space health: {ping.json()}")
except Exception as exc:
raise SystemExit(f"Space health check failed: {exc}") from exc
client = OllamaLLMClient(model=MODEL, base_url=OLLAMA_BASE, timeout_seconds=REQUEST_TIMEOUT)
tools = build_action_tools()
for task_index in TASK_INDICES:
print(f"Resetting task {task_index} via {SPACE_URL}/openenv/reset")
resp = requests.post(f"{SPACE_URL}/openenv/reset", json={"task_index": task_index}, timeout=REQUEST_TIMEOUT)
resp.raise_for_status()
data = resp.json()
session_id = str(data.get("session_id"))
observation = data.get("observation", {})
messages: list[dict[str, Any]] = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": json.dumps(observation, indent=2, sort_keys=True)},
]
done = bool(data.get("done", False))
step = 0
rewards: list[float] = []
while not done and step < MAX_STEPS:
step += 1
action, assistant_message = get_model_action(client, messages, tools)
error = None
try:
result = requests.post(
f"{SPACE_URL}/openenv/step",
json={
"session_id": session_id,
"action_type": action["action_type"],
"payload": action["payload"],
},
timeout=REQUEST_TIMEOUT,
)
result.raise_for_status()
result = result.json()
except Exception as exc:
error = str(exc)
print(f"Step {step}: request failed: {error}")
break
reward = float(result.get("reward", 0.0) or 0.0)
done = bool(result.get("done", False))
rewards.append(reward)
print(f"Step {step}: action={_format_action(action)} reward={reward:.3f} done={done} error={error}")
messages.append(assistant_message)
tool_message = _tool_result_message(assistant_message, result)
if tool_message is not None:
messages.append(tool_message)
print(f"Episode finished. steps={step} total_reward={sum(rewards):.3f} rewards={rewards}")
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt:
print("Interrupted", file=sys.stderr)
sys.exit(1)