""" Minimal OpenAI-compatible chat completions server backed by a local transformers + PEFT checkpoint. Lets the trained CricketCaptain model be consumed by anything that speaks OpenAI's chat-completions API: - inference.py: --model local --api-base http://your-host:8080/v1 - compare_eval.py: same flags - HF Space live LLM opponent / captain: set Space secrets CRICKET_OPPONENT_API_BASE = http://your-host:8080/v1 CRICKET_OPPONENT_MODEL = local HF_TOKEN = (any non-empty string; this server ignores auth — only set so the openai client doesn't refuse to call) The env's llm_live opponent (server/opponent_policy.py) reads those env vars and routes its calls here automatically. Run on a GPU box that has the trained adapter: python model_server.py --checkpoint ./checkpoints/stage2_final --port 8080 Endpoints: POST /v1/chat/completions chat-completions (OpenAI-compatible) GET /v1/models {"data": [{"id": "", ...}]} Notes: - Loads the base model from `adapter_config.json.base_model_name_or_path` and applies the LoRA adapter via PEFT. If the path is a full model (no adapter_config), it loads it directly. - Pre-fills an empty `...` block before generation so Qwen3-Thinking variants don't burn tokens reasoning. Harmless on the Qwen3-4B-Instruct-2507 path (no native thinking blocks). """ import argparse import time import uuid import torch from fastapi import FastAPI from pydantic import BaseModel from typing import Any import uvicorn app = FastAPI() _model = None _tokenizer = None _model_name = "local" class Message(BaseModel): role: str content: str class ChatRequest(BaseModel): model: str = "local" messages: list[Message] max_tokens: int = 300 temperature: float = 0.2 def _load_model(checkpoint: str): global _model, _tokenizer, _model_name from transformers import AutoTokenizer, AutoModelForCausalLM from peft import PeftModel import json, pathlib cfg_path = pathlib.Path(checkpoint) / "adapter_config.json" if cfg_path.exists(): base = json.loads(cfg_path.read_text())["base_model_name_or_path"] print(f"Loading base model: {base}") _tokenizer = AutoTokenizer.from_pretrained(base, trust_remote_code=True) base_model = AutoModelForCausalLM.from_pretrained( base, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) print(f"Loading LoRA adapter from: {checkpoint}") _model = PeftModel.from_pretrained(base_model, checkpoint) else: print(f"Loading model directly: {checkpoint}") _tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) _model = AutoModelForCausalLM.from_pretrained( checkpoint, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True, ) _model.eval() _model_name = checkpoint print("Model ready.") @app.post("/v1/chat/completions") def chat_completions(req: ChatRequest): # Build prompt with thinking disabled (pre-fill empty think block) chat = [] for m in req.messages: chat.append({"role": m.role, "content": m.content}) text = _tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) # Disable thinking by pre-filling an empty think block text += "\n\n\n" inputs = _tokenizer(text, return_tensors="pt").to(_model.device) input_len = inputs["input_ids"].shape[1] with torch.no_grad(): out = _model.generate( **inputs, max_new_tokens=req.max_tokens, temperature=max(req.temperature, 1e-4), do_sample=req.temperature > 0.01, pad_token_id=_tokenizer.eos_token_id, ) new_tokens = out[0][input_len:] generated = _tokenizer.decode(new_tokens, skip_special_tokens=True).strip() return { "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "object": "chat.completion", "created": int(time.time()), "model": _model_name, "choices": [{ "index": 0, "message": {"role": "assistant", "content": generated}, "finish_reason": "stop", }], "usage": {"prompt_tokens": input_len, "completion_tokens": len(new_tokens), "total_tokens": input_len + len(new_tokens)}, } @app.get("/v1/models") def list_models(): return {"object": "list", "data": [{"id": _model_name, "object": "model"}]} if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", default="./checkpoints/stage2_final") parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8080) args = parser.parse_args() _load_model(args.checkpoint) uvicorn.run(app, host=args.host, port=args.port)