File size: 5,113 Bytes
e70c305 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | """
Minimal OpenAI-compatible chat completions server backed by a local
transformers + PEFT checkpoint.
Lets the trained CricketCaptain model be consumed by anything that speaks
OpenAI's chat-completions API:
- inference.py: --model local --api-base http://your-host:8080/v1
- compare_eval.py: same flags
- HF Space live LLM opponent / captain: set Space secrets
CRICKET_OPPONENT_API_BASE = http://your-host:8080/v1
CRICKET_OPPONENT_MODEL = local
HF_TOKEN = (any non-empty string; this server
ignores auth — only set so the openai
client doesn't refuse to call)
The env's llm_live opponent (server/opponent_policy.py) reads those env
vars and routes its calls here automatically.
Run on a GPU box that has the trained adapter:
python model_server.py --checkpoint ./checkpoints/stage2_final --port 8080
Endpoints:
POST /v1/chat/completions chat-completions (OpenAI-compatible)
GET /v1/models {"data": [{"id": "<checkpoint>", ...}]}
Notes:
- Loads the base model from `adapter_config.json.base_model_name_or_path`
and applies the LoRA adapter via PEFT. If the path is a full model
(no adapter_config), it loads it directly.
- Pre-fills an empty `<think>...</think>` block before generation so
Qwen3-Thinking variants don't burn tokens reasoning. Harmless on the
Qwen3-4B-Instruct-2507 path (no native thinking blocks).
"""
import argparse
import time
import uuid
import torch
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any
import uvicorn
app = FastAPI()
_model = None
_tokenizer = None
_model_name = "local"
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str = "local"
messages: list[Message]
max_tokens: int = 300
temperature: float = 0.2
def _load_model(checkpoint: str):
global _model, _tokenizer, _model_name
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import json, pathlib
cfg_path = pathlib.Path(checkpoint) / "adapter_config.json"
if cfg_path.exists():
base = json.loads(cfg_path.read_text())["base_model_name_or_path"]
print(f"Loading base model: {base}")
_tokenizer = AutoTokenizer.from_pretrained(base, trust_remote_code=True)
base_model = AutoModelForCausalLM.from_pretrained(
base,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
print(f"Loading LoRA adapter from: {checkpoint}")
_model = PeftModel.from_pretrained(base_model, checkpoint)
else:
print(f"Loading model directly: {checkpoint}")
_tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
_model = AutoModelForCausalLM.from_pretrained(
checkpoint,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
_model.eval()
_model_name = checkpoint
print("Model ready.")
@app.post("/v1/chat/completions")
def chat_completions(req: ChatRequest):
# Build prompt with thinking disabled (pre-fill empty think block)
chat = []
for m in req.messages:
chat.append({"role": m.role, "content": m.content})
text = _tokenizer.apply_chat_template(
chat,
tokenize=False,
add_generation_prompt=True,
)
# Disable thinking by pre-filling an empty think block
text += "<think>\n\n</think>\n"
inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
input_len = inputs["input_ids"].shape[1]
with torch.no_grad():
out = _model.generate(
**inputs,
max_new_tokens=req.max_tokens,
temperature=max(req.temperature, 1e-4),
do_sample=req.temperature > 0.01,
pad_token_id=_tokenizer.eos_token_id,
)
new_tokens = out[0][input_len:]
generated = _tokenizer.decode(new_tokens, skip_special_tokens=True).strip()
return {
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"object": "chat.completion",
"created": int(time.time()),
"model": _model_name,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": generated},
"finish_reason": "stop",
}],
"usage": {"prompt_tokens": input_len, "completion_tokens": len(new_tokens), "total_tokens": input_len + len(new_tokens)},
}
@app.get("/v1/models")
def list_models():
return {"object": "list", "data": [{"id": _model_name, "object": "model"}]}
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", default="./checkpoints/stage2_final")
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8080)
args = parser.parse_args()
_load_model(args.checkpoint)
uvicorn.run(app, host=args.host, port=args.port)
|