# Migration Guide: StreamingVLM → Qwen3-VL on ROCm ## Overview of Changes This document details every change made to port StreamingVLM from its original implementation (Qwen2.5-VL-7B + flash-attn + CUDA) to (Qwen3-VL-4B + SDPA + ROCm). --- ## 1. Dependency Changes ### Original (`infer_requirements.txt`) ``` transformers==4.52.4 flash_attn==2.8.0.post2 torch==2.7.1 liger_kernel==0.6.1 qwen-vl-utils==0.0.11 ``` ### Updated (`requirements.txt`) ``` transformers>=4.57.0 (install from source: pip install git+https://github.com/huggingface/transformers) torch>=2.4.0 (ROCm or CUDA build) qwen-vl-utils>=0.0.14 # NO flash-attn required! ``` **Why**: Qwen3-VL requires transformers 4.57.0+ (unreleased, must install from source). Flash-attn is completely eliminated — replaced with `torch.nn.functional.scaled_dot_product_attention`. --- ## 2. Model Class Changes | Original | Updated | |----------|---------| | `Qwen2_5_VLForConditionalGeneration` | `Qwen3VLForConditionalGeneration` | | `from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ...` | No private imports needed | | `attn_implementation="flash_attention_2"` | `attn_implementation="sdpa"` | ### Loading Code ```python # BEFORE (original) from transformers import Qwen2_5_VLForConditionalGeneration model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_path, torch_dtype="auto", device_map="cuda", attn_implementation="flash_attention_2" ) # AFTER (this port) from transformers import Qwen3VLForConditionalGeneration model = Qwen3VLForConditionalGeneration.from_pretrained( "Qwen/Qwen3-VL-4B-Instruct", torch_dtype=torch.bfloat16, device_map="auto", attn_implementation="sdpa" # Works on ROCm + CUDA ) ``` --- ## 3. Attention Mechanism Replacement ### Language Model Attention **Original**: Used `_flash_attention_forward()` (private transformers API, CUDA-only) ```python from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _flash_attention_forward attn_output = _flash_attention_forward(query, key, value, attention_mask, ...) ``` **Updated**: Uses `F.scaled_dot_product_attention()` (PyTorch native, ROCm+CUDA) ```python import torch.nn.functional as F attn_output = F.scaled_dot_product_attention( query_states, key_states, value_states, attn_mask=attention_mask, dropout_p=0.0, is_causal=(attention_mask is None and q_len > 1), ) ``` ### Vision Encoder Attention **Original**: Used `flash_attn_varlen_func` for packed variable-length sequences ```python from flash_attn import flash_attn_varlen_func attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen) ``` **Updated**: Chunked SDPA — splits at sequence boundaries, applies SDPA per chunk ```python def _sdpa_varlen(query, key, value, cu_seqlens, max_seqlen): outputs = [] for i in range(cu_seqlens.shape[0] - 1): start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item() q_i = query[start:end].unsqueeze(0).transpose(1, 2) k_i = key[start:end].unsqueeze(0).transpose(1, 2) v_i = value[start:end].unsqueeze(0).transpose(1, 2) out = F.scaled_dot_product_attention(q_i, k_i, v_i, is_causal=False) outputs.append(out.transpose(1, 2).squeeze(0)) return torch.cat(outputs, dim=0) ``` --- ## 4. Architecture Differences (Qwen2.5-VL → Qwen3-VL) | Feature | Qwen2.5-VL-7B | Qwen3-VL-4B | |---------|---------------|--------------| | ViT patch size | 14×14 | **16×16** | | ViT depth | 32 layers | **24 layers** | | ViT hidden | 1280 | **1024** | | LM hidden | 3584 | **2560** | | LM layers | 28 | **36** | | KV heads | 4 | **8** | | Max context | 128K | **256K** | | RoPE theta | 1M | **5M** | | MRoPE | standard | **interleaved** | | QK-Norm | ❌ | **✅** | | DeepStack | ❌ | **✅ (layers 5,11,17)** | | ViT window attn | Yes | **No (full attn)** | ### Impact on StreamingVLM: - **QK-Norm**: Added `self.q_norm(query)` and `self.k_norm(key)` before RoPE in attention - **DeepStack**: Vision encoder extracts features from intermediate layers for richer representation - **Interleaved MRoPE**: `mrope_section=[24,20,20]` (temporal/height/width) with interleaved application - **No ViT window attention**: Simpler vision encoder — no `fullatt_block_indexes` logic needed --- ## 5. Monkey-Patch Target Changes ### Original (10 patches on Qwen2.5-VL): ```python model.generate = streaming_generate model.prepare_inputs_for_generation = ... model._sample = ... model.forward = qwen2_5_vl_forward model.model.forward = model_forward model.model.language_model.forward = streaming_language_model_forward model.model.language_model._update_causal_mask = ... for layer in model.model.language_model.layers: # "language_model" submodule layer.forward = streaming_text_decoder_layer_forward layer.self_attn.forward = streaming_text_flash_attn_forward model.model.visual.forward = streaming_visual_encoder_forward ``` ### Updated (simplified patches on Qwen3-VL): ```python model.forward = streaming_qwen3_vl_forward model.model.forward = streaming_model_forward for layer in model.model.layers: # directly on model.model (no "language_model"!) layer.forward = streaming_text_decoder_layer_forward layer.self_attn.forward = streaming_text_sdpa_forward model.model.visual.forward = streaming_visual_encoder_forward for blk in model.model.visual.blocks: blk.forward = streaming_visual_block_forward blk.attn.forward = streaming_visual_attention_forward model.model.get_rope_index = get_rope_index_streaming ``` **Key structural difference**: In Qwen2.5-VL, text layers are at `model.model.language_model.layers`. In Qwen3-VL, they are at `model.model.layers` (no intermediate `language_model` wrapper). --- ## 6. KV Cache API Changes ### transformers 4.50-4.52 (original): ```python cache.key_cache[layer_idx] # List[Tensor] cache.value_cache[layer_idx] ``` ### transformers 5.8.0+ (updated): ```python cache.layers[layer_idx].keys # DynamicLayer objects cache.layers[layer_idx].values ``` Our `StreamingCache` handles both via `_get_layer_keys()`/`_set_layer_keys()` compat methods. --- ## 7. Token IDs (Verified for Qwen3-VL) | Token | ID | Same as 2.5? | |-------|-----|--------------| | `<\|im_start\|>` | 151644 | ✅ | | `<\|im_end\|>` | 151645 | ✅ | | `<\|vision_start\|>` | 151652 | ✅ | | `<\|vision_end\|>` | 151653 | ✅ | | `<\|image_pad\|>` | 151655 | ✅ | | `<\|video_pad\|>` | 151656 | ✅ | Token IDs are identical — no changes needed for token-level logic. --- ## 8. ROCm-Specific Configuration ### Performance tiers (recommended): | Priority | Implementation | Install | Performance | Platform | |----------|---------------|---------|-------------|----------| | 1 | `flash_attention_2` + CK backend | Build flash-attn from source | Best | ROCm ≥5.7 | | 2 | `sdpa` (default) | None (PyTorch built-in) | Good | Any | | 3 | `eager` | None | Slowest | Any | ### For ROCm flash-attn (optional, for maximum throughput): ```bash git clone https://github.com/Dao-AILab/flash-attention.git cd flash-attention pip install . # Then use: attn_implementation="flash_attention_2" ``` ### Environment variables: ```bash export PYTORCH_ROCM_ARCH="gfx942" # MI300X # export PYTORCH_ROCM_ARCH="gfx90a" # MI250X # export PYTORCH_ROCM_ARCH="gfx1100" # RX 7900 XTX # Flash-attn backend selection: export FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE" # CK (default, recommended) # export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" # Triton (alternative) ``` --- ## 9. Eliminated Private API Dependencies The original code imported many private symbols from transformers internals. These are **all eliminated** in this port: ```python # REMOVED — These broke between transformers versions: from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( _flash_attention_forward, # Private, CUDA-only flash_attn_varlen_func, # Private re-export of flash_attn rotate_half, # Reimplemented locally repeat_kv, # Reimplemented locally apply_rotary_pos_emb_flashatt, # Not needed with SDPA StaticCache, # Using DynamicCache directly SlidingWindowCache, # Not used AttentionMaskConverter, # Not used make_flex_block_causal_mask, # Not used BlockMask, # Not used ) # KEPT — Only stable public APIs: from transformers import ( Qwen3VLForConditionalGeneration, # Public model class AutoProcessor, # Public processor DynamicCache, # Public cache class ) from transformers.modeling_outputs import ( BaseModelOutputWithPast, # Public output type CausalLMOutputWithPast, # Public output type ) ``` --- ## 10. Running the Updated Code ### Inference (streaming commentary): ```bash python -m streaming_vlm.inference.inference \ --model_path Qwen/Qwen3-VL-4B-Instruct \ --video_path match.mp4 \ --output_path commentary.vtt \ --fps 2 \ --attn_implementation sdpa ``` ### Training (Stage 1): ```bash # Download data # huggingface-cli download mit-han-lab/Inf-Stream-Train --local-dir ./data bash scripts/sft_stage_1.sh ``` ### Tests: ```bash python test_imports.py # Expected: 6/6 tests pass with no CUDA/flash-attn dependency ```