sai1912 commited on
Commit
f7739cd
·
verified ·
1 Parent(s): 32cbf02

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. inference.py +65 -29
  2. requirements.txt +2 -0
inference.py CHANGED
@@ -1,14 +1,17 @@
1
  import asyncio
 
2
  import os
3
  import textwrap
 
4
  from typing import List, Optional
5
 
6
- from openai import OpenAI
 
 
 
 
7
 
8
  from my_env_v4 import MyEnvV4Action, MyEnvV4Env
9
- from dotenv import load_dotenv
10
-
11
- load_dotenv(override=True)
12
 
13
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
14
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
@@ -71,30 +74,49 @@ def build_user_prompt(step: int, last_echoed: str, last_reward: float, history:
71
  ).strip()
72
 
73
 
74
- def get_model_message(client: OpenAI, step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
75
  user_prompt = build_user_prompt(step, last_echoed, last_reward, history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  try:
77
- completion = client.chat.completions.create(
78
- model=MODEL_NAME,
79
- messages=[
80
- {"role": "system", "content": SYSTEM_PROMPT},
81
- {"role": "user", "content": user_prompt},
82
- ],
83
- temperature=TEMPERATURE,
84
- max_tokens=MAX_TOKENS,
85
- stream=False,
86
  )
87
- text = (completion.choices[0].message.content or "").strip()
88
- return text if text else "hello"
 
 
89
  except Exception as exc:
90
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
91
  return "hello"
92
 
93
 
94
  async def main() -> None:
95
- client = OpenAI(base_url=API_BASE_URL, api_key=_api_key)
96
-
97
- env = await MyEnvV4Env.from_docker_image(LOCAL_IMAGE_NAME)
 
 
98
 
99
  history: List[str] = []
100
  rewards: List[float] = []
@@ -105,26 +127,37 @@ async def main() -> None:
105
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
106
 
107
  try:
108
- result = await env.reset() # OpenENV.reset()
109
- last_echoed = result.observation.echoed_message
 
 
 
 
 
 
110
  last_reward = 0.0
111
 
112
  for step in range(1, MAX_STEPS + 1):
113
- if result.done:
114
  break
115
 
116
- message = get_model_message(client, step, last_echoed, last_reward, history)
117
 
118
- result = await env.step(MyEnvV4Action(message=message))
119
- obs = result.observation
 
 
 
 
120
 
121
- reward = result.reward or 0.0
122
- done = result.done
 
123
  error = getattr(result, "error", None)
124
 
125
  rewards.append(reward)
126
  steps_taken = step
127
- last_echoed = obs.echoed_message
128
  last_reward = reward
129
 
130
  # Formatting action to avoid newlines breaking stdout tracking format rules
@@ -148,4 +181,7 @@ async def main() -> None:
148
 
149
 
150
  if __name__ == "__main__":
151
- asyncio.run(main())
 
 
 
 
1
  import asyncio
2
+ import json
3
  import os
4
  import textwrap
5
+ import urllib.request
6
  from typing import List, Optional
7
 
8
+ try:
9
+ from dotenv import load_dotenv
10
+ load_dotenv(override=True)
11
+ except ImportError:
12
+ pass
13
 
14
  from my_env_v4 import MyEnvV4Action, MyEnvV4Env
 
 
 
15
 
16
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
17
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
 
74
  ).strip()
75
 
76
 
77
+ def get_model_message(api_key: str, step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
78
  user_prompt = build_user_prompt(step, last_echoed, last_reward, history)
79
+
80
+ headers = {
81
+ "Content-Type": "application/json",
82
+ }
83
+ if api_key:
84
+ headers["Authorization"] = f"Bearer {api_key}"
85
+
86
+ data = {
87
+ "model": MODEL_NAME,
88
+ "messages": [
89
+ {"role": "system", "content": SYSTEM_PROMPT},
90
+ {"role": "user", "content": user_prompt},
91
+ ],
92
+ "temperature": TEMPERATURE,
93
+ "max_tokens": MAX_TOKENS,
94
+ "stream": False
95
+ }
96
+
97
  try:
98
+ url = f"{API_BASE_URL.rstrip('/')}/chat/completions"
99
+ req = urllib.request.Request(
100
+ url,
101
+ headers=headers,
102
+ data=json.dumps(data).encode("utf-8"),
103
+ method="POST"
 
 
 
104
  )
105
+ with urllib.request.urlopen(req, timeout=30) as response:
106
+ result = json.loads(response.read().decode("utf-8"))
107
+ text = (result.get("choices", [{}])[0].get("message", {}).get("content") or "").strip()
108
+ return text if text else "hello"
109
  except Exception as exc:
110
  print(f"[DEBUG] Model request failed: {exc}", flush=True)
111
  return "hello"
112
 
113
 
114
  async def main() -> None:
115
+ try:
116
+ env = await MyEnvV4Env.from_docker_image(LOCAL_IMAGE_NAME)
117
+ except Exception as e:
118
+ print(f"[DEBUG] env init failed: {e}", flush=True)
119
+ return
120
 
121
  history: List[str] = []
122
  rewards: List[float] = []
 
127
  log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
128
 
129
  try:
130
+ try:
131
+ result = await env.reset() # OpenENV.reset()
132
+ except Exception as e:
133
+ print(f"[DEBUG] env.reset() failed: {e}", flush=True)
134
+ log_end(success=False, steps=0, rewards=[])
135
+ return
136
+
137
+ last_echoed = getattr(result.observation, "echoed_message", "")
138
  last_reward = 0.0
139
 
140
  for step in range(1, MAX_STEPS + 1):
141
+ if getattr(result, "done", False):
142
  break
143
 
144
+ message = get_model_message(_api_key or "", step, last_echoed, last_reward, history)
145
 
146
+ try:
147
+ result = await env.step(MyEnvV4Action(message=message))
148
+ except Exception as e:
149
+ print(f"[DEBUG] env.step() failed: {e}", flush=True)
150
+ log_step(step=step, action=repr(message), reward=0.0, done=True, error=str(e))
151
+ break
152
 
153
+ obs = result.observation
154
+ reward = getattr(result, "reward", 0.0) or 0.0
155
+ done = getattr(result, "done", False)
156
  error = getattr(result, "error", None)
157
 
158
  rewards.append(reward)
159
  steps_taken = step
160
+ last_echoed = getattr(obs, "echoed_message", "")
161
  last_reward = reward
162
 
163
  # Formatting action to avoid newlines breaking stdout tracking format rules
 
181
 
182
 
183
  if __name__ == "__main__":
184
+ try:
185
+ asyncio.run(main())
186
+ except Exception as e:
187
+ print(f"[DEBUG] main() unhandled exception {e}", flush=True)
requirements.txt CHANGED
@@ -2,3 +2,5 @@ fastapi
2
  uvicorn
3
  pydantic
4
  duckdb
 
 
 
2
  uvicorn
3
  pydantic
4
  duckdb
5
+ openai
6
+ python-dotenv