How to Build a vLLM Plugin: A Guide to the general_plugins Entry Point

Community Article Published April 10, 2026

If you've ever tried to modify vLLM's behavior, you probably know the pain. We certainly do.

When we built FlashHead (a fast approximate replacement for the LM head) we started by string-patching vLLM's source and baking it into custom Docker images. Then we wrapped it all in an embedl-models package that users had to install separately, with its own custom LLM import. Every vLLM update broke something. Users had to pull multi-GB containers just to try it. "Just pip install and run" was a distant dream.

Then we found vLLM's plugin system... I wish we'd known about it sooner. It's a simple entry point mechanism that lets you hook into vLLM's startup, register custom architectures, and patch internals, all from a standard pip install. Our entire integration went from a custom Docker container to:

pip install flash-head
vllm serve embedl/Qwen3-1.7B-FlashHead-W4A16

This post explains how the plugin system works, how to build your own, and walks through how we rebuilt FlashHead on top of it.


๐Ÿ”Œ How vLLM discovers plugins

vLLM uses Python's standard entry points mechanism. When vLLM starts up, it scans all installed packages for entry points registered under specific groups. For each one it finds, it calls the registered function.

This happens early -- before model loading, before CUDA initialization, in every worker process. Your plugin gets to set up whatever it needs before anything else runs.

Plugin types

vLLM supports four plugin groups:

Entry point group Loaded where Use case
vllm.general_plugins All processes Register models, patch internals
vllm.platform_plugins All processes (platform init) Custom hardware platforms
vllm.io_processor_plugins Process 0 only Pre/post-processing for pooling models
vllm.stat_logger_plugins Process 0 (async serving) Custom stat loggers

This post focuses on vllm.general_plugins, which is the most versatile -- it's where you register custom architectures and modify inference behavior.

Filtering with VLLM_PLUGINS

vLLM respects the VLLM_PLUGINS environment variable to filter which plugins load. This is useful during debugging or when you want to selectively enable/disable plugins without uninstalling them.


๐Ÿงฉ The minimal plugin

A vLLM plugin needs two things: a register() function and an entry point declaration.

1. The register function

# src/my_plugin/__init__.py

def register():
    """Called by vLLM at startup in every process."""
    print("My plugin is loaded!")

This is your hook into vLLM's initialization. From here you can:

  • Register custom model architectures
  • Monkey-patch vLLM internals
  • Set up logging, metrics, or telemetry
  • Load configuration from environment variables or files

2. The entry point in pyproject.toml

[build-system]
requires = ["setuptools>=64"]
build-backend = "setuptools.build_meta"

[project]
name = "my-vllm-plugin"
version = "0.1.0"
requires-python = ">=3.10"
dependencies = []  # vLLM and torch are already installed

[project.entry-points."vllm.general_plugins"]
my_plugin = "my_plugin:register"

[tool.setuptools.packages.find]
where = ["src"]

The key line is:

[project.entry-points."vllm.general_plugins"]
my_plugin = "my_plugin:register"

This tells Python's packaging system: "when someone looks for vllm.general_plugins entry points, point them to the register function in the my_plugin module."

After pip install ., vLLM will automatically discover and call your register() function at startup.


๐Ÿ—๏ธ Registering custom architectures

One of the most useful things a plugin can do is register custom model architectures with vLLM's ModelRegistry. This lets you publish models on Hugging Face with custom architecture names that only work when your plugin is installed.

from vllm import ModelRegistry

def register():
    ModelRegistry.register_model(
        "MyCustomLlamaForCausalLM",
        "vllm.model_executor.models.llama:LlamaForCausalLM",
    )

Now a model with "architectures": ["MyCustomLlamaForCausalLM"] in its config.json will load through your plugin. Without the plugin installed, vLLM rejects the model with an "architecture not supported" error -> users can't accidentally run without your optimization.

The second argument can be either a lazy string import ("module:class") or a direct class reference. Lazy strings are preferred since vLLM won't import the model class until it's actually needed:

# Lazy string (recommended) -- no import until model is used
ModelRegistry.register_model(
    "MyLlama", "my_plugin.models:MyLlamaForCausalLM"
)

# Direct class -- imported immediately
from my_plugin.models import MyLlamaForCausalLM
ModelRegistry.register_model("MyLlama", MyLlamaForCausalLM)

๐Ÿ”ง Monkey-patching vLLM internals

For deeper modifications, plugins can replace vLLM's internal methods at startup. The pattern is straightforward:

def register():
    from vllm.model_executor.layers.logits_processor import LogitsProcessor

    _original = LogitsProcessor._get_logits

    def _patched(self, hidden_states, lm_head, embedding_bias):
        # Your custom logic here
        return _original(self, hidden_states, lm_head, embedding_bias)

    LogitsProcessor._get_logits = _patched

Save the original method, define a replacement that wraps or replaces it, and assign it back to the class. Since register() runs before model initialization, your patches are in place before any inference happens.


โšก Real-world example: FlashHead

FlashHead is a vLLM plugin that replaces the dense language model head with a two-stage retrieval pipeline, delivering up to 2x inference speedup.

