Draken1606 commited on
Commit
16229c6
·
1 Parent(s): 986ebb9

Fix inference env/auth handling and validator output format

Browse files
Files changed (2) hide show
  1. README.md +15 -1
  2. inference.py +18 -12
README.md CHANGED
@@ -55,7 +55,21 @@ python inference.py --difficulty easy
55
  python inference.py --url http://127.0.0.1:7860 --difficulty all
56
  ```
57
 
58
- For LLM mode, set `HF_TOKEN` first.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  ## Docker
61
 
 
55
  python inference.py --url http://127.0.0.1:7860 --difficulty all
56
  ```
57
 
58
+ LLM mode is enabled by default in `inference.py` and requires:
59
+
60
+ ```bash
61
+ export API_BASE_URL="https://api.openai.com/v1" # or validator-provided proxy URL
62
+ export OPENAI_API_KEY="your-validator-provided-token"
63
+ ```
64
+
65
+ `MODEL_NAME` is optional and defaults to `meta-llama/Llama-3.1-8B-Instruct`.
66
+ For compatibility with different validator versions, `API_KEY` and `HF_TOKEN` are also accepted.
67
+
68
+ To run greedy mode locally without LLM calls:
69
+
70
+ ```bash
71
+ python inference.py --no-llm
72
+ ```
73
 
74
  ## Docker
75
 
inference.py CHANGED
@@ -12,7 +12,7 @@ Usage:
12
  python inference.py
13
  python inference.py --difficulty easy
14
  python inference.py --difficulty all
15
- python inference.py --use-llm
16
  python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space
17
  """
18
 
@@ -43,11 +43,16 @@ def _load_dotenv() -> None:
43
 
44
  _load_dotenv()
45
 
46
- API_BASE_URL = os.getenv('API_BASE_URL', 'https://router.huggingface.co/v1')
47
- MODEL_NAME = os.getenv('MODEL_NAME', 'meta-llama/Llama-3.1-8B-Instruct')
48
  HF_TOKEN = os.getenv('HF_TOKEN')
 
 
49
  LOCAL_IMAGE_NAME = os.getenv('LOCAL_IMAGE_NAME')
50
- API_KEY = HF_TOKEN or os.getenv('API_KEY')
 
 
 
 
 
51
 
52
  ENV_URL = os.getenv('ENV_URL', 'http://localhost:7860')
53
  TASK_NAME = 'container-stacking'
@@ -64,15 +69,15 @@ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[
64
  error_val = error if error else 'null'
65
  done_val = str(done).lower()
66
  print(
67
- f'[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}',
68
  flush=True,
69
  )
70
 
71
 
72
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
73
  rewards_str = ','.join(f'{r:.2f}' for r in rewards)
74
  print(
75
- f'[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}',
76
  flush=True,
77
  )
78
 
@@ -183,7 +188,7 @@ async def run_episode(url: str, difficulty: str = 'medium', use_llm: bool = Fals
183
  if not ws_url.endswith('/ws'):
184
  ws_url = ws_url.rstrip('/') + '/ws'
185
 
186
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if use_llm else None
187
  model_label = MODEL_NAME if use_llm else 'greedy'
188
 
189
  log_start(task=f'{TASK_NAME}-{difficulty}', env=BENCHMARK, model=model_label)
@@ -233,7 +238,7 @@ async def run_episode(url: str, difficulty: str = 'medium', use_llm: bool = Fals
233
  print(f'[DEBUG] Episode error: {exc}', file=sys.stderr, flush=True)
234
 
235
  finally:
236
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
237
 
238
  return score
239
 
@@ -247,10 +252,11 @@ if __name__ == '__main__':
247
  parser = argparse.ArgumentParser(description='Container Port Baseline Agent')
248
  parser.add_argument('--url', default=ENV_URL)
249
  parser.add_argument('--difficulty', default='all', choices=['easy', 'medium', 'hard', 'all'])
250
- parser.add_argument('--use-llm', action='store_true', help='Use LLM agent via HF router (requires HF_TOKEN)')
251
  args = parser.parse_args()
 
252
 
253
  if args.difficulty == 'all':
254
- asyncio.run(run_all(args.url, use_llm=args.use_llm))
255
  else:
256
- asyncio.run(run_episode(args.url, difficulty=args.difficulty, use_llm=args.use_llm))
 
12
  python inference.py
13
  python inference.py --difficulty easy
14
  python inference.py --difficulty all
15
+ python inference.py --no-llm
16
  python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space
17
  """
18
 
 
43
 
44
  _load_dotenv()
45
 
 
 
46
  HF_TOKEN = os.getenv('HF_TOKEN')
47
+ API_BASE_URL = os.getenv('API_BASE_URL', 'https://api.openai.com/v1')
48
+ MODEL_NAME = os.getenv('MODEL_NAME', 'meta-llama/Llama-3.1-8B-Instruct')
49
  LOCAL_IMAGE_NAME = os.getenv('LOCAL_IMAGE_NAME')
50
+ OPENAI_API_KEY = HF_TOKENAPI_KEY = os.getenv('API_KEY')
51
+ API_KEY = HF_TOKEN
52
+ AUTH_TOKEN = HF_TOKEN or API_KEY or OPENAI_API_KEY
53
+
54
+ if AUTH_TOKEN is None:
55
+ raise ValueError('OPENAI_API_KEY (or API_KEY/HF_TOKEN) environment variable is required')
56
 
57
  ENV_URL = os.getenv('ENV_URL', 'http://localhost:7860')
58
  TASK_NAME = 'container-stacking'
 
69
  error_val = error if error else 'null'
70
  done_val = str(done).lower()
71
  print(
72
+ f'[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}',
73
  flush=True,
74
  )
75
 
76
 
77
+ def log_end(success: bool, steps: int, rewards: List[float]) -> None:
78
  rewards_str = ','.join(f'{r:.2f}' for r in rewards)
79
  print(
80
+ f'[END] success={str(success).lower()} steps={steps} rewards={rewards_str}',
81
  flush=True,
82
  )
83
 
 
188
  if not ws_url.endswith('/ws'):
189
  ws_url = ws_url.rstrip('/') + '/ws'
190
 
191
+ client = OpenAI(base_url=API_BASE_URL, api_key=AUTH_TOKEN) if use_llm else None
192
  model_label = MODEL_NAME if use_llm else 'greedy'
193
 
194
  log_start(task=f'{TASK_NAME}-{difficulty}', env=BENCHMARK, model=model_label)
 
238
  print(f'[DEBUG] Episode error: {exc}', file=sys.stderr, flush=True)
239
 
240
  finally:
241
+ log_end(success=success, steps=steps_taken, rewards=rewards)
242
 
243
  return score
244
 
 
252
  parser = argparse.ArgumentParser(description='Container Port Baseline Agent')
253
  parser.add_argument('--url', default=ENV_URL)
254
  parser.add_argument('--difficulty', default='all', choices=['easy', 'medium', 'hard', 'all'])
255
+ parser.add_argument('--no-llm', action='store_true', help='Disable LLM agent and use greedy policy')
256
  args = parser.parse_args()
257
+ use_llm = not args.no_llm
258
 
259
  if args.difficulty == 'all':
260
+ asyncio.run(run_all(args.url, use_llm=use_llm))
261
  else:
262
+ asyncio.run(run_episode(args.url, difficulty=args.difficulty, use_llm=use_llm))