Spaces:
Sleeping
Sleeping
Commit ·
16229c6
1
Parent(s): 986ebb9
Fix inference env/auth handling and validator output format
Browse files- README.md +15 -1
- inference.py +18 -12
README.md
CHANGED
|
@@ -55,7 +55,21 @@ python inference.py --difficulty easy
|
|
| 55 |
python inference.py --url http://127.0.0.1:7860 --difficulty all
|
| 56 |
```
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
## Docker
|
| 61 |
|
|
|
|
| 55 |
python inference.py --url http://127.0.0.1:7860 --difficulty all
|
| 56 |
```
|
| 57 |
|
| 58 |
+
LLM mode is enabled by default in `inference.py` and requires:
|
| 59 |
+
|
| 60 |
+
```bash
|
| 61 |
+
export API_BASE_URL="https://api.openai.com/v1" # or validator-provided proxy URL
|
| 62 |
+
export OPENAI_API_KEY="your-validator-provided-token"
|
| 63 |
+
```
|
| 64 |
+
|
| 65 |
+
`MODEL_NAME` is optional and defaults to `meta-llama/Llama-3.1-8B-Instruct`.
|
| 66 |
+
For compatibility with different validator versions, `API_KEY` and `HF_TOKEN` are also accepted.
|
| 67 |
+
|
| 68 |
+
To run greedy mode locally without LLM calls:
|
| 69 |
+
|
| 70 |
+
```bash
|
| 71 |
+
python inference.py --no-llm
|
| 72 |
+
```
|
| 73 |
|
| 74 |
## Docker
|
| 75 |
|
inference.py
CHANGED
|
@@ -12,7 +12,7 @@ Usage:
|
|
| 12 |
python inference.py
|
| 13 |
python inference.py --difficulty easy
|
| 14 |
python inference.py --difficulty all
|
| 15 |
-
|
| 16 |
python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space
|
| 17 |
"""
|
| 18 |
|
|
@@ -43,11 +43,16 @@ def _load_dotenv() -> None:
|
|
| 43 |
|
| 44 |
_load_dotenv()
|
| 45 |
|
| 46 |
-
API_BASE_URL = os.getenv('API_BASE_URL', 'https://router.huggingface.co/v1')
|
| 47 |
-
MODEL_NAME = os.getenv('MODEL_NAME', 'meta-llama/Llama-3.1-8B-Instruct')
|
| 48 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
|
|
|
|
|
|
| 49 |
LOCAL_IMAGE_NAME = os.getenv('LOCAL_IMAGE_NAME')
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
ENV_URL = os.getenv('ENV_URL', 'http://localhost:7860')
|
| 53 |
TASK_NAME = 'container-stacking'
|
|
@@ -64,15 +69,15 @@ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[
|
|
| 64 |
error_val = error if error else 'null'
|
| 65 |
done_val = str(done).lower()
|
| 66 |
print(
|
| 67 |
-
f'[STEP]
|
| 68 |
flush=True,
|
| 69 |
)
|
| 70 |
|
| 71 |
|
| 72 |
-
def log_end(success: bool, steps: int,
|
| 73 |
rewards_str = ','.join(f'{r:.2f}' for r in rewards)
|
| 74 |
print(
|
| 75 |
-
f'[END] success={str(success).lower()} steps={steps}
|
| 76 |
flush=True,
|
| 77 |
)
|
| 78 |
|
|
@@ -183,7 +188,7 @@ async def run_episode(url: str, difficulty: str = 'medium', use_llm: bool = Fals
|
|
| 183 |
if not ws_url.endswith('/ws'):
|
| 184 |
ws_url = ws_url.rstrip('/') + '/ws'
|
| 185 |
|
| 186 |
-
client = OpenAI(base_url=API_BASE_URL, api_key=
|
| 187 |
model_label = MODEL_NAME if use_llm else 'greedy'
|
| 188 |
|
| 189 |
log_start(task=f'{TASK_NAME}-{difficulty}', env=BENCHMARK, model=model_label)
|
|
@@ -233,7 +238,7 @@ async def run_episode(url: str, difficulty: str = 'medium', use_llm: bool = Fals
|
|
| 233 |
print(f'[DEBUG] Episode error: {exc}', file=sys.stderr, flush=True)
|
| 234 |
|
| 235 |
finally:
|
| 236 |
-
log_end(success=success, steps=steps_taken,
|
| 237 |
|
| 238 |
return score
|
| 239 |
|
|
@@ -247,10 +252,11 @@ if __name__ == '__main__':
|
|
| 247 |
parser = argparse.ArgumentParser(description='Container Port Baseline Agent')
|
| 248 |
parser.add_argument('--url', default=ENV_URL)
|
| 249 |
parser.add_argument('--difficulty', default='all', choices=['easy', 'medium', 'hard', 'all'])
|
| 250 |
-
parser.add_argument('--
|
| 251 |
args = parser.parse_args()
|
|
|
|
| 252 |
|
| 253 |
if args.difficulty == 'all':
|
| 254 |
-
asyncio.run(run_all(args.url, use_llm=
|
| 255 |
else:
|
| 256 |
-
asyncio.run(run_episode(args.url, difficulty=args.difficulty, use_llm=
|
|
|
|
| 12 |
python inference.py
|
| 13 |
python inference.py --difficulty easy
|
| 14 |
python inference.py --difficulty all
|
| 15 |
+
python inference.py --no-llm
|
| 16 |
python inference.py --url https://YOUR_USERNAME-container-port-env.hf.space
|
| 17 |
"""
|
| 18 |
|
|
|
|
| 43 |
|
| 44 |
_load_dotenv()
|
| 45 |
|
|
|
|
|
|
|
| 46 |
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 47 |
+
API_BASE_URL = os.getenv('API_BASE_URL', 'https://api.openai.com/v1')
|
| 48 |
+
MODEL_NAME = os.getenv('MODEL_NAME', 'meta-llama/Llama-3.1-8B-Instruct')
|
| 49 |
LOCAL_IMAGE_NAME = os.getenv('LOCAL_IMAGE_NAME')
|
| 50 |
+
OPENAI_API_KEY = HF_TOKENAPI_KEY = os.getenv('API_KEY')
|
| 51 |
+
API_KEY = HF_TOKEN
|
| 52 |
+
AUTH_TOKEN = HF_TOKEN or API_KEY or OPENAI_API_KEY
|
| 53 |
+
|
| 54 |
+
if AUTH_TOKEN is None:
|
| 55 |
+
raise ValueError('OPENAI_API_KEY (or API_KEY/HF_TOKEN) environment variable is required')
|
| 56 |
|
| 57 |
ENV_URL = os.getenv('ENV_URL', 'http://localhost:7860')
|
| 58 |
TASK_NAME = 'container-stacking'
|
|
|
|
| 69 |
error_val = error if error else 'null'
|
| 70 |
done_val = str(done).lower()
|
| 71 |
print(
|
| 72 |
+
f'[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}',
|
| 73 |
flush=True,
|
| 74 |
)
|
| 75 |
|
| 76 |
|
| 77 |
+
def log_end(success: bool, steps: int, rewards: List[float]) -> None:
|
| 78 |
rewards_str = ','.join(f'{r:.2f}' for r in rewards)
|
| 79 |
print(
|
| 80 |
+
f'[END] success={str(success).lower()} steps={steps} rewards={rewards_str}',
|
| 81 |
flush=True,
|
| 82 |
)
|
| 83 |
|
|
|
|
| 188 |
if not ws_url.endswith('/ws'):
|
| 189 |
ws_url = ws_url.rstrip('/') + '/ws'
|
| 190 |
|
| 191 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=AUTH_TOKEN) if use_llm else None
|
| 192 |
model_label = MODEL_NAME if use_llm else 'greedy'
|
| 193 |
|
| 194 |
log_start(task=f'{TASK_NAME}-{difficulty}', env=BENCHMARK, model=model_label)
|
|
|
|
| 238 |
print(f'[DEBUG] Episode error: {exc}', file=sys.stderr, flush=True)
|
| 239 |
|
| 240 |
finally:
|
| 241 |
+
log_end(success=success, steps=steps_taken, rewards=rewards)
|
| 242 |
|
| 243 |
return score
|
| 244 |
|
|
|
|
| 252 |
parser = argparse.ArgumentParser(description='Container Port Baseline Agent')
|
| 253 |
parser.add_argument('--url', default=ENV_URL)
|
| 254 |
parser.add_argument('--difficulty', default='all', choices=['easy', 'medium', 'hard', 'all'])
|
| 255 |
+
parser.add_argument('--no-llm', action='store_true', help='Disable LLM agent and use greedy policy')
|
| 256 |
args = parser.parse_args()
|
| 257 |
+
use_llm = not args.no_llm
|
| 258 |
|
| 259 |
if args.difficulty == 'all':
|
| 260 |
+
asyncio.run(run_all(args.url, use_llm=use_llm))
|
| 261 |
else:
|
| 262 |
+
asyncio.run(run_episode(args.url, difficulty=args.difficulty, use_llm=use_llm))
|