cricket-captain-llm / model_server.py
pratinavseth's picture
sync: pull latest from main (model_server.py, captain LLM toggle in ui.py, 0.6B configs, SUBMISSION + RUNTIME_DURABILITY docs)
e70c305 verified
"""
Minimal OpenAI-compatible chat completions server backed by a local
transformers + PEFT checkpoint.
Lets the trained CricketCaptain model be consumed by anything that speaks
OpenAI's chat-completions API:
- inference.py: --model local --api-base http://your-host:8080/v1
- compare_eval.py: same flags
- HF Space live LLM opponent / captain: set Space secrets
CRICKET_OPPONENT_API_BASE = http://your-host:8080/v1
CRICKET_OPPONENT_MODEL = local
HF_TOKEN = (any non-empty string; this server
ignores auth — only set so the openai
client doesn't refuse to call)
The env's llm_live opponent (server/opponent_policy.py) reads those env
vars and routes its calls here automatically.
Run on a GPU box that has the trained adapter:
python model_server.py --checkpoint ./checkpoints/stage2_final --port 8080
Endpoints:
POST /v1/chat/completions chat-completions (OpenAI-compatible)
GET /v1/models {"data": [{"id": "<checkpoint>", ...}]}
Notes:
- Loads the base model from `adapter_config.json.base_model_name_or_path`
and applies the LoRA adapter via PEFT. If the path is a full model
(no adapter_config), it loads it directly.
- Pre-fills an empty `<think>...</think>` block before generation so
Qwen3-Thinking variants don't burn tokens reasoning. Harmless on the
Qwen3-4B-Instruct-2507 path (no native thinking blocks).
"""
import argparse
import time
import uuid
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any
import uvicorn
app = FastAPI()
_model = None
_tokenizer = None
_model_name = "local"
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str = "local"
messages: list[Message]
max_tokens: int = 300
temperature: float = 0.2
def _load_model(checkpoint: str):
global _model, _tokenizer, _model_name
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import json, pathlib
cfg_path = pathlib.Path(checkpoint) / "adapter_config.json"
if cfg_path.exists():
base = json.loads(cfg_path.read_text())["base_model_name_or_path"]
print(f"Loading base model: {base}")
_tokenizer = AutoTokenizer.from_pretrained(base, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
base,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print(f"Loading LoRA adapter from: {checkpoint}")
_model = PeftModel.from_pretrained(base_model, checkpoint)
else:
print(f"Loading model directly: {checkpoint}")
_tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
_model.eval()
_model_name = checkpoint
print("Model ready.")
@app.post("/v1/chat/completions")
def chat_completions(req: ChatRequest):
# Build prompt with thinking disabled (pre-fill empty think block)
chat = []
for m in req.messages:
chat.append({"role": m.role, "content": m.content})
text = _tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True,
)
# Disable thinking by pre-filling an empty think block
text += "<think>\n\n</think>\n"
inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
input_len = inputs["input_ids"].shape[1]
with torch.no_grad():
out = _model.generate(
**inputs,
max_new_tokens=req.max_tokens,
temperature=max(req.temperature, 1e-4),
do_sample=req.temperature > 0.01,
pad_token_id=_tokenizer.eos_token_id,
)
new_tokens = out[0][input_len:]
generated = _tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
"created": int(time.time()),
"model": _model_name,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": generated},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": input_len, "completion_tokens": len(new_tokens), "total_tokens": input_len + len(new_tokens)},
}
@app.get("/v1/models")
def list_models():
return {"object": "list", "data": [{"id": _model_name, "object": "model"}]}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="./checkpoints/stage2_final")
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()
_load_model(args.checkpoint)
uvicorn.run(app, host=args.host, port=args.port)