vllm support?

#2
by prudant - opened

can be server with vllm? how to use it if answer is yes

regards!

Hi, here is my suggestion.

Heads-up on the current state of things

Just to be upfront: personally, I haven't been able to get Qwen3.5 running on either vLLM or SGLang yet due to library compatibility issues. This is very much a "me right now" problem — I'm waiting on upstream updates to land, and the situation is generally unsettled. So take the framework-specific code below as the intended path; you may need to revisit it once the dust settles.

What we actually need

Regardless of which inference stack we end up on, the core requirement is simple:

  • First token: greedy (argmax, deterministic)
  • All subsequent tokens: temperature sampling, with temperature in the 0.3–0.5 range as the recommended sweet spot.

That's the whole spec. Now, how to implement it.

Recommended approach: two calls + prefix caching

Honestly, the cleanest way is to just split it into two requests:

# Step 1: greedy decode the first token
out1 = llm.generate(prompt, SamplingParams(temperature=0, max_tokens=1))
first_tok = out1[0].outputs[0].text

# Step 2: append that token and sample the rest
out2 = llm.generate(prompt + first_tok,
                    SamplingParams(temperature=0.4, max_tokens=N-1))

Why this is the right call:

  • vLLM and SGLang both ship with prefix caching on by default, so the second request reuses the KV cache from the first — prefill cost on the second call is essentially zero.
  • No interference with batching, CUDA graphs, or chunked prefill.
  • Trivial to read, trivial to debug, trivial to swap engines.

Alternative: custom logits processor

If a single streaming call is non-negotiable (e.g., you're streaming tokens to a frontend and don't want the seam between the two requests), a logits processor is a few lines:

vLLM:

def greedy_first(token_ids, logits):
    if len(token_ids) == 0:                  # about to emit the first token
        mask = torch.full_like(logits, float('-inf'))
        mask[logits.argmax()] = 0
        return logits + mask                 # collapse to argmax only
    return logits                            # let the sampler's temperature handle the rest

sp = SamplingParams(temperature=0.4, logits_processors=[greedy_first])

HuggingFace Transformers:

class GreedyFirstToken(LogitsProcessor):
    def __init__(self, prompt_len):
        self.prompt_len = prompt_len
    def __call__(self, input_ids, scores):
        if input_ids.shape[-1] == self.prompt_len:        # first generation step
            top = scores.argmax(dim=-1, keepdim=True)
            mask = torch.full_like(scores, float('-inf'))
            mask.scatter_(-1, top, 0)
            return mask                                    # only argmax survives
        return scores                                      # downstream TemperatureLogitsWarper takes over

out = model.generate(**inputs, do_sample=True, temperature=0.4,
                     logits_processor=LogitsProcessorList(
                         [GreedyFirstToken(inputs.input_ids.shape[-1])]))

thanks, very helpful information!

Sign up or log in to comment