s23deepak's picture
Add train.py and test_imports.py
86d7955 verified
"""
train.py — SFT training for StreamingVLM with Qwen3-VL on ROCm.
Trains the streaming capability using the overlapped-chunk, full-attention strategy
described in the StreamingVLM paper (Section 2.2).
Key changes from original:
- Uses Qwen3VLForConditionalGeneration (not Qwen2_5)
- attn_implementation="sdpa" for ROCm compatibility
- Updated for latest transformers API (4.57.0+)
- No flash-attn dependency
"""
import os
import sys
from types import MethodType
from dataclasses import dataclass, field
from typing import Optional, List
import torch
import transformers
from transformers import (
AutoProcessor,
HfArgumentParser,
Trainer,
TrainingArguments,
Qwen3VLForConditionalGeneration,
)
from streaming_vlm.inference.qwen3.pos_emb import get_rope_index_streaming
@dataclass
class ModelArguments:
"""Arguments for model configuration."""
pretrained_model_name_or_path: str = field(
default="Qwen/Qwen3-VL-4B-Instruct",
metadata={"help": "Path to pretrained model or model identifier from HF Hub"}
)
freeze_visual: bool = field(
default=True,
metadata={"help": "Freeze visual encoder during training"}
)
attn_implementation: str = field(
default="sdpa",
metadata={"help": "Attention implementation: sdpa (ROCm), flash_attention_2 (CUDA), eager"}
)
@dataclass
class DataArguments:
"""Arguments for data configuration."""
train_data_path: str = field(
default="",
metadata={"help": "Path to training data JSONL file"}
)
eval_data_path: str = field(
default="",
metadata={"help": "Path to evaluation data JSONL file"}
)
text_sink: int = field(
default=512,
metadata={"help": "Number of attention sink tokens"}
)
text_sliding_window: int = field(
default=512,
metadata={"help": "Text sliding window size"}
)
max_seq_length: int = field(
default=8192,
metadata={"help": "Maximum sequence length for training"}
)
def main():
parser = HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# === Load Model ===
print(f"[Train] Loading model: {model_args.pretrained_model_name_or_path}")
print(f"[Train] Attention: {model_args.attn_implementation}")
model = Qwen3VLForConditionalGeneration.from_pretrained(
model_args.pretrained_model_name_or_path,
torch_dtype=torch.bfloat16,
attn_implementation=model_args.attn_implementation, # "sdpa" for ROCm
)
# Patch RoPE for streaming (contiguous mode)
model.model.get_rope_index = MethodType(get_rope_index_streaming, model.model)
# Freeze visual encoder (standard practice — vision model already well-trained)
if model_args.freeze_visual:
print("[Train] Freezing visual encoder")
if hasattr(model.model, 'visual'):
model.model.visual.requires_grad_(False)
elif hasattr(model, 'visual'):
model.visual.requires_grad_(False)
# Print trainable parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"[Train] Total params: {total_params:,}")
print(f"[Train] Trainable params: {trainable_params:,} ({100*trainable_params/total_params:.1f}%)")
# === Load Processor ===
processor = AutoProcessor.from_pretrained(
model_args.pretrained_model_name_or_path,
)
# === Load Data ===
# TODO: Implement LMMDataset for Qwen3-VL format
# The dataset should load the streaming SFT data (overlapped chunks)
# from mit-han-lab/Inf-Stream-Train
print(f"[Train] Data path: {data_args.train_data_path}")
print(f"[Train] Text sink: {data_args.text_sink}")
print(f"[Train] Text window: {data_args.text_sliding_window}")
# Placeholder — integrate with actual dataset class
if not data_args.train_data_path:
print("[Train] ERROR: No training data path specified. Use --train_data_path")
print("[Train] Download data from: https://huggingface.co/datasets/mit-han-lab/Inf-Stream-Train")
sys.exit(1)
# === Training ===
training_args.gradient_checkpointing = True
training_args.bf16 = True
training_args.remove_unused_columns = False
print(f"[Train] Output dir: {training_args.output_dir}")
print(f"[Train] Learning rate: {training_args.learning_rate}")
print(f"[Train] Batch size: {training_args.per_device_train_batch_size}")
print(f"[Train] Gradient accumulation: {training_args.gradient_accumulation_steps}")
# TODO: Initialize Trainer with LMMDataset and custom data collator
# trainer = Trainer(
# model=model,
# args=training_args,
# train_dataset=train_dataset,
# data_collator=train_dataset.data_collator,
# processing_class=processor,
# )
# trainer.train()
print("[Train] Training setup complete. Implement LMMDataset to start training.")
if __name__ == "__main__":
main()