File size: 5,564 Bytes
cf587f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""

verify_prismatic.py



Given an HF-exported Prismatic model, attempt to load via AutoClasses, and verify forward() and generate().

"""

import time

import requests
import torch
from PIL import Image
from transformers import AutoModelForVision2Seq, AutoProcessor

# === Verification Arguments ===
MODEL_PATH = "TRI-ML/prismatic-siglip-224px-7b"
DEFAULT_IMAGE_URL = (
    "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
)

if "-prism-" in MODEL_PATH:
    SAMPLE_PROMPTS_FOR_GENERATION = [
        "In: What is sitting in the coffee?\nOut:",
        "In: What's the name of the food on the plate?\nOut:",
        "In: caption.\nOut:",
        "In: how many beinets..?\nOut:",
        "In: Can you give me a lyrical description of the scene\nOut:",
    ]
else:
    SYSTEM_PROMPT = (
        "A chat between a curious user and an artificial intelligence assistant. "
        "The assistant gives helpful, detailed, and polite answers to the user's questions."
    )
    SAMPLE_PROMPTS_FOR_GENERATION = [
        f"{SYSTEM_PROMPT} USER: What is sitting in the coffee? ASSISTANT:",
        f"{SYSTEM_PROMPT} USER: What's the name of the food on the plate? ASSISTANT:",
        f"{SYSTEM_PROMPT} USER: caption. ASSISTANT:",
        f"{SYSTEM_PROMPT} USER: how many beinets..? ASSISTANT:",
        f"{SYSTEM_PROMPT} USER: Can you give me a lyrical description of the scene ASSISTANT:",
    ]


@torch.inference_mode()
def verify_prismatic() -> None:
    print(f"[*] Verifying PrismaticForConditionalGeneration using Model `{MODEL_PATH}`")
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Load Processor & VLM
    print("[*] Instantiating Processor and Pretrained VLM")
    processor = AutoProcessor.from_pretrained(MODEL_PATH, trust_remote_code=True)

    # === AUTOCAST MODE ===
    # print("[*] Loading in BF16 Autocast Mode")
    # vlm = AutoModelForVision2Seq.from_pretrained(MODEL_PATH, low_cpu_mem_usage=True, trust_remote_code=True).to(
    #     device, dtype=torch.bfloat16
    # )

    # === NATIVE BFLOAT16 MODE ===
    # print("[*] Loading in BF16")
    # vlm = AutoModelForVision2Seq.from_pretrained(
    #     MODEL_PATH, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, trust_remote_code=True
    # ).to(device)

    # === BFLOAT16 + FLASH-ATTN MODE :: [~14GB of VRAM Passive || 18GB of VRAM Active] ===
    print("[*] Loading in BF16 with Flash-Attention Enabled")
    vlm = AutoModelForVision2Seq.from_pretrained(
        MODEL_PATH,
        attn_implementation="flash_attention_2",
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    ).to(device)

    # === 8-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~9GB of VRAM Passive || 10GB of VRAM Active] ===
    # print("[*] Loading in 8-Bit Quantization Mode")
    # vlm = AutoModelForVision2Seq.from_pretrained(
    #     MODEL_PATH,
    #     attn_implementation="flash_attention_2",
    #     torch_dtype=torch.float16,
    #     quantization_config=BitsAndBytesConfig(load_in_8bit=True),
    #     low_cpu_mem_usage=True,
    #     trust_remote_code=True,
    # )

    # === 4-BIT QUANTIZATION MODE (`pip install bitsandbytes`) :: [~6GB of VRAM Passive || 7GB of VRAM Active] ===
    # print("[*] Loading in 4-Bit Quantization Mode")
    # vlm = AutoModelForVision2Seq.from_pretrained(
    #     MODEL_PATH,
    #     attn_implementation="flash_attention_2",
    #     torch_dtype=torch.float16,
    #     quantization_config=BitsAndBytesConfig(load_in_4bit=True),
    #     low_cpu_mem_usage=True,
    #     trust_remote_code=True,
    # )

    # Iterate over Sample Prompts =>> Generate
    image = Image.open(requests.get(DEFAULT_IMAGE_URL, stream=True).raw).convert("RGB")
    num_tokens, total_time = 0, 0.0

    print("[*] Iterating over Sample Prompts\n===\n")
    for idx, prompt in enumerate(SAMPLE_PROMPTS_FOR_GENERATION):
        # === AUTOCAST MODE (Reproduces Prismatic `scripts/generate.py`) ===
        # inputs = processor(prompt, image).to(device)
        #
        # # Using "autocast" to evaluate bit-wise equivalence to `scripts/generate.py`
        # #   =>> Running in native BF16 is also fine (but leads to slightly different generations)
        # with torch.autocast("cuda", dtype=torch.bfloat16, enabled=True):
        #     gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)

        # === BFLOAT16 MODE ===
        inputs = processor(prompt, image).to(device, dtype=torch.bfloat16)

        # === 8-BIT/4-BIT QUANTIZATION MODE ===
        # inputs = processor(prompt, image).to(device, dtype=torch.float16)

        # Run Inference
        gen_ids = None
        for _ in range(5):
            start_time = time.time()
            gen_ids = vlm.generate(**inputs, do_sample=False, min_length=1, max_length=512)
            total_time += time.time() - start_time

            gen_ids = gen_ids[0, inputs.input_ids.shape[1] :]
            num_tokens += len(gen_ids)

        # ===
        gen_text = processor.decode(gen_ids, skip_special_tokens=True).strip()
        print(f"[{idx + 1}] Input Prompt => {prompt}\n    Generated    => {gen_text}\n")

    # Compute Tokens / Second
    print(f"[*] Generated Tokens per Second = {num_tokens / total_time} w/ {num_tokens = } and {total_time = }")


if __name__ == "__main__":
    verify_prismatic()