streaming-vlm-qwen3-rocm / MIGRATION_GUIDE.md
s23deepak's picture
Add MIGRATION_GUIDE.md
48b62e4 verified
# Migration Guide: StreamingVLM β†’ Qwen3-VL on ROCm
## Overview of Changes
This document details every change made to port StreamingVLM from its original
implementation (Qwen2.5-VL-7B + flash-attn + CUDA) to (Qwen3-VL-4B + SDPA + ROCm).
---
## 1. Dependency Changes
### Original (`infer_requirements.txt`)
```
transformers==4.52.4
flash_attn==2.8.0.post2
torch==2.7.1
liger_kernel==0.6.1
qwen-vl-utils==0.0.11
```
### Updated (`requirements.txt`)
```
transformers>=4.57.0 (install from source: pip install git+https://github.com/huggingface/transformers)
torch>=2.4.0 (ROCm or CUDA build)
qwen-vl-utils>=0.0.14
# NO flash-attn required!
```
**Why**: Qwen3-VL requires transformers 4.57.0+ (unreleased, must install from source).
Flash-attn is completely eliminated β€” replaced with `torch.nn.functional.scaled_dot_product_attention`.
---
## 2. Model Class Changes
| Original | Updated |
|----------|---------|
| `Qwen2_5_VLForConditionalGeneration` | `Qwen3VLForConditionalGeneration` |
| `from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ...` | No private imports needed |
| `attn_implementation="flash_attention_2"` | `attn_implementation="sdpa"` |
### Loading Code
```python
# BEFORE (original)
from transformers import Qwen2_5_VLForConditionalGeneration
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_path, torch_dtype="auto", device_map="cuda",
attn_implementation="flash_attention_2"
)
# AFTER (this port)
from transformers import Qwen3VLForConditionalGeneration
model = Qwen3VLForConditionalGeneration.from_pretrained(
"Qwen/Qwen3-VL-4B-Instruct",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="sdpa" # Works on ROCm + CUDA
)
```
---
## 3. Attention Mechanism Replacement
### Language Model Attention
**Original**: Used `_flash_attention_forward()` (private transformers API, CUDA-only)
```python
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import _flash_attention_forward
attn_output = _flash_attention_forward(query, key, value, attention_mask, ...)
```
**Updated**: Uses `F.scaled_dot_product_attention()` (PyTorch native, ROCm+CUDA)
```python
import torch.nn.functional as F
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=(attention_mask is None and q_len > 1),
)
```
### Vision Encoder Attention
**Original**: Used `flash_attn_varlen_func` for packed variable-length sequences
```python
from flash_attn import flash_attn_varlen_func
attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen)
```
**Updated**: Chunked SDPA β€” splits at sequence boundaries, applies SDPA per chunk
```python
def _sdpa_varlen(query, key, value, cu_seqlens, max_seqlen):
outputs = []
for i in range(cu_seqlens.shape[0] - 1):
start, end = cu_seqlens[i].item(), cu_seqlens[i+1].item()
q_i = query[start:end].unsqueeze(0).transpose(1, 2)
k_i = key[start:end].unsqueeze(0).transpose(1, 2)
v_i = value[start:end].unsqueeze(0).transpose(1, 2)
out = F.scaled_dot_product_attention(q_i, k_i, v_i, is_causal=False)
outputs.append(out.transpose(1, 2).squeeze(0))
return torch.cat(outputs, dim=0)
```
---
## 4. Architecture Differences (Qwen2.5-VL β†’ Qwen3-VL)
| Feature | Qwen2.5-VL-7B | Qwen3-VL-4B |
|---------|---------------|--------------|
| ViT patch size | 14Γ—14 | **16Γ—16** |
| ViT depth | 32 layers | **24 layers** |
| ViT hidden | 1280 | **1024** |
| LM hidden | 3584 | **2560** |
| LM layers | 28 | **36** |
| KV heads | 4 | **8** |
| Max context | 128K | **256K** |
| RoPE theta | 1M | **5M** |
| MRoPE | standard | **interleaved** |
| QK-Norm | ❌ | **βœ…** |
| DeepStack | ❌ | **βœ… (layers 5,11,17)** |
| ViT window attn | Yes | **No (full attn)** |
### Impact on StreamingVLM:
- **QK-Norm**: Added `self.q_norm(query)` and `self.k_norm(key)` before RoPE in attention
- **DeepStack**: Vision encoder extracts features from intermediate layers for richer representation
- **Interleaved MRoPE**: `mrope_section=[24,20,20]` (temporal/height/width) with interleaved application
- **No ViT window attention**: Simpler vision encoder β€” no `fullatt_block_indexes` logic needed
---
## 5. Monkey-Patch Target Changes
### Original (10 patches on Qwen2.5-VL):
```python
model.generate = streaming_generate
model.prepare_inputs_for_generation = ...
model._sample = ...
model.forward = qwen2_5_vl_forward
model.model.forward = model_forward
model.model.language_model.forward = streaming_language_model_forward
model.model.language_model._update_causal_mask = ...
for layer in model.model.language_model.layers: # "language_model" submodule
layer.forward = streaming_text_decoder_layer_forward
layer.self_attn.forward = streaming_text_flash_attn_forward
model.model.visual.forward = streaming_visual_encoder_forward
```
### Updated (simplified patches on Qwen3-VL):
```python
model.forward = streaming_qwen3_vl_forward
model.model.forward = streaming_model_forward
for layer in model.model.layers: # directly on model.model (no "language_model"!)
layer.forward = streaming_text_decoder_layer_forward
layer.self_attn.forward = streaming_text_sdpa_forward
model.model.visual.forward = streaming_visual_encoder_forward
for blk in model.model.visual.blocks:
blk.forward = streaming_visual_block_forward
blk.attn.forward = streaming_visual_attention_forward
model.model.get_rope_index = get_rope_index_streaming
```
**Key structural difference**: In Qwen2.5-VL, text layers are at `model.model.language_model.layers`.
In Qwen3-VL, they are at `model.model.layers` (no intermediate `language_model` wrapper).
---
## 6. KV Cache API Changes
### transformers 4.50-4.52 (original):
```python
cache.key_cache[layer_idx] # List[Tensor]
cache.value_cache[layer_idx]
```
### transformers 5.8.0+ (updated):
```python
cache.layers[layer_idx].keys # DynamicLayer objects
cache.layers[layer_idx].values
```
Our `StreamingCache` handles both via `_get_layer_keys()`/`_set_layer_keys()` compat methods.
---
## 7. Token IDs (Verified for Qwen3-VL)
| Token | ID | Same as 2.5? |
|-------|-----|--------------|
| `<\|im_start\|>` | 151644 | βœ… |
| `<\|im_end\|>` | 151645 | βœ… |
| `<\|vision_start\|>` | 151652 | βœ… |
| `<\|vision_end\|>` | 151653 | βœ… |
| `<\|image_pad\|>` | 151655 | βœ… |
| `<\|video_pad\|>` | 151656 | βœ… |
Token IDs are identical β€” no changes needed for token-level logic.
---
## 8. ROCm-Specific Configuration
### Performance tiers (recommended):
| Priority | Implementation | Install | Performance | Platform |
|----------|---------------|---------|-------------|----------|
| 1 | `flash_attention_2` + CK backend | Build flash-attn from source | Best | ROCm β‰₯5.7 |
| 2 | `sdpa` (default) | None (PyTorch built-in) | Good | Any |
| 3 | `eager` | None | Slowest | Any |
### For ROCm flash-attn (optional, for maximum throughput):
```bash
git clone https://github.com/Dao-AILab/flash-attention.git
cd flash-attention
pip install .
# Then use: attn_implementation="flash_attention_2"
```
### Environment variables:
```bash
export PYTORCH_ROCM_ARCH="gfx942" # MI300X
# export PYTORCH_ROCM_ARCH="gfx90a" # MI250X
# export PYTORCH_ROCM_ARCH="gfx1100" # RX 7900 XTX
# Flash-attn backend selection:
export FLASH_ATTENTION_TRITON_AMD_ENABLE="FALSE" # CK (default, recommended)
# export FLASH_ATTENTION_TRITON_AMD_ENABLE="TRUE" # Triton (alternative)
```
---
## 9. Eliminated Private API Dependencies
The original code imported many private symbols from transformers internals.
These are **all eliminated** in this port:
```python
# REMOVED β€” These broke between transformers versions:
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
_flash_attention_forward, # Private, CUDA-only
flash_attn_varlen_func, # Private re-export of flash_attn
rotate_half, # Reimplemented locally
repeat_kv, # Reimplemented locally
apply_rotary_pos_emb_flashatt, # Not needed with SDPA
StaticCache, # Using DynamicCache directly
SlidingWindowCache, # Not used
AttentionMaskConverter, # Not used
make_flex_block_causal_mask, # Not used
BlockMask, # Not used
)
# KEPT β€” Only stable public APIs:
from transformers import (
Qwen3VLForConditionalGeneration, # Public model class
AutoProcessor, # Public processor
DynamicCache, # Public cache class
)
from transformers.modeling_outputs import (
BaseModelOutputWithPast, # Public output type
CausalLMOutputWithPast, # Public output type
)
```
---
## 10. Running the Updated Code
### Inference (streaming commentary):
```bash
python -m streaming_vlm.inference.inference \
--model_path Qwen/Qwen3-VL-4B-Instruct \
--video_path match.mp4 \
--output_path commentary.vtt \
--fps 2 \
--attn_implementation sdpa
```
### Training (Stage 1):
```bash
# Download data
# huggingface-cli download mit-han-lab/Inf-Stream-Train --local-dir ./data
bash scripts/sft_stage_1.sh
```
### Tests:
```bash
python test_imports.py
# Expected: 6/6 tests pass with no CUDA/flash-attn dependency
```