haofuly's picture
Add files using upload-large-folder tool
cf587f4 verified
raw
history blame
5.56 kB
"""
verify_prismatic.py
Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate().
"""
import time
import requests
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor
# === Verification Arguments ===
MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b"
DEFAULT_IMAGE_URL = (
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
)
if "-prism-" in MODEL_PATH:
SAMPLE_PROMPTS_FOR_GENERATION = [
"In: What is sitting in the coffee?\nOut:",
"In: What's the name of the food on the plate?\nOut:",
"In: caption.\nOut:",
"In: how many beinets..?\nOut:",
"In: Can you give me a lyrical description of the scene\nOut:",
]
else:
SYSTEM_PROMPT = (
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
)
SAMPLE_PROMPTS_FOR_GENERATION = [
f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:",
f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:",
f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:",
f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:",
f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:",
]
@torch.inference_mode()
def verify_prismatic() -> None:
print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`")
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load Processor & VLM
print("[*] Instantiating Processor and Pretrained VLM")
processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)
# === AUTOCAST MODE ===
# print("[*] Loading in BF16 Autocast Mode")
# vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to(
# device, dtype=torch.bfloat16
# )
# === NATIVE BFLOAT16 MODE ===
# print("[*] Loading in BF16")
# vlm = AutoModelForVision2Seq.from_pretrained(
# MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
# ).to(device)
# === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] ===
print("[*] Loading in BF16 with Flash-Attention Enabled")
vlm = AutoModelForVision2Seq.from_pretrained(
MODEL_PATH,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True,
).to(device)
# === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] ===
# print("[*] Loading in 8-Bit Quantization Mode")
# vlm = AutoModelForVision2Seq.from_pretrained(
# MODEL_PATH,
# attn_implementation="flash_attention_2",
# torch_dtype=torch.float16,
# quantization_config=BitsAndBytesConfig(load_in_8bit=True),
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# )
# === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] ===
# print("[*] Loading in 4-Bit Quantization Mode")
# vlm = AutoModelForVision2Seq.from_pretrained(
# MODEL_PATH,
# attn_implementation="flash_attention_2",
# torch_dtype=torch.float16,
# quantization_config=BitsAndBytesConfig(load_in_4bit=True),
# low_cpu_mem_usage=True,
# trust_remote_code=True,
# )
# Iterate over Sample Prompts =>> Generate
image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB")
num_tokens, total_time = 0, 0.0
print("[*] Iterating over Sample Prompts\n===\n")
for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION):
# === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) ===
# inputs = processor(prompt, image).to(device)
#
# # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py`
# # =>> Running in native BF16 is also fine (but leads to slightly different generations)
# with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
# gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
# === BFLOAT16 MODE ===
inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)
# === 8-BIT/4-BIT QUANTIZATION MODE ===
# inputs = processor(prompt, image).to(device, dtype=torch.float16)
# Run Inference
gen_ids = None
for _ in range(5):
start_time = time.time()
gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
total_time += time.time() - start_time
gen_ids = gen_ids[0, inputs.input_ids.shape[1] :]
num_tokens += len(gen_ids)
# ===
gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip()
print(f"[{idx + 1}] Input Prompt => {prompt}\n Generated => {gen_text}\n")
# Compute Tokens / Second
print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }")
if __name__ == "__main__":
verify_prismatic()