File size: 5,113 Bytes
e70c305
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Minimal OpenAI-compatible chat completions server backed by a local
transformers + PEFT checkpoint.

Lets the trained CricketCaptain model be consumed by anything that speaks
OpenAI's chat-completions API:
  - inference.py:     --model local --api-base http://your-host:8080/v1
  - compare_eval.py:  same flags
  - HF Space live LLM opponent / captain: set Space secrets
        CRICKET_OPPONENT_API_BASE = http://your-host:8080/v1
        CRICKET_OPPONENT_MODEL    = local
        HF_TOKEN                  = (any non-empty string; this server
                                     ignores auth — only set so the openai
                                     client doesn't refuse to call)
    The env's llm_live opponent (server/opponent_policy.py) reads those env
    vars and routes its calls here automatically.

Run on a GPU box that has the trained adapter:

    python model_server.py --checkpoint ./checkpoints/stage2_final --port 8080

Endpoints:
    POST /v1/chat/completions   chat-completions (OpenAI-compatible)
    GET  /v1/models             {"data": [{"id": "<checkpoint>", ...}]}

Notes:
  - Loads the base model from `adapter_config.json.base_model_name_or_path`
    and applies the LoRA adapter via PEFT. If the path is a full model
    (no adapter_config), it loads it directly.
  - Pre-fills an empty `<think>...</think>` block before generation so
    Qwen3-Thinking variants don't burn tokens reasoning. Harmless on the
    Qwen3-4B-Instruct-2507 path (no native thinking blocks).
"""
import argparse
import time
import uuid

import torch
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Any
import uvicorn


app = FastAPI()
_model = None
_tokenizer = None
_model_name = "local"


class Message(BaseModel):
    role: str
    content: str


class ChatRequest(BaseModel):
    model: str = "local"
    messages: list[Message]
    max_tokens: int = 300
    temperature: float = 0.2


def _load_model(checkpoint: str):
    global _model, _tokenizer, _model_name
    from transformers import AutoTokenizer, AutoModelForCausalLM
    from peft import PeftModel
    import json, pathlib

    cfg_path = pathlib.Path(checkpoint) / "adapter_config.json"
    if cfg_path.exists():
        base = json.loads(cfg_path.read_text())["base_model_name_or_path"]
        print(f"Loading base model: {base}")
        _tokenizer = AutoTokenizer.from_pretrained(base, trust_remote_code=True)
        base_model = AutoModelForCausalLM.from_pretrained(
            base,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )
        print(f"Loading LoRA adapter from: {checkpoint}")
        _model = PeftModel.from_pretrained(base_model, checkpoint)
    else:
        print(f"Loading model directly: {checkpoint}")
        _tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
        _model = AutoModelForCausalLM.from_pretrained(
            checkpoint,
            torch_dtype=torch.bfloat16,
            device_map="auto",
            trust_remote_code=True,
        )

    _model.eval()
    _model_name = checkpoint
    print("Model ready.")


@app.post("/v1/chat/completions")
def chat_completions(req: ChatRequest):
    # Build prompt with thinking disabled (pre-fill empty think block)
    chat = []
    for m in req.messages:
        chat.append({"role": m.role, "content": m.content})

    text = _tokenizer.apply_chat_template(
        chat,
        tokenize=False,
        add_generation_prompt=True,
    )
    # Disable thinking by pre-filling an empty think block
    text += "<think>\n\n</think>\n"

    inputs = _tokenizer(text, return_tensors="pt").to(_model.device)
    input_len = inputs["input_ids"].shape[1]

    with torch.no_grad():
        out = _model.generate(
            **inputs,
            max_new_tokens=req.max_tokens,
            temperature=max(req.temperature, 1e-4),
            do_sample=req.temperature > 0.01,
            pad_token_id=_tokenizer.eos_token_id,
        )

    new_tokens = out[0][input_len:]
    generated = _tokenizer.decode(new_tokens, skip_special_tokens=True).strip()

    return {
        "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
        "object": "chat.completion",
        "created": int(time.time()),
        "model": _model_name,
        "choices": [{
            "index": 0,
            "message": {"role": "assistant", "content": generated},
            "finish_reason": "stop",
        }],
        "usage": {"prompt_tokens": input_len, "completion_tokens": len(new_tokens), "total_tokens": input_len + len(new_tokens)},
    }


@app.get("/v1/models")
def list_models():
    return {"object": "list", "data": [{"id": _model_name, "object": "model"}]}


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--checkpoint", default="./checkpoints/stage2_final")
    parser.add_argument("--host", default="0.0.0.0")
    parser.add_argument("--port", type=int, default=8080)
    args = parser.parse_args()

    _load_model(args.checkpoint)
    uvicorn.run(app, host=args.host, port=args.port)