| """ |
| materialize.py |
| |
| Factory class defining functions for instantiating various Training Strategies, supporting different VLMs, backbones, |
| and strategy configurations. |
| """ |
|
|
| from typing import Callable, Optional |
|
|
| import torch |
|
|
| from prismatic.models.vlms import PrismaticVLM |
| from prismatic.training.strategies import FSDPStrategy, TrainingStrategy |
|
|
| |
| TRAIN_STRATEGIES = { |
| "fsdp-shard-grad-op": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "shard-grad-op"}}, |
| "fsdp-full-shard": {"cls": FSDPStrategy, "kwargs": {"sharding_strategy": "full-shard"}}, |
| } |
|
|
|
|
| def get_train_strategy( |
| train_strategy: str, |
| vlm: PrismaticVLM, |
| device_id: int, |
| stage: str, |
| epochs: int, |
| max_steps: Optional[int], |
| global_batch_size: int, |
| per_device_batch_size: int, |
| learning_rate: float, |
| weight_decay: float, |
| max_grad_norm: float, |
| lr_scheduler_type: str, |
| warmup_ratio: float, |
| enable_gradient_checkpointing: bool = True, |
| enable_mixed_precision_training: bool = True, |
| reduce_in_full_precision: bool = False, |
| mixed_precision_dtype: torch.dtype = torch.bfloat16, |
| worker_init_fn: Optional[Callable[[int], None]] = None, |
| ) -> TrainingStrategy: |
| if train_strategy in TRAIN_STRATEGIES: |
| strategy_cfg = TRAIN_STRATEGIES[train_strategy] |
| strategy = strategy_cfg["cls"]( |
| vlm=vlm, |
| device_id=device_id, |
| stage=stage, |
| epochs=epochs, |
| max_steps=max_steps, |
| global_batch_size=global_batch_size, |
| per_device_batch_size=per_device_batch_size, |
| learning_rate=learning_rate, |
| weight_decay=weight_decay, |
| max_grad_norm=max_grad_norm, |
| lr_scheduler_type=lr_scheduler_type, |
| warmup_ratio=warmup_ratio, |
| enable_gradient_checkpointing=enable_gradient_checkpointing, |
| enable_mixed_precision_training=enable_mixed_precision_training, |
| reduce_in_full_precision=reduce_in_full_precision, |
| mixed_precision_dtype=mixed_precision_dtype, |
| worker_init_fn=worker_init_fn, |
| **strategy_cfg["kwargs"], |
| ) |
| return strategy |
| else: |
| raise ValueError(f"Train Strategy `{train_strategy}` is not supported!") |
|
|