Spaces:
Running
Running
ajaxwin commited on
Commit Β·
e5b8b13
1
Parent(s): dccaaac
refactor: Update API base URL and model name, enhance message handling with history tracking
Browse files- inference.py +21 -20
inference.py
CHANGED
|
@@ -26,6 +26,7 @@ import asyncio
|
|
| 26 |
import json
|
| 27 |
import os
|
| 28 |
import sys
|
|
|
|
| 29 |
from typing import Any, Dict, List, Optional, Callable, Awaitable, Union
|
| 30 |
|
| 31 |
from openai import AsyncOpenAI
|
|
@@ -40,8 +41,8 @@ from utils import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM
|
|
| 40 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 41 |
|
| 42 |
load_dotenv()
|
| 43 |
-
API_BASE_URL = os.getenv("API_BASE_URL", "https://
|
| 44 |
-
MODEL_NAME = os.getenv("MODEL_NAME", "")
|
| 45 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 46 |
|
| 47 |
if not HF_TOKEN:
|
|
@@ -52,10 +53,6 @@ if not MODEL_NAME:
|
|
| 52 |
|
| 53 |
client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
| 54 |
|
| 55 |
-
# from groq import AsyncGroq
|
| 56 |
-
# GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
| 57 |
-
# client = AsyncGroq(api_key=GROQ_API_KEY)
|
| 58 |
-
|
| 59 |
# Benchmark / environment identifier (constant for this env)
|
| 60 |
ENV_BENCHMARK = "smart-contract-audit"
|
| 61 |
|
|
@@ -151,18 +148,13 @@ async def run_episode(
|
|
| 151 |
default_action: ActionType = ActionType.LIST_FUNCTIONS,
|
| 152 |
extra_fields: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
| 153 |
) -> Dict[str, Any]:
|
| 154 |
-
"""
|
| 155 |
-
Run one episode with the given environment and task-specific parameters.
|
| 156 |
-
Emits [START]/[STEP]/[END] lines and returns a dict with episode results.
|
| 157 |
-
"""
|
| 158 |
r = env.reset(seed=seed)
|
| 159 |
obs = r.observation.model_dump()
|
| 160 |
-
|
| 161 |
log_start(task=task_id, env=ENV_BENCHMARK, model=MODEL_NAME)
|
| 162 |
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
step_rewards: List[float] = []
|
| 167 |
grader_score = 0.0
|
| 168 |
steps_taken = 0
|
|
@@ -170,29 +162,39 @@ async def run_episode(
|
|
| 170 |
|
| 171 |
try:
|
| 172 |
for step in range(1, MAX_STEPS + 1):
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
try:
|
| 175 |
-
raw = await get_llm_response(
|
| 176 |
error_msg = None
|
| 177 |
except Exception as e:
|
| 178 |
raw = ""
|
| 179 |
error_msg = str(e)[:80]
|
| 180 |
print(f"[DEBUG] {task_id} LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
|
| 181 |
|
|
|
|
|
|
|
|
|
|
| 182 |
try:
|
| 183 |
parsed = json.loads(raw)
|
| 184 |
at = ActionType(parsed["action"])
|
| 185 |
params = parsed.get("params", {})
|
| 186 |
except Exception as e:
|
| 187 |
at, params = default_action, {}
|
| 188 |
-
print("Error in parsing LLM
|
| 189 |
|
| 190 |
-
messages.append({"role": "assistant", "content": raw})
|
| 191 |
result = env.step(Action(action_type=at, params=params))
|
| 192 |
obs = result.observation.model_dump()
|
| 193 |
r_val = result.reward.value
|
| 194 |
done = result.done
|
| 195 |
-
|
| 196 |
step_rewards.append(r_val)
|
| 197 |
steps_taken = step
|
| 198 |
log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg)
|
|
@@ -215,7 +217,6 @@ async def run_episode(
|
|
| 215 |
}
|
| 216 |
if extra_fields:
|
| 217 |
result_dict.update(extra_fields(obs))
|
| 218 |
-
|
| 219 |
return result_dict
|
| 220 |
|
| 221 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 26 |
import json
|
| 27 |
import os
|
| 28 |
import sys
|
| 29 |
+
from collections import deque
|
| 30 |
from typing import Any, Dict, List, Optional, Callable, Awaitable, Union
|
| 31 |
|
| 32 |
from openai import AsyncOpenAI
|
|
|
|
| 41 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 42 |
|
| 43 |
load_dotenv()
|
| 44 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1/")
|
| 45 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "CohereLabs/tiny-aya-fire:cohere")
|
| 46 |
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
| 47 |
|
| 48 |
if not HF_TOKEN:
|
|
|
|
| 53 |
|
| 54 |
client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
|
| 55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
# Benchmark / environment identifier (constant for this env)
|
| 57 |
ENV_BENCHMARK = "smart-contract-audit"
|
| 58 |
|
|
|
|
| 148 |
default_action: ActionType = ActionType.LIST_FUNCTIONS,
|
| 149 |
extra_fields: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
|
| 150 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
r = env.reset(seed=seed)
|
| 152 |
obs = r.observation.model_dump()
|
|
|
|
| 153 |
log_start(task=task_id, env=ENV_BENCHMARK, model=MODEL_NAME)
|
| 154 |
|
| 155 |
+
# Keep only the last 2 user-assistant pairs (4 messages).
|
| 156 |
+
history: deque = deque(maxlen=4)
|
| 157 |
+
|
| 158 |
step_rewards: List[float] = []
|
| 159 |
grader_score = 0.0
|
| 160 |
steps_taken = 0
|
|
|
|
| 162 |
|
| 163 |
try:
|
| 164 |
for step in range(1, MAX_STEPS + 1):
|
| 165 |
+
user_msg_content = user_msg_formatter(obs)
|
| 166 |
+
user_message = {"role": "user", "content": user_msg_content}
|
| 167 |
+
history.append(user_message)
|
| 168 |
+
|
| 169 |
+
# Always prepend the system prompt so it survives deque eviction
|
| 170 |
+
messages_for_llm = [
|
| 171 |
+
{"role": "system", "content": system_prompt},
|
| 172 |
+
*list(history),
|
| 173 |
+
]
|
| 174 |
+
|
| 175 |
try:
|
| 176 |
+
raw = await get_llm_response(messages_for_llm, max_tokens=max_tokens, temperature=0.0)
|
| 177 |
error_msg = None
|
| 178 |
except Exception as e:
|
| 179 |
raw = ""
|
| 180 |
error_msg = str(e)[:80]
|
| 181 |
print(f"[DEBUG] {task_id} LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
|
| 182 |
|
| 183 |
+
# Append the assistant reply so the next step sees the full turn
|
| 184 |
+
history.append({"role": "assistant", "content": raw})
|
| 185 |
+
|
| 186 |
try:
|
| 187 |
parsed = json.loads(raw)
|
| 188 |
at = ActionType(parsed["action"])
|
| 189 |
params = parsed.get("params", {})
|
| 190 |
except Exception as e:
|
| 191 |
at, params = default_action, {}
|
| 192 |
+
print("Error in parsing LLM response: " + str(e))
|
| 193 |
|
|
|
|
| 194 |
result = env.step(Action(action_type=at, params=params))
|
| 195 |
obs = result.observation.model_dump()
|
| 196 |
r_val = result.reward.value
|
| 197 |
done = result.done
|
|
|
|
| 198 |
step_rewards.append(r_val)
|
| 199 |
steps_taken = step
|
| 200 |
log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg)
|
|
|
|
| 217 |
}
|
| 218 |
if extra_fields:
|
| 219 |
result_dict.update(extra_fields(obs))
|
|
|
|
| 220 |
return result_dict
|
| 221 |
|
| 222 |
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|