| """ |
| materialize.py |
| |
| Factory class for initializing Vision Backbones, LLM Backbones, and VLMs from a set registry; provides and exports |
| individual functions for clear control flow. |
| """ |
|
|
| from typing import Optional, Tuple |
|
|
| from transformers import PreTrainedTokenizerBase |
|
|
| from prismatic.models.backbones.llm import LLaMa2LLMBackbone, LLMBackbone, MistralLLMBackbone, PhiLLMBackbone |
| from prismatic.models.backbones.vision import ( |
| CLIPViTBackbone, |
| DinoCLIPViTBackbone, |
| DinoSigLIPViTBackbone, |
| DinoV2ViTBackbone, |
| ImageTransform, |
| IN1KViTBackbone, |
| SigLIPViTBackbone, |
| VisionBackbone, |
| ) |
| from prismatic.models.vlms import PrismaticVLM |
|
|
| |
| |
|
|
| |
| VISION_BACKBONES = { |
| |
| "clip-vit-l": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, |
| "siglip-vit-so400m": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, |
| "dinov2-vit-l": {"cls": DinoV2ViTBackbone, "kwargs": {"default_image_size": 224}}, |
| "in1k-vit-l": {"cls": IN1KViTBackbone, "kwargs": {"default_image_size": 224}}, |
| "dinosiglip-vit-so-224px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, |
|
|
| |
| "clip-vit-b": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 224}}, |
| "clip-vit-l-336px": {"cls": CLIPViTBackbone, "kwargs": {"default_image_size": 336}}, |
|
|
| |
| "siglip-vit-b16-224px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 224}}, |
| "siglip-vit-b16-256px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 256}}, |
| "siglip-vit-b16-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, |
| "siglip-vit-so400m-384px": {"cls": SigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, |
|
|
| |
| "dinoclip-vit-l-336px": {"cls": DinoCLIPViTBackbone, "kwargs": {"default_image_size": 336}}, |
| "dinosiglip-vit-so-384px": {"cls": DinoSigLIPViTBackbone, "kwargs": {"default_image_size": 384}}, |
| } |
|
|
|
|
| |
| LLM_BACKBONES = { |
| |
| "llama2-7b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, |
| "llama2-13b-pure": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, |
|
|
| |
| "llama2-7b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, |
| "llama2-13b-chat": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, |
|
|
| |
| "vicuna-v15-7b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, |
| "vicuna-v15-13b": {"cls": LLaMa2LLMBackbone, "kwargs": {}}, |
|
|
| |
| "mistral-v0.1-7b-pure": {"cls": MistralLLMBackbone, "kwargs": {}}, |
| "mistral-v0.1-7b-instruct": {"cls": MistralLLMBackbone, "kwargs": {}}, |
|
|
| |
| "phi-2-3b": {"cls": PhiLLMBackbone, "kwargs": {}}, |
| } |
|
|
| |
|
|
|
|
| def get_vision_backbone_and_transform( |
| vision_backbone_id: str, image_resize_strategy: str |
| ) -> Tuple[VisionBackbone, ImageTransform]: |
| """Instantiate a Vision Backbone, returning both the nn.Module wrapper class and default Image Transform.""" |
| if vision_backbone_id in VISION_BACKBONES: |
| vision_cfg = VISION_BACKBONES[vision_backbone_id] |
| vision_backbone: VisionBackbone = vision_cfg["cls"]( |
| vision_backbone_id, image_resize_strategy, **vision_cfg["kwargs"] |
| ) |
| image_transform = vision_backbone.get_image_transform() |
| return vision_backbone, image_transform |
|
|
| else: |
| raise ValueError(f"Vision Backbone `{vision_backbone_id}` is not supported!") |
|
|
|
|
| def get_llm_backbone_and_tokenizer( |
| llm_backbone_id: str, |
| llm_max_length: int = 2048, |
| hf_token: Optional[str] = None, |
| inference_mode: bool = False, |
| ) -> Tuple[LLMBackbone, PreTrainedTokenizerBase]: |
| if llm_backbone_id in LLM_BACKBONES: |
| llm_cfg = LLM_BACKBONES[llm_backbone_id] |
| llm_backbone: LLMBackbone = llm_cfg["cls"]( |
| llm_backbone_id, |
| llm_max_length=llm_max_length, |
| hf_token=hf_token, |
| inference_mode=inference_mode, |
| **llm_cfg["kwargs"], |
| ) |
| tokenizer = llm_backbone.get_tokenizer() |
| return llm_backbone, tokenizer |
|
|
| else: |
| raise ValueError(f"LLM Backbone `{llm_backbone_id}` is not supported!") |
|
|
|
|
| def get_vlm( |
| model_id: str, |
| arch_specifier: str, |
| vision_backbone: VisionBackbone, |
| llm_backbone: LLMBackbone, |
| enable_mixed_precision_training: bool = True, |
| ) -> PrismaticVLM: |
| """Lightweight wrapper around initializing a VLM, mostly for future-proofing (if one wants to add a new VLM).""" |
| return PrismaticVLM( |
| model_id, |
| vision_backbone, |
| llm_backbone, |
| enable_mixed_precision_training=enable_mixed_precision_training, |
| arch_specifier=arch_specifier, |
| ) |
|
|