Instead of scoring all 128K+ vocabulary tokens at every decode step, FlashHead first identifies relevant cluster regions, then scores only the candidates within those clusters:

v

Here's how it uses the plugin system.

Entry point

# pyproject.toml
[project.entry-points."vllm.general_plugins"]
flash_head = "flash_head:register"

The register function

_patches_applied = False

def register():
    """vLLM plugin entry point. Called in every process before model init."""
    global _patches_applied
    if _patches_applied:
        return
    _patches_applied = True

    _register_architectures()

    if os.environ.get("FLASHHEAD_ENABLED", "1") == "0":
        logger.info("[FlashHead] Disabled via FLASHHEAD_ENABLED=0")
        return

    from flash_head.patches import apply_all
    apply_all()

A few things worth noting:

  • Idempotency guard. _patches_applied prevents double-patching in multi-worker setups where register() might be called more than once.
  • Kill switch. FLASHHEAD_ENABLED=0 disables the plugin without uninstalling it. Architectures are still registered (so the model loads), but the performance patches are skipped. This is different from VLLM_PLUGINS filtering -- with FLASHHEAD_ENABLED=0 the model still loads (useful for A/B benchmarking), while excluding via VLLM_PLUGINS would cause the model to fail with "architecture not supported".
  • Lazy imports. Patches are imported inside the function to avoid loading CUDA-dependent code before it's needed.

Architecture registration

def _register_architectures():
    from vllm import ModelRegistry

    _FLASHHEAD_ARCHITECTURES = {
        "FlashHeadLlamaForCausalLM": "vllm.model_executor.models.llama:LlamaForCausalLM",
        "FlashHeadQwen3ForCausalLM": "vllm.model_executor.models.qwen3:Qwen3ForCausalLM",
        "FlashHeadQwen3VLForConditionalGeneration": "vllm.model_executor.models.qwen3_vl:Qwen3VLForConditionalGeneration",
        "FlashHeadGemma3ForCausalLM": "vllm.model_executor.models.gemma3:Gemma3ForCausalLM",
    }

    supported = ModelRegistry.get_supported_archs()
    for fh_arch, model_cls_path in _FLASHHEAD_ARCHITECTURES.items():
        if fh_arch not in supported:
            ModelRegistry.register_model(fh_arch, model_cls_path)

FlashHead models use custom architecture names like FlashHeadQwen3ForCausalLM. These map to the standard vLLM model classes, the actual optimization happens via patches to the logits processor, not a custom model class. The custom names exist purely as a safety mechanism: without the plugin, the model won't load.

What FlashHead patches

FlashHead patches six vLLM components:

Component What the patch does
LogitsProcessor._get_logits Replaces dense vocabulary scoring with FlashHead's two-stage retrieval
Sampler.forward Bypasses sampling when logits are already token IDs
RejectionSampler.forward Handles FlashHead tokens in speculative decoding
EagleProposer._greedy_sample Supports FlashHead in EAGLE draft proposals
GPUModelRunner._dummy_sampler_run Handles warmup edge cases
LLMEngine.from_engine_args Loads FlashHead clustering assets before engine init

Each patch follows the same pattern: save the original, check if FlashHead should activate (based on model config), and either run the custom path or fall through to the original.


๐Ÿ›ก๏ธ Design patterns for plugins

From vLLM's own documentation: "Plugins can be loaded multiple times in different processes. They should be designed in a way that they can be loaded multiple times without causing issues."

Make it re-entrant

register() will be called in every vLLM process (main process + workers). Use a guard:

_initialized = False

def register():
    global _initialized
    if _initialized:
        return
    _initialized = True
    # ... your setup code

Add a kill switch

Let users disable your plugin without uninstalling it:

def register():
    if os.environ.get("MY_PLUGIN_ENABLED", "1") == "0":
        return
    # ... apply patches

Don't add vLLM as a dependency

Your plugin runs inside a vLLM installation. Adding vLLM to your dependencies would create circular installs or version conflicts. List dependencies = [] and rely on the host environment.

Use lazy imports

Don't import CUDA-heavy modules at the top level. Import inside register() to avoid premature initialization:

def register():
    # Good: lazy import
    from vllm.model_executor.layers.logits_processor import LogitsProcessor

Fail gracefully

vLLM versions change. Wrap patches in try/except so your plugin doesn't crash on import if an internal API shifted:

def register():
    try:
        from vllm.spec_decode.eagle import EagleProposer
        # ... patch it
    except ImportError:
        pass  # This vLLM version doesn't have EAGLE support

๐Ÿš€ Getting started

my-vllm-plugin/
โ”œโ”€โ”€ src/my_plugin/
โ”‚   โ”œโ”€โ”€ __init__.py      # Contains register()
โ”‚   โ””โ”€โ”€ patches.py       # Your patches
โ””โ”€โ”€ pyproject.toml        # Entry point declaration

Install it with pip install . alongside vLLM, and your register() function will be called automatically on every vllm serve or LLM() initialization.


๐Ÿ”— Links

Collection Benchmarks Blog
GitHub arXiv vLLM Python

Community

Sign up or log in to comment