Optimizing Performance

#1
by triplet33 - opened

Hello, thank you for your work on releasing this, and so soon too! I was wondering if there is a way to speed up inferencing.

Right now it's more than 2x faster (image one: 1.13s, image two: 1.72s) than the safetensors version (image one: 2.32s, image two: 4.5s), which is amazing. However, I have a feeling this can go even faster...

import os

# do not make requests to HF
os.environ['HF_HUB_OFFLINE'] = '1'


import onnxruntime
import numpy as np
from transformers import AutoConfig, AutoProcessor, GenerationConfig
from huggingface_hub import snapshot_download
from PIL import Image
from time import perf_counter


model_id = "onnx-community/gemma-4-E2B-it-ONNX"
processor = AutoProcessor.from_pretrained(model_id)
config = AutoConfig.from_pretrained(model_id)
generation_config = GenerationConfig.from_pretrained(model_id)


vision_model = "onnx/vision_encoder_q4.onnx"
embed_model = "onnx/embed_tokens_q4.onnx"
decoder_model = "onnx/decoder_model_merged_q4.onnx"

model_dir = snapshot_download(model_id, allow_patterns=[
    f"{vision_model}*", 
    f"{embed_model}*",
    f"{decoder_model}*"
])

vision_model_path  = os.path.join(model_dir, vision_model)
embed_model_path   = os.path.join(model_dir, embed_model)
decoder_model_path = os.path.join(model_dir, decoder_model)

providers = [
        ## TensorrtExecutionProvider not working, I guess this will take time to support

        # ('TensorrtExecutionProvider', {
        #     'device_id': 0,
        #     'trt_fp16_enable': True,
        #     'trt_engine_cache_enable': True,
        #     'trt_engine_cache_path': model_dir,

        #     # This is pretty important shit
        #     # It will set the min and max number of images a batch can contain
        #     # 'trt_profile_min_shapes': 'input:1x448x448x3',
        #     # 'trt_profile_opt_shapes': 'input:2x448x448x3',
        #     # 'trt_profile_max_shapes': 'input:2x448x448x3',

        #     'trt_max_workspace_size': 1_073_741_824 * 2,

        #     # default is 3 range is [0-5]
        #     'trt_builder_optimization_level': 3,
        # }),

        ('CUDAExecutionProvider', {'device_id': 0})
    ]

vision_session  = onnxruntime.InferenceSession(vision_model_path, providers=providers)
embed_session   = onnxruntime.InferenceSession(embed_model_path, providers=providers)
decoder_session = onnxruntime.InferenceSession(decoder_model_path, providers=providers)


eos_token_id = generation_config.eos_token_id
image_token_id = config.image_token_id
audio_token_id = config.audio_token_id


while True:
    p = input('Image path: ')
    if not os.path.isfile(p):
        print(f'Could not find image: {p}')
        continue

    image = Image.open(p)

    start = perf_counter()
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": "Write single detailed caption for this image."},
            ],
        },
    ]
    inputs = processor.apply_chat_template(
        messages,
        tokenize=True,
        return_dict=True,
        return_tensors="pt",
        add_generation_prompt=True,
    )

    input_ids = inputs["input_ids"].numpy()
    attention_mask = inputs["attention_mask"].numpy()
    position_ids = np.cumsum(attention_mask, axis=-1) - 1

    pixel_values = inputs["pixel_values"].numpy() if "pixel_values" in inputs else None
    pixel_position_ids = inputs["image_position_ids"].numpy() if "image_position_ids" in inputs else None
    input_features = inputs["input_features"].numpy().astype(np.float32) if "input_features" in inputs else None
    input_features_mask = inputs["input_features_mask"].numpy() if "input_features_mask" in inputs else None


    batch_size = input_ids.shape[0]
    num_logits_to_keep = np.array(1, dtype=np.int64)
    past_key_values = {
        inp.name: np.zeros(
            [batch_size, inp.shape[1], 0, inp.shape[3]],
            dtype=np.float32 if inp.type == "tensor(float)" else np.float16,
        )
        for inp in decoder_session.get_inputs()
        if inp.name.startswith("past_key_values")
    }


    max_new_tokens = 512
    generated_tokens = np.array([[]], dtype=np.int64)
    image_features = None
    audio_features = None
    for i in range(max_new_tokens):
        inputs_embeds, per_layer_inputs = embed_session.run(None, {"input_ids": input_ids})

        if vision_session and image_features is None and pixel_values is not None:
            image_features = vision_session.run(["image_features"], {"pixel_values": pixel_values, "pixel_position_ids": pixel_position_ids})[0]
            mask = (input_ids == image_token_id).reshape(-1)
            flat_embeds = inputs_embeds.reshape(-1, inputs_embeds.shape[-1])
            flat_embeds[mask] = image_features
            inputs_embeds = flat_embeds.reshape(inputs_embeds.shape)

        logits, *present_key_values = decoder_session.run(None, dict(
            inputs_embeds=inputs_embeds,
            attention_mask=attention_mask,
            per_layer_inputs=per_layer_inputs,
            position_ids=position_ids,
            num_logits_to_keep=num_logits_to_keep,
            **past_key_values,
        ))

        input_ids = logits[:, -1].argmax(-1, keepdims=True)
        attention_mask = np.concatenate([attention_mask, np.ones_like(input_ids)], axis=-1)
        position_ids = position_ids[:, -1:] + 1
        for j, key in enumerate(past_key_values):
            past_key_values[key] = present_key_values[j]

        generated_tokens = np.concatenate([generated_tokens, input_ids], axis=-1)
        if np.isin(input_ids, eos_token_id).any():
            break

        # print(processor.decode(input_ids[0]), end="", flush=True)
    # print()

    result = processor.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    end = perf_counter()
    print(f'Time: {end-start:,.4f}s | Caption: ', result)
ONNX Community org

Hi πŸ‘‹ Yes, I'm aware you can specify different execution providers, but since this is hardware dependent, we use CPU (default) and don't enforce a default. e.g., mac does not support CUDAExecutionProvider.

My question was about optimizing inferencing for speed.

Does anyone know if this can be sped up at all? Maybe with shaping and decoding, or an optimal execution provider config.

Sign up or log in to comment