""" verify_prismatic.py Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate(). """ import time import requests import torch from PIL import Image from transformers import AutoModelForVision2Seq, AutoProcessor # === Verification Arguments === MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b" DEFAULT_IMAGE_URL = ( "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png" ) if "-prism-" in MODEL_PATH: SAMPLE_PROMPTS_FOR_GENERATION = [ "In: What is sitting in the coffee?\nOut:", "In: What's the name of the food on the plate?\nOut:", "In: caption.\nOut:", "In: how many beinets..?\nOut:", "In: Can you give me a lyrical description of the scene\nOut:", ] else: SYSTEM_PROMPT = ( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ) SAMPLE_PROMPTS_FOR_GENERATION = [ f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:", f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:", f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:", f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:", f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:", ] @torch.inference_mode() def verify_prismatic() -> None: print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`") device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") # Load Processor & VLM print("[*] Instantiating Processor and Pretrained VLM") processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True) # === AUTOCAST MODE === # print("[*] Loading in BF16 Autocast Mode") # vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to( # device, dtype=torch.bfloat16 # ) # === NATIVE BFLOAT16 MODE === # print("[*] Loading in BF16") # vlm = AutoModelForVision2Seq.from_pretrained( # MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True # ).to(device) # === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] === print("[*] Loading in BF16 with Flash-Attention Enabled") vlm = AutoModelForVision2Seq.from_pretrained( MODEL_PATH, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True, ).to(device) # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] === # print("[*] Loading in 8-Bit Quantization Mode") # vlm = AutoModelForVision2Seq.from_pretrained( # MODEL_PATH, # attn_implementation="flash_attention_2", # torch_dtype=torch.float16, # quantization_config=BitsAndBytesConfig(load_in_8bit=True), # low_cpu_mem_usage=True, # trust_remote_code=True, # ) # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] === # print("[*] Loading in 4-Bit Quantization Mode") # vlm = AutoModelForVision2Seq.from_pretrained( # MODEL_PATH, # attn_implementation="flash_attention_2", # torch_dtype=torch.float16, # quantization_config=BitsAndBytesConfig(load_in_4bit=True), # low_cpu_mem_usage=True, # trust_remote_code=True, # ) # Iterate over Sample Prompts =>> Generate image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB") num_tokens, total_time = 0, 0.0 print("[*] Iterating over Sample Prompts\n===\n") for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION): # === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) === # inputs = processor(prompt, image).to(device) # # # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py` # # =>> Running in native BF16 is also fine (but leads to slightly different generations) # with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True): # gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) # === BFLOAT16 MODE === inputs = processor(prompt, image).to(device, dtype=torch.bfloat16) # === 8-BIT/4-BIT QUANTIZATION MODE === # inputs = processor(prompt, image).to(device, dtype=torch.float16) # Run Inference gen_ids = None for _ in range(5): start_time = time.time() gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512) total_time += time.time() - start_time gen_ids = gen_ids[0, inputs.input_ids.shape[1] :] num_tokens += len(gen_ids) # === gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip() print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n") # Compute Tokens / Second print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }") if __name__ == "__main__": verify_prismatic()