streaming-vlm-qwen3-rocm / MIGRATION_GUIDE.md
s23deepak's picture
Add MIGRATION_GUIDE.md
48b62e4 verified

Migration Guide: StreamingVLM β†’ Qwen3-VL on ROCm

Overview of Changes

This document details every change made to port StreamingVLM from its original implementation (Qwen2.5-VL-7B + flash-attn + CUDA) to (Qwen3-VL-4B + SDPA + ROCm).


1. Dependency Changes

Original (infer_requirements.txt)

transformers==4.52.4
flash_attn==2.8.0.post2
torch==2.7.1
liger_kernel==0.6.1
qwen-vl-utils==0.0.11

Updated (requirements.txt)

transformers>=4.57.0  (install from source: pip install git+https://github.com/huggingface/transformers)
torch>=2.4.0          (ROCm or CUDA build)
qwen-vl-utils>=0.0.14
# NO flash-attn required!

Why: Qwen3-VL requires transformers 4.57.0+ (unreleased, must install from source). Flash-attn is completely eliminated β€” replaced with torch.nn.functional.scaled_dot_product_attention.


2. Model Class Changes

Original Updated
Qwen2_5_VLForConditionalGeneration Qwen3VLForConditionalGeneration
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ... No private imports needed
attn_implementation="flash_attention_2" attn_implementation="sdpa"

Loading Code

# BEFORE (original)
from transformers import Qwen2_5_VLForConditionalGeneration
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    model_path, torch_dtype="auto", device_map="cuda",
    attn_implementation="flash_attention_2"
)

# AFTER (this port)  
from transformers import Qwen3VLForConditionalGeneration
model = Qwen3VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen3-VL-4B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"  # Works on ROCm + CUDA
)

3. Attention Mechanism Replacement

Language Model Attention

Original: Used _flash_attention_forward() (private transformers API, CUDA-only)

from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _flash_attention_forward
attn_output = _flash_attention_forward(query, key, value, attention_mask, ...)

Updated: Uses F.scaled_dot_product_attention() (PyTorch native, ROCm+CUDA)

import torch.nn.functional as F
attn_output = F.scaled_dot_product_attention(
    query_states, key_states, value_states,
    attn_mask=attention_mask,
    dropout_p=0.0,
    is_causal=(attention_mask is None and q_len > 1),
)

Vision Encoder Attention

Original: Used flash_attn_varlen_func for packed variable-length sequences

from flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen)

Updated: Chunked SDPA β€” splits at sequence boundaries, applies SDPA per chunk

def _sdpa_varlen(query, key, value, cu_seqlens, max_seqlen):
    outputs = []
    for i in range(cu_seqlens.shape[0] - 1):
        start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
        q_i = query[start:end].unsqueeze(0).transpose(1, 2)
        k_i = key[start:end].unsqueeze(0).transpose(1, 2)
        v_i = value[start:end].unsqueeze(0).transpose(1, 2)
        out = F.scaled_dot_product_attention(q_i, k_i, v_i, is_causal=False)
        outputs.append(out.transpose(1, 2).squeeze(0))
    return torch.cat(outputs, dim=0)

4. Architecture Differences (Qwen2.5-VL β†’ Qwen3-VL)

Feature Qwen2.5-VL-7B Qwen3-VL-4B
ViT patch size 14Γ—14 16Γ—16
ViT depth 32 layers 24 layers
ViT hidden 1280 1024
LM hidden 3584 2560
LM layers 28 36
KV heads 4 8
Max context 128K 256K
RoPE theta 1M 5M
MRoPE standard interleaved
QK-Norm ❌ βœ…
DeepStack ❌ βœ… (layers 5,11,17)
ViT window attn Yes No (full attn)

Impact on StreamingVLM:

  • QK-Norm: Added self.q_norm(query) and self.k_norm(key) before RoPE in attention
  • DeepStack: Vision encoder extracts features from intermediate layers for richer representation
  • Interleaved MRoPE: mrope_section=[24,20,20] (temporal/height/width) with interleaved application
  • No ViT window attention: Simpler vision encoder β€” no fullatt_block_indexes logic needed

5. Monkey-Patch Target Changes

Original (10 patches on Qwen2.5-VL):

model.generate = streaming_generate
model.prepare_inputs_for_generation = ...
model._sample = ...
model.forward = qwen2_5_vl_forward
model.model.forward = model_forward
model.model.language_model.forward = streaming_language_model_forward  
model.model.language_model._update_causal_mask = ...
for layer in model.model.language_model.layers:  # "language_model" submodule
    layer.forward = streaming_text_decoder_layer_forward
    layer.self_attn.forward = streaming_text_flash_attn_forward
model.model.visual.forward = streaming_visual_encoder_forward

Updated (simplified patches on Qwen3-VL):

model.forward = streaming_qwen3_vl_forward
model.model.forward = streaming_model_forward
for layer in model.model.layers:  # directly on model.model (no "language_model"!)
    layer.forward = streaming_text_decoder_layer_forward
    layer.self_attn.forward = streaming_text_sdpa_forward
model.model.visual.forward = streaming_visual_encoder_forward
for blk in model.model.visual.blocks:
    blk.forward = streaming_visual_block_forward
    blk.attn.forward = streaming_visual_attention_forward
model.model.get_rope_index = get_rope_index_streaming

Key structural difference: In Qwen2.5-VL, text layers are at model.model.language_model.layers. In Qwen3-VL, they are at model.model.layers (no intermediate language_model wrapper).


6. KV Cache API Changes

transformers 4.50-4.52 (original):

cache.key_cache[layer_idx]   # List[Tensor]
cache.value_cache[layer_idx]

transformers 5.8.0+ (updated):

cache.layers[layer_idx].keys   # DynamicLayer objects
cache.layers[layer_idx].values

Our StreamingCache handles both via _get_layer_keys()/_set_layer_keys() compat methods.


7. Token IDs (Verified for Qwen3-VL)

Token ID Same as 2.5?
<|im_start|> 151644 βœ…
<|im_end|> 151645 βœ…
<|vision_start|> 151652 βœ…
<|vision_end|> 151653 βœ…
<|image_pad|> 151655 βœ…
<|video_pad|> 151656 βœ…

Token IDs are identical β€” no changes needed for token-level logic.


8. ROCm-Specific Configuration

Performance tiers (recommended):

Priority Implementation Install Performance Platform
1 flash_attention_2 + CK backend Build flash-attn from source Best ROCm β‰₯5.7
2 sdpa (default) None (PyTorch built-in) Good Any
3 eager None Slowest Any

For ROCm flash-attn (optional, for maximum throughput):

git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install .
# Then use: attn_implementation="flash_attention_2"

Environment variables:

export PYTORCH_ROCM_ARCH="gfx942"    # MI300X
# export PYTORCH_ROCM_ARCH="gfx90a"  # MI250X
# export PYTORCH_ROCM_ARCH="gfx1100" # RX 7900 XTX

# Flash-attn backend selection:
export FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE"  # CK (default, recommended)
# export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" # Triton (alternative)

9. Eliminated Private API Dependencies

The original code imported many private symbols from transformers internals. These are all eliminated in this port:

# REMOVED β€” These broke between transformers versions:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
    _flash_attention_forward,      # Private, CUDA-only
    flash_attn_varlen_func,        # Private re-export of flash_attn
    rotate_half,                   # Reimplemented locally
    repeat_kv,                     # Reimplemented locally
    apply_rotary_pos_emb_flashatt, # Not needed with SDPA
    StaticCache,                   # Using DynamicCache directly
    SlidingWindowCache,            # Not used
    AttentionMaskConverter,        # Not used
    make_flex_block_causal_mask,   # Not used
    BlockMask,                     # Not used
)

# KEPT β€” Only stable public APIs:
from transformers import (
    Qwen3VLForConditionalGeneration,  # Public model class
    AutoProcessor,                     # Public processor
    DynamicCache,                      # Public cache class
)
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,           # Public output type
    CausalLMOutputWithPast,            # Public output type
)

10. Running the Updated Code

Inference (streaming commentary):

python -m streaming_vlm.inference.inference \
    --model_path Qwen/Qwen3-VL-4B-Instruct \
    --video_path match.mp4 \
    --output_path commentary.vtt \
    --fps 2 \
    --attn_implementation sdpa

Training (Stage 1):

# Download data
# huggingface-cli download mit-han-lab/Inf-Stream-Train --local-dir ./data

bash scripts/sft_stage_1.sh

Tests:

python test_imports.py
# Expected: 6/6 tests pass with no CUDA/flash-attn dependency