""" train.py — SFT training for StreamingVLM with Qwen3-VL on ROCm. Trains the streaming capability using the overlapped-chunk, full-attention strategy described in the StreamingVLM paper (Section 2.2). Key changes from original: - Uses Qwen3VLForConditionalGeneration (not Qwen2_5) - attn_implementation="sdpa" for ROCm compatibility - Updated for latest transformers API (4.57.0+) - No flash-attn dependency """ import os import sys from types import MethodType from dataclasses import dataclass, field from typing import Optional, List import torch import transformers from transformers import ( AutoProcessor, HfArgumentParser, Trainer, TrainingArguments, Qwen3VLForConditionalGeneration, ) from streaming_vlm.inference.qwen3.pos_emb import get_rope_index_streaming @dataclass class ModelArguments: """Arguments for model configuration.""" pretrained_model_name_or_path: str = field( default="Qwen/Qwen3-VL-4B-Instruct", metadata={"help": "Path to pretrained model or model identifier from HF Hub"} ) freeze_visual: bool = field( default=True, metadata={"help": "Freeze visual encoder during training"} ) attn_implementation: str = field( default="sdpa", metadata={"help": "Attention implementation: sdpa (ROCm), flash_attention_2 (CUDA), eager"} ) @dataclass class DataArguments: """Arguments for data configuration.""" train_data_path: str = field( default="", metadata={"help": "Path to training data JSONL file"} ) eval_data_path: str = field( default="", metadata={"help": "Path to evaluation data JSONL file"} ) text_sink: int = field( default=512, metadata={"help": "Number of attention sink tokens"} ) text_sliding_window: int = field( default=512, metadata={"help": "Text sliding window size"} ) max_seq_length: int = field( default=8192, metadata={"help": "Maximum sequence length for training"} ) def main(): parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) model_args, data_args, training_args = parser.parse_args_into_dataclasses() # === Load Model === print(f"[Train] Loading model: {model_args.pretrained_model_name_or_path}") print(f"[Train] Attention: {model_args.attn_implementation}") model = Qwen3VLForConditionalGeneration.from_pretrained( model_args.pretrained_model_name_or_path, torch_dtype=torch.bfloat16, attn_implementation=model_args.attn_implementation, # "sdpa" for ROCm ) # Patch RoPE for streaming (contiguous mode) model.model.get_rope_index = MethodType(get_rope_index_streaming, model.model) # Freeze visual encoder (standard practice — vision model already well-trained) if model_args.freeze_visual: print("[Train] Freezing visual encoder") if hasattr(model.model, 'visual'): model.model.visual.requires_grad_(False) elif hasattr(model, 'visual'): model.visual.requires_grad_(False) # Print trainable parameters total_params = sum(p.numel() for p in model.parameters()) trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"[Train] Total params: {total_params:,}") print(f"[Train] Trainable params: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)") # === Load Processor === processor = AutoProcessor.from_pretrained( model_args.pretrained_model_name_or_path, ) # === Load Data === # TODO: Implement LMMDataset for Qwen3-VL format # The dataset should load the streaming SFT data (overlapped chunks) # from mit-han-lab/Inf-Stream-Train print(f"[Train] Data path: {data_args.train_data_path}") print(f"[Train] Text sink: {data_args.text_sink}") print(f"[Train] Text window: {data_args.text_sliding_window}") # Placeholder — integrate with actual dataset class if not data_args.train_data_path: print("[Train] ERROR: No training data path specified. Use --train_data_path") print("[Train] Download data from: https://huggingface.co/datasets/mit-han-lab/Inf-Stream-Train") sys.exit(1) # === Training === training_args.gradient_checkpointing = True training_args.bf16 = True training_args.remove_unused_columns = False print(f"[Train] Output dir: {training_args.output_dir}") print(f"[Train] Learning rate: {training_args.learning_rate}") print(f"[Train] Batch size: {training_args.per_device_train_batch_size}") print(f"[Train] Gradient accumulation: {training_args.gradient_accumulation_steps}") # TODO: Initialize Trainer with LMMDataset and custom data collator # trainer = Trainer( # model=model, # args=training_args, # train_dataset=train_dataset, # data_collator=train_dataset.data_collator, # processing_class=processor, # ) # trainer.train() print("[Train] Training setup complete. Implement LMMDataset to start training.") if __name__ == "__main__": main()