ajaxwin commited on
Commit
e5b8b13
Β·
1 Parent(s): dccaaac

refactor: Update API base URL and model name, enhance message handling with history tracking

Browse files
Files changed (1) hide show
  1. inference.py +21 -20
inference.py CHANGED
@@ -26,6 +26,7 @@ import asyncio
26
  import json
27
  import os
28
  import sys
 
29
  from typing import Any, Dict, List, Optional, Callable, Awaitable, Union
30
 
31
  from openai import AsyncOpenAI
@@ -40,8 +41,8 @@ from utils import T1_SYSTEM, T2_SYSTEM, T3_SYSTEM
40
  # ─────────────────────────────────────────────────────────────────────────────
41
 
42
  load_dotenv()
43
- API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
44
- MODEL_NAME = os.getenv("MODEL_NAME", "")
45
  HF_TOKEN = os.getenv("HF_TOKEN", "")
46
 
47
  if not HF_TOKEN:
@@ -52,10 +53,6 @@ if not MODEL_NAME:
52
 
53
  client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
54
 
55
- # from groq import AsyncGroq
56
- # GROQ_API_KEY = os.getenv("GROQ_API_KEY")
57
- # client = AsyncGroq(api_key=GROQ_API_KEY)
58
-
59
  # Benchmark / environment identifier (constant for this env)
60
  ENV_BENCHMARK = "smart-contract-audit"
61
 
@@ -151,18 +148,13 @@ async def run_episode(
151
  default_action: ActionType = ActionType.LIST_FUNCTIONS,
152
  extra_fields: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
153
  ) -> Dict[str, Any]:
154
- """
155
- Run one episode with the given environment and task-specific parameters.
156
- Emits [START]/[STEP]/[END] lines and returns a dict with episode results.
157
- """
158
  r = env.reset(seed=seed)
159
  obs = r.observation.model_dump()
160
-
161
  log_start(task=task_id, env=ENV_BENCHMARK, model=MODEL_NAME)
162
 
163
- messages: List[Dict[str, str]] = [
164
- {"role": "system", "content": system_prompt}
165
- ]
166
  step_rewards: List[float] = []
167
  grader_score = 0.0
168
  steps_taken = 0
@@ -170,29 +162,39 @@ async def run_episode(
170
 
171
  try:
172
  for step in range(1, MAX_STEPS + 1):
173
- messages.append({"role": "user", "content": user_msg_formatter(obs)})
 
 
 
 
 
 
 
 
 
174
  try:
175
- raw = await get_llm_response(messages, max_tokens=max_tokens, temperature=0.0)
176
  error_msg = None
177
  except Exception as e:
178
  raw = ""
179
  error_msg = str(e)[:80]
180
  print(f"[DEBUG] {task_id} LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
181
 
 
 
 
182
  try:
183
  parsed = json.loads(raw)
184
  at = ActionType(parsed["action"])
185
  params = parsed.get("params", {})
186
  except Exception as e:
187
  at, params = default_action, {}
188
- print("Error in parsing LLM respoonse: " + str(e))
189
 
190
- messages.append({"role": "assistant", "content": raw})
191
  result = env.step(Action(action_type=at, params=params))
192
  obs = result.observation.model_dump()
193
  r_val = result.reward.value
194
  done = result.done
195
-
196
  step_rewards.append(r_val)
197
  steps_taken = step
198
  log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg)
@@ -215,7 +217,6 @@ async def run_episode(
215
  }
216
  if extra_fields:
217
  result_dict.update(extra_fields(obs))
218
-
219
  return result_dict
220
 
221
  # ─────────────────────────────────────────────────────────────────────────────
 
26
  import json
27
  import os
28
  import sys
29
+ from collections import deque
30
  from typing import Any, Dict, List, Optional, Callable, Awaitable, Union
31
 
32
  from openai import AsyncOpenAI
 
41
  # ─────────────────────────────────────────────────────────────────────────────
42
 
43
  load_dotenv()
44
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1/")
45
+ MODEL_NAME = os.getenv("MODEL_NAME", "CohereLabs/tiny-aya-fire:cohere")
46
  HF_TOKEN = os.getenv("HF_TOKEN", "")
47
 
48
  if not HF_TOKEN:
 
53
 
54
  client = AsyncOpenAI(api_key=HF_TOKEN, base_url=API_BASE_URL)
55
 
 
 
 
 
56
  # Benchmark / environment identifier (constant for this env)
57
  ENV_BENCHMARK = "smart-contract-audit"
58
 
 
148
  default_action: ActionType = ActionType.LIST_FUNCTIONS,
149
  extra_fields: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
150
  ) -> Dict[str, Any]:
 
 
 
 
151
  r = env.reset(seed=seed)
152
  obs = r.observation.model_dump()
 
153
  log_start(task=task_id, env=ENV_BENCHMARK, model=MODEL_NAME)
154
 
155
+ # Keep only the last 2 user-assistant pairs (4 messages).
156
+ history: deque = deque(maxlen=4)
157
+
158
  step_rewards: List[float] = []
159
  grader_score = 0.0
160
  steps_taken = 0
 
162
 
163
  try:
164
  for step in range(1, MAX_STEPS + 1):
165
+ user_msg_content = user_msg_formatter(obs)
166
+ user_message = {"role": "user", "content": user_msg_content}
167
+ history.append(user_message)
168
+
169
+ # Always prepend the system prompt so it survives deque eviction
170
+ messages_for_llm = [
171
+ {"role": "system", "content": system_prompt},
172
+ *list(history),
173
+ ]
174
+
175
  try:
176
+ raw = await get_llm_response(messages_for_llm, max_tokens=max_tokens, temperature=0.0)
177
  error_msg = None
178
  except Exception as e:
179
  raw = ""
180
  error_msg = str(e)[:80]
181
  print(f"[DEBUG] {task_id} LLM error ep={ep_num} step={step}: {e}", file=sys.stderr)
182
 
183
+ # Append the assistant reply so the next step sees the full turn
184
+ history.append({"role": "assistant", "content": raw})
185
+
186
  try:
187
  parsed = json.loads(raw)
188
  at = ActionType(parsed["action"])
189
  params = parsed.get("params", {})
190
  except Exception as e:
191
  at, params = default_action, {}
192
+ print("Error in parsing LLM response: " + str(e))
193
 
 
194
  result = env.step(Action(action_type=at, params=params))
195
  obs = result.observation.model_dump()
196
  r_val = result.reward.value
197
  done = result.done
 
198
  step_rewards.append(r_val)
199
  steps_taken = step
200
  log_step(step=step, action=at.value, reward=r_val, done=done, error=error_msg)
 
217
  }
218
  if extra_fields:
219
  result_dict.update(extra_fields(obs))
 
220
  return result_dict
221
 
222
  # ─────────────────────────────────────────────────────────────────────────────