ritvik360 commited on
Commit
23a01fc
·
verified ·
1 Parent(s): c071eb6

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +3 -2
inference.py CHANGED
@@ -190,8 +190,9 @@ async def run_task(client: OpenAI, env, task_name: str) -> dict:
190
 
191
  sql = call_llm(client, user_prompt)
192
 
193
- from client import NL2SQLAction # local to avoid circular at module level
194
- result = await env.step({"query": sql}) # changed
 
195
  obs = result.observation
196
 
197
  reward = obs.reward or 0.0
 
190
 
191
  sql = call_llm(client, user_prompt)
192
 
193
+ from models import NL2SQLAction # local to avoid circular at module level
194
+ action = NL2SQLAction(query=sql)
195
+ result = await env.step(action)
196
  obs = result.observation
197
 
198
  reward = obs.reward or 0.0