| import faulthandler |
| faulthandler.enable() |
| import sys |
| import os |
| os.environ["RKLLM_LOG_LEVEL"] = "1" |
| import ctypes |
| import argparse |
| import cv2 |
| import numpy as np |
| import ztu_somemodelruntime_rknnlite2 as ort |
| from rkllm_binding import ( |
| RKLLMRuntime, |
| RKLLMParam, |
| RKLLMInput, |
| RKLLMInferParam, |
| LLMCallState, |
| RKLLMInputType, |
| RKLLMInferMode, |
| RKLLMResult |
| ) |
|
|
| |
| IMAGE_HEIGHT = 448 |
| IMAGE_WIDTH = 448 |
| IMAGE_SEQ_LENGTH = 256 |
| MULTIMODAL_HIDDEN_DIM = 2048 |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
|
|
| def expand2square(img, background_color): |
| """ |
| Expand the image into a square and fill it with the specified background color. |
| """ |
| height, width, _ = img.shape |
| if width == height: |
| return img.copy() |
|
|
| size = max(width, height) |
| square_img = np.full((size, size, 3), background_color, dtype=np.uint8) |
|
|
| x_offset = (size - width) // 2 |
| y_offset = (size - height) // 2 |
|
|
| square_img[y_offset:y_offset+height, x_offset:x_offset+width] = img |
| return square_img |
|
|
| def llm_callback(result_ptr, userdata_ptr, state_enum): |
| """ |
| Callback function to handle LLM results. |
| """ |
| state = LLMCallState(state_enum) |
| result = result_ptr.contents |
|
|
| if state == LLMCallState.RKLLM_RUN_NORMAL: |
| if result.text: |
| print(result.text.decode('utf-8', errors='ignore'), end='', flush=True) |
| elif state == LLMCallState.RKLLM_RUN_FINISH: |
| print("\n", flush=True) |
| elif state == LLMCallState.RKLLM_RUN_ERROR: |
| print("\nrun error", flush=True) |
| |
| return 0 |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Run RKLLM visual language model inference based on the C++ example." |
| ) |
| parser.add_argument("image_path", type=str, help="Path to the input image.") |
| parser.add_argument("encoder_model_path", type=str, help="Path to the ONNX vision encoder model.") |
| parser.add_argument("llm_model_path", type=str, help="Path to the .rkllm language model.") |
| parser.add_argument("max_new_tokens", type=int, help="Maximum number of new tokens to generate.") |
| parser.add_argument("max_context_len", type=int, help="Maximum context length.") |
| |
| |
| |
| parser.add_argument("rknn_core_num", type=int, help="Sets the number of npu cores used in vision encoder.") |
|
|
| args = parser.parse_args() |
|
|
| |
| print("Initializing ONNX Runtime for vision encoder...") |
| try: |
| sess_options = ort.SessionOptions() |
| sess_options.intra_op_num_threads = args.rknn_core_num |
| ort_session = ort.InferenceSession(args.encoder_model_path, sess_options=sess_options) |
| except Exception as e: |
| print(f"Failed to load ONNX model: {e}") |
| sys.exit(1) |
| print("Vision encoder loaded successfully.") |
| |
| input_name = ort_session.get_inputs()[0].name |
| output_name = ort_session.get_outputs()[0].name |
| print(f"ONNX Input: {input_name}, ONNX Output: {output_name}") |
|
|
| |
| print("Initializing RKLLM Runtime...") |
| rk_llm = RKLLMRuntime() |
| param = rk_llm.create_default_param() |
|
|
| param.model_path = args.llm_model_path.encode('utf-8') |
| param.top_k = 1 |
| param.max_new_tokens = args.max_new_tokens |
| param.max_context_len = args.max_context_len |
| param.skip_special_token = True |
| param.img_start = b"<img>" |
| param.img_end = b"</img>\n" |
| param.img_content = b"" |
| param.extend_param.base_domain_id = 1 |
|
|
| try: |
| rk_llm.init(param, llm_callback) |
| print("RKLLM initialized successfully.") |
| except RuntimeError as e: |
| print(f"RKLLM init failed: {e}") |
| sys.exit(1) |
|
|
| |
| print("Preprocessing image...") |
| img = cv2.imread(args.image_path) |
| if img is None: |
| print(f"Failed to read image from {args.image_path}") |
| sys.exit(1) |
| |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
|
|
| background_color = (127.5, 127.5, 127.5) |
| square_img = expand2square(img, background_color) |
| resized_img = cv2.resize(square_img, (IMAGE_WIDTH, IMAGE_HEIGHT), interpolation=cv2.INTER_LINEAR) |
| |
| |
| input_tensor = resized_img.astype(np.float32) |
| |
| input_tensor = (input_tensor / 255.0 - IMAGENET_MEAN) / IMAGENET_STD |
| |
| input_tensor = np.transpose(input_tensor, (2, 0, 1)) |
| input_tensor = np.expand_dims(input_tensor, axis=0) |
|
|
| |
| print("Running vision encoder...") |
| import time |
| start_time = time.time() |
| try: |
| img_vec_output = ort_session.run([output_name], {input_name: input_tensor.astype(np.float32)})[0] |
| if img_vec_output.ndim != 3: |
| raise RuntimeError(f"Unexpected encoder output shape {img_vec_output.shape}, expected (batch, tokens, hidden)") |
| if img_vec_output.shape[-1] != MULTIMODAL_HIDDEN_DIM: |
| print(f"Warning: hidden dim {img_vec_output.shape[-1]} differs from expected {MULTIMODAL_HIDDEN_DIM}") |
| if img_vec_output.shape[1] != IMAGE_SEQ_LENGTH: |
| print(f"Warning: token count {img_vec_output.shape[1]} differs from expected {IMAGE_SEQ_LENGTH}") |
| elapsed_time = time.time() - start_time |
| print(f"视觉编码器推理耗时: {elapsed_time:.4f} 秒") |
| |
| img_vec = img_vec_output.flatten().astype(np.float32) |
|
|
| except Exception as e: |
| print(f"Failed to run vision encoder inference: {e}") |
| rk_llm.destroy() |
| sys.exit(1) |
| |
| print("Image encoded successfully.") |
|
|
| |
| rkllm_infer_params = RKLLMInferParam() |
| rkllm_infer_params.mode = RKLLMInferMode.RKLLM_INFER_GENERATE |
| rkllm_infer_params.keep_history = 1 |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| pre_input = [ |
| "<image>What is in the image?", |
| "<image>这张图片中有什么?" |
| ] |
| print("\n**********************可输入以下问题对应序号获取回答/或自定义输入********************\n") |
| for i, p in enumerate(pre_input): |
| print(f"[{i}] {p}") |
| print("\n*************************************************************************\n") |
|
|
| try: |
| while True: |
| print("\nuser: ", end="", flush=True) |
| input_str = sys.stdin.readline().strip() |
|
|
| if not input_str: |
| continue |
| if input_str == "exit": |
| break |
| if input_str == "clear": |
| try: |
| rk_llm.clear_kv_cache(keep_system_prompt=True) |
| print("KV cache cleared.") |
| except RuntimeError as e: |
| print(f"Failed to clear KV cache: {e}") |
| continue |
|
|
| try: |
| idx = int(input_str) |
| if 0 <= idx < len(pre_input): |
| input_str = pre_input[idx] |
| print(input_str) |
| except (ValueError, IndexError): |
| pass |
|
|
| rkllm_input = RKLLMInput() |
| rkllm_input.role = b"user" |
| |
| print("robot: ", end="", flush=True) |
|
|
| if "<image>" in input_str: |
| rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_MULTIMODAL |
| |
| |
| rkllm_input.multimodal_input.prompt = input_str.encode('utf-8') |
| rkllm_input.multimodal_input.image_embed = img_vec.ctypes.data_as(ctypes.POINTER(ctypes.c_float)) |
| rkllm_input.multimodal_input.n_image_tokens = img_vec_output.shape[1] |
| print("n_image_tokens: ", rkllm_input.multimodal_input.n_image_tokens) |
| rkllm_input.multimodal_input.n_image = 1 |
| rkllm_input.multimodal_input.image_height = IMAGE_HEIGHT |
| rkllm_input.multimodal_input.image_width = IMAGE_WIDTH |
| else: |
| rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_PROMPT |
| rkllm_input.prompt_input = input_str.encode('utf-8') |
|
|
| try: |
| rk_llm.run(rkllm_input, rkllm_infer_params) |
| except RuntimeError as e: |
| print(f"\nError during rkllm_run: {e}") |
|
|
| except KeyboardInterrupt: |
| print("\nExiting...") |
| finally: |
| print("Releasing resources...") |
| rk_llm.destroy() |
| print("RKLLM instance destroyed.") |
|
|
| if __name__ == "__main__": |
| main() |
|
|