sync: pull latest from main (model_server.py, captain LLM toggle in ui.py, 0.6B configs, SUBMISSION + RUNTIME_DURABILITY docs)
e70c305 verified | """ | |
| Minimal OpenAI-compatible chat completions server backed by a local | |
| transformers + PEFT checkpoint. | |
| Lets the trained CricketCaptain model be consumed by anything that speaks | |
| OpenAI's chat-completions API: | |
| - inference.py: --model local --api-base http://your-host:8080/v1 | |
| - compare_eval.py: same flags | |
| - HF Space live LLM opponent / captain: set Space secrets | |
| CRICKET_OPPONENT_API_BASE = http://your-host:8080/v1 | |
| CRICKET_OPPONENT_MODEL = local | |
| HF_TOKEN = (any non-empty string; this server | |
| ignores auth — only set so the openai | |
| client doesn't refuse to call) | |
| The env's llm_live opponent (server/opponent_policy.py) reads those env | |
| vars and routes its calls here automatically. | |
| Run on a GPU box that has the trained adapter: | |
| python model_server.py --checkpoint ./checkpoints/stage2_final --port 8080 | |
| Endpoints: | |
| POST /v1/chat/completions chat-completions (OpenAI-compatible) | |
| GET /v1/models {"data": [{"id": "<checkpoint>", ...}]} | |
| Notes: | |
| - Loads the base model from `adapter_config.json.base_model_name_or_path` | |
| and applies the LoRA adapter via PEFT. If the path is a full model | |
| (no adapter_config), it loads it directly. | |
| - Pre-fills an empty `<think>...</think>` block before generation so | |
| Qwen3-Thinking variants don't burn tokens reasoning. Harmless on the | |
| Qwen3-4B-Instruct-2507 path (no native thinking blocks). | |
| """ | |
| import argparse | |
| import time | |
| import uuid | |
| import torch | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from typing import Any | |
| import uvicorn | |
| app = FastAPI() | |
| _model = None | |
| _tokenizer = None | |
| _model_name = "local" | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class ChatRequest(BaseModel): | |
| model: str = "local" | |
| messages: list[Message] | |
| max_tokens: int = 300 | |
| temperature: float = 0.2 | |
| def _load_model(checkpoint: str): | |
| global _model, _tokenizer, _model_name | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from peft import PeftModel | |
| import json, pathlib | |
| cfg_path = pathlib.Path(checkpoint) / "adapter_config.json" | |
| if cfg_path.exists(): | |
| base = json.loads(cfg_path.read_text())["base_model_name_or_path"] | |
| print(f"Loading base model: {base}") | |
| _tokenizer = AutoTokenizer.from_pretrained(base, trust_remote_code=True) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| base, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| print(f"Loading LoRA adapter from: {checkpoint}") | |
| _model = PeftModel.from_pretrained(base_model, checkpoint) | |
| else: | |
| print(f"Loading model directly: {checkpoint}") | |
| _tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) | |
| _model = AutoModelForCausalLM.from_pretrained( | |
| checkpoint, | |
| torch_dtype=torch.bfloat16, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| ) | |
| _model.eval() | |
| _model_name = checkpoint | |
| print("Model ready.") | |
| def chat_completions(req: ChatRequest): | |
| # Build prompt with thinking disabled (pre-fill empty think block) | |
| chat = [] | |
| for m in req.messages: | |
| chat.append({"role": m.role, "content": m.content}) | |
| text = _tokenizer.apply_chat_template( | |
| chat, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| # Disable thinking by pre-filling an empty think block | |
| text += "<think>\n\n</think>\n" | |
| inputs = _tokenizer(text, return_tensors="pt").to(_model.device) | |
| input_len = inputs["input_ids"].shape[1] | |
| with torch.no_grad(): | |
| out = _model.generate( | |
| **inputs, | |
| max_new_tokens=req.max_tokens, | |
| temperature=max(req.temperature, 1e-4), | |
| do_sample=req.temperature > 0.01, | |
| pad_token_id=_tokenizer.eos_token_id, | |
| ) | |
| new_tokens = out[0][input_len:] | |
| generated = _tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| return { | |
| "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": _model_name, | |
| "choices": [{ | |
| "index": 0, | |
| "message": {"role": "assistant", "content": generated}, | |
| "finish_reason": "stop", | |
| }], | |
| "usage": {"prompt_tokens": input_len, "completion_tokens": len(new_tokens), "total_tokens": input_len + len(new_tokens)}, | |
| } | |
| def list_models(): | |
| return {"object": "list", "data": [{"id": _model_name, "object": "model"}]} | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--checkpoint", default="./checkpoints/stage2_final") | |
| parser.add_argument("--host", default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=8080) | |
| args = parser.parse_args() | |
| _load_model(args.checkpoint) | |
| uvicorn.run(app, host=args.host, port=args.port) | |