| """ |
| train.py — SFT training for StreamingVLM with Qwen3-VL on ROCm. |
| |
| Trains the streaming capability using the overlapped-chunk, full-attention strategy |
| described in the StreamingVLM paper (Section 2.2). |
| |
| Key changes from original: |
| - Uses Qwen3VLForConditionalGeneration (not Qwen2_5) |
| - attn_implementation="sdpa" for ROCm compatibility |
| - Updated for latest transformers API (4.57.0+) |
| - No flash-attn dependency |
| """ |
|
|
| import os |
| import sys |
| from types import MethodType |
| from dataclasses import dataclass, field |
| from typing import Optional, List |
|
|
| import torch |
| import transformers |
| from transformers import ( |
| AutoProcessor, |
| HfArgumentParser, |
| Trainer, |
| TrainingArguments, |
| Qwen3VLForConditionalGeneration, |
| ) |
|
|
| from streaming_vlm.inference.qwen3.pos_emb import get_rope_index_streaming |
|
|
|
|
| @dataclass |
| class ModelArguments: |
| """Arguments for model configuration.""" |
| pretrained_model_name_or_path: str = field( |
| default="Qwen/Qwen3-VL-4B-Instruct", |
| metadata={"help": "Path to pretrained model or model identifier from HF Hub"} |
| ) |
| freeze_visual: bool = field( |
| default=True, |
| metadata={"help": "Freeze visual encoder during training"} |
| ) |
| attn_implementation: str = field( |
| default="sdpa", |
| metadata={"help": "Attention implementation: sdpa (ROCm), flash_attention_2 (CUDA), eager"} |
| ) |
|
|
|
|
| @dataclass |
| class DataArguments: |
| """Arguments for data configuration.""" |
| train_data_path: str = field( |
| default="", |
| metadata={"help": "Path to training data JSONL file"} |
| ) |
| eval_data_path: str = field( |
| default="", |
| metadata={"help": "Path to evaluation data JSONL file"} |
| ) |
| text_sink: int = field( |
| default=512, |
| metadata={"help": "Number of attention sink tokens"} |
| ) |
| text_sliding_window: int = field( |
| default=512, |
| metadata={"help": "Text sliding window size"} |
| ) |
| max_seq_length: int = field( |
| default=8192, |
| metadata={"help": "Maximum sequence length for training"} |
| ) |
|
|
|
|
| def main(): |
| parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) |
| model_args, data_args, training_args = parser.parse_args_into_dataclasses() |
| |
| |
| print(f"[Train] Loading model: {model_args.pretrained_model_name_or_path}") |
| print(f"[Train] Attention: {model_args.attn_implementation}") |
| |
| model = Qwen3VLForConditionalGeneration.from_pretrained( |
| model_args.pretrained_model_name_or_path, |
| torch_dtype=torch.bfloat16, |
| attn_implementation=model_args.attn_implementation, |
| ) |
| |
| |
| model.model.get_rope_index = MethodType(get_rope_index_streaming, model.model) |
| |
| |
| if model_args.freeze_visual: |
| print("[Train] Freezing visual encoder") |
| if hasattr(model.model, 'visual'): |
| model.model.visual.requires_grad_(False) |
| elif hasattr(model, 'visual'): |
| model.visual.requires_grad_(False) |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| print(f"[Train] Total params: {total_params:,}") |
| print(f"[Train] Trainable params: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)") |
| |
| |
| processor = AutoProcessor.from_pretrained( |
| model_args.pretrained_model_name_or_path, |
| ) |
| |
| |
| |
| |
| |
| print(f"[Train] Data path: {data_args.train_data_path}") |
| print(f"[Train] Text sink: {data_args.text_sink}") |
| print(f"[Train] Text window: {data_args.text_sliding_window}") |
| |
| |
| if not data_args.train_data_path: |
| print("[Train] ERROR: No training data path specified. Use --train_data_path") |
| print("[Train] Download data from: https://huggingface.co/datasets/mit-han-lab/Inf-Stream-Train") |
| sys.exit(1) |
| |
| |
| training_args.gradient_checkpointing = True |
| training_args.bf16 = True |
| training_args.remove_unused_columns = False |
| |
| print(f"[Train] Output dir: {training_args.output_dir}") |
| print(f"[Train] Learning rate: {training_args.learning_rate}") |
| print(f"[Train] Batch size: {training_args.per_device_train_batch_size}") |
| print(f"[Train] Gradient accumulation: {training_args.gradient_accumulation_steps}") |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| print("[Train] Training setup complete. Implement LMMDataset to start training.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|