| """
|
| verify_prismatic.py
|
|
|
| Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate().
|
| """
|
|
|
| import time
|
|
|
| import requests
|
| import torch
|
| from PIL import Image
|
| from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
|
|
|
| MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b"
|
| DEFAULT_IMAGE_URL = (
|
| "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
|
| )
|
|
|
| if "-prism-" in MODEL_PATH:
|
| SAMPLE_PROMPTS_FOR_GENERATION = [
|
| "In: What is sitting in the coffee?\nOut:",
|
| "In: What's the name of the food on the plate?\nOut:",
|
| "In: caption.\nOut:",
|
| "In: how many beinets..?\nOut:",
|
| "In: Can you give me a lyrical description of the scene\nOut:",
|
| ]
|
| else:
|
| SYSTEM_PROMPT = (
|
| "A chat between a curious user and an artificial intelligence assistant. "
|
| "The assistant gives helpful, detailed, and polite answers to the user's questions."
|
| )
|
| SAMPLE_PROMPTS_FOR_GENERATION = [
|
| f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:",
|
| f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:",
|
| f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:",
|
| f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:",
|
| f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:",
|
| ]
|
|
|
|
|
| @torch.inference_mode()
|
| def verify_prismatic() -> None:
|
| print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`")
|
| device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
|
|
| print("[*] Instantiating Processor and Pretrained VLM")
|
| processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| print("[*] Loading in BF16 with Flash-Attention Enabled")
|
| vlm = AutoModelForVision2Seq.from_pretrained(
|
| MODEL_PATH,
|
| attn_implementation="flash_attention_2",
|
| torch_dtype=torch.bfloat16,
|
| low_cpu_mem_usage=True,
|
| trust_remote_code=True,
|
| ).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB")
|
| num_tokens, total_time = 0, 0.0
|
|
|
| print("[*] Iterating over Sample Prompts\n===\n")
|
| for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
|
|
|
|
|
|
|
|
|
|
|
| gen_ids = None
|
| for _ in range(5):
|
| start_time = time.time()
|
| gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
|
| total_time += time.time() - start_time
|
|
|
| gen_ids = gen_ids[0, inputs.input_ids.shape[1] :]
|
| num_tokens += len(gen_ids)
|
|
|
|
|
| gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip()
|
| print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n")
|
|
|
|
|
| print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }")
|
|
|
|
|
| if __name__ == "__main__":
|
| verify_prismatic()
|
|
|