| import faulthandler |
| faulthandler.enable() |
|
|
| import os |
| os.environ["RKLLM_LOG_LEVEL"] = "1" |
| import numpy as np |
| import onnxruntime as real_ort |
| import ztu_somemodelruntime_rknnlite2 as ort |
| from tokenizers import Tokenizer |
| import cv2 |
| import tqdm |
| import time |
| import ctypes |
|
|
| from rkllm_binding import * |
|
|
| model_path = "." |
| onnx_model_path = f"{model_path}" |
| tokenizer = Tokenizer.from_file(f"{model_path}/tokenizer.json") |
| |
|
|
| image = None |
| prompt = "A stunning princess from kabul in red, white traditional clothing, blue eyes, brown hair" |
| mode = "t2i" |
|
|
| |
| |
| |
|
|
| tempature = 0.7 |
|
|
| |
| rkllm_result_data = { |
| 'hidden_states': None, |
| 'finished': False, |
| 'error': False |
| } |
|
|
| def rkllm_callback(result_ptr, userdata_ptr, state_enum): |
| """RKLLM 推理回调函数""" |
| global rkllm_result_data |
| |
| try: |
| state = LLMCallState(state_enum) |
| |
|
|
| if state == LLMCallState.RKLLM_RUN_FINISH: |
| rkllm_result_data['finished'] = True |
| print("RKLLM 推理完成") |
| return |
| elif state == LLMCallState.RKLLM_RUN_ERROR: |
| rkllm_result_data['error'] = True |
| rkllm_result_data['error_msg'] = "RKLLM 推理出错" |
| rkllm_result_data['finished'] = True |
| print("错误: RKLLM 推理出错") |
|
|
| |
| if not result_ptr: |
| print("警告: result_ptr 为空指针") |
| return |
| |
| result = result_ptr.contents |
| |
| if state == LLMCallState.RKLLM_RUN_NORMAL: |
| |
| if result.last_hidden_layer.hidden_states and result.last_hidden_layer.embd_size > 0: |
| |
| hidden_size = result.last_hidden_layer.embd_size |
| num_tokens = result.last_hidden_layer.num_tokens |
| |
| |
| |
| |
| hidden_array = np.ctypeslib.as_array( |
| result.last_hidden_layer.hidden_states, |
| shape=(num_tokens, hidden_size) |
| ).copy() |
| |
| rkllm_result_data['hidden_states'] = hidden_array |
| |
| rkllm_result_data['finished'] = True |
| return 1 |
| else: |
| print("警告: 没有获取到有效的 hidden states") |
|
|
| return 1 |
| except Exception as e: |
| print(f"回调函数异常: {e}") |
| rkllm_result_data['error'] = True |
| rkllm_result_data['error_msg'] = str(e) |
| rkllm_result_data['finished'] = True |
|
|
| |
|
|
| |
| |
| |
| vision_encoder = ort.InferenceSession(f"{onnx_model_path}/vision_encoder.rknn") |
|
|
| |
| print("初始化 RKLLM 语言模型...") |
| rkllm_runtime = RKLLMRuntime() |
| rkllm_params = rkllm_runtime.create_default_param() |
| rkllm_params.model_path = f"{model_path}/language_model.rkllm".encode('utf-8') |
| rkllm_params.max_context_len = 1024 |
| rkllm_params.max_new_tokens = 5 |
| |
| rkllm_params.skip_special_token = 0 |
| rkllm_params.extend_param.base_domain_id = 1 |
| rkllm_runtime.init(rkllm_params, rkllm_callback) |
|
|
| |
| |
| |
| lm_head = ort.InferenceSession(f"{onnx_model_path}/lm_head.onnx") |
| |
| |
| |
| gen_head = ort.InferenceSession(f"{onnx_model_path}/gen_head.onnx") |
| |
| |
| |
| gen_img_embeds = ort.InferenceSession(f"{onnx_model_path}/gen_img_embeds.onnx") |
| |
| |
| |
| text_embeds = real_ort.InferenceSession(f"{onnx_model_path}/embed_tokens.onnx") |
| |
| |
| |
| image_decode = ort.InferenceSession(f"{onnx_model_path}/image_decode.onnx") |
|
|
| |
| |
| if mode == "t2i": |
| input_str = f"""<|User|>: {prompt} |
| |
| <|Assistant|>:<begin_of_image>""" |
| else: |
| input_str = f"""You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language. |
| |
| <|User|>: <image_placeholder> |
| {prompt} |
| |
| <|Assistant|>:""" |
|
|
| |
|
|
| |
| input_str = input_str.replace("<image_placeholder>", "<image_placeholder>" * 576) |
| input = tokenizer.encode(input_str) |
| input_ids = np.array([input.ids], dtype=np.int64) |
| input_len = len(input.ids) |
| attention_mask = np.array([input.attention_mask], dtype=np.int64) |
| images_seq_mask = np.array([[1 if id == 100581 else 0 for id in input.ids]], dtype=np.bool_) |
| position_ids = np.expand_dims(np.arange(input_len), axis=0) |
| |
| if image: |
| img = cv2.imread(image) |
| if img is None: |
| raise ValueError(f"无法读取图片: {image}") |
| |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
| target_size = 384 |
| img = cv2.resize(img, (target_size, target_size), interpolation=cv2.INTER_LINEAR) |
| |
| img = img.astype(np.float32) * 0.00392156862745098 |
| |
| img = (img - np.array([0.5, 0.5, 0.5], dtype=np.float32)) / np.array([0.5, 0.5, 0.5], dtype=np.float32) |
| |
| |
| |
| img = img.transpose(2, 0, 1) |
| pixel_values = np.expand_dims(np.expand_dims(img, axis=0), axis=1) |
| images_emb_mask = np.ones((1, 1, 576), dtype=np.bool_) |
| else: |
| pixel_values = np.zeros((0, 0, 3, 384, 384), dtype=np.float32) |
| images_emb_mask = np.zeros((1, 0, 576), dtype=np.bool_) |
|
|
| |
| |
| text_inputs_embeds = text_embeds.run(None, {"input_ids": input_ids})[0] |
|
|
| |
| if image: |
| |
| vision_embeds = vision_encoder.run(None, {"pixel_values": pixel_values})[0] |
| |
| |
| image_token_positions = np.where(images_seq_mask[0])[0] |
| |
| |
| |
| for idx, pos in enumerate(image_token_positions): |
| if idx < vision_embeds.shape[1]: |
| text_inputs_embeds[0, pos, :] = vision_embeds[0, idx, :] |
|
|
| inputs_embeds = text_inputs_embeds |
|
|
| |
| |
| generated_tokens = [] |
|
|
| |
| rkllm_input = RKLLMInput() |
| rkllm_input.input_type = RKLLMInputType.RKLLM_INPUT_EMBED |
| embed_input = RKLLMEmbedInput() |
| infer_params = RKLLMInferParam() |
| infer_params.mode = RKLLMInferMode.RKLLM_INFER_GET_LAST_HIDDEN_LAYER |
| infer_params.keep_history = 1 |
|
|
| def run_rkllm_inference(inputs_embeds): |
| """使用 RKLLM 进行推理,输入 embedding,输出 hidden states""" |
| global rkllm_result_data |
| |
| |
| rkllm_result_data = { |
| 'hidden_states': None, |
| 'finished': False, |
| 'error': False |
| } |
| |
| |
| embed_flat = inputs_embeds.flatten().astype(np.float32) |
| embed_c_array = (ctypes.c_float * len(embed_flat))(*embed_flat) |
| embed_input.embed = embed_c_array |
| embed_input.n_tokens = inputs_embeds.shape[1] |
| |
| rkllm_input._union_data.embed_input = embed_input |
| |
| |
| rkllm_runtime.run(rkllm_input, infer_params) |
| |
| |
| while not rkllm_result_data['finished']: |
| time.sleep(0.001) |
| |
| if rkllm_result_data['error']: |
| raise RuntimeError("RKLLM 推理出错") |
| |
| return rkllm_result_data['hidden_states'] |
|
|
| |
| with tqdm.tqdm(range(576)) as pbar: |
| for i in pbar: |
| |
| hidden_states = run_rkllm_inference(inputs_embeds) |
| |
| if hidden_states is None: |
| raise RuntimeError("RKLLM 未返回有效的 hidden states") |
| |
| |
| if len(hidden_states.shape) == 2: |
| |
| hidden_states = hidden_states.reshape(1, hidden_states.shape[0], hidden_states.shape[1]) |
| |
| |
| hs = hidden_states[:, -1:, :] |
| |
| |
| logits = (gen_head if mode == "t2i" else lm_head).run(None, {"hidden_states": hs})[0] |
| logits = logits[:, -1, :] |
| |
| |
| logits = logits / tempature |
| |
| exp_logits = np.exp(logits - np.max(logits, axis=-1, keepdims=True))[0] |
| probs = exp_logits / np.sum(exp_logits, axis=-1, keepdims=True) |
| |
| probs = probs.astype(np.float64) |
| probs /= probs.sum() |
| next_token = int(np.random.multinomial(1, probs).argmax()) |
| pbar.set_postfix(next_token=tokenizer.decode([next_token])) |
| generated_tokens.append(next_token) |
| if next_token == 100001: |
| break |
| |
| |
| if mode == "t2i": |
| new_embed = gen_img_embeds.run(None, {"image_ids": np.array([[next_token]], dtype=np.int64)})[0] |
| else: |
| new_embed = text_embeds.run(None, {"input_ids": np.array([[next_token]], dtype=np.int64)})[0] |
| |
| |
| |
| inputs_embeds = new_embed |
|
|
| rkllm_runtime.clear_kv_cache(False) |
|
|
| |
| if mode == "t2i": |
| |
| generated_tokens_array = np.array([generated_tokens], dtype=np.int64) |
| decoded_image = image_decode.run(None, {"generated_tokens": generated_tokens_array})[0] |
| decoded_image = np.clip((decoded_image + 1) / 2 * 255, 0, 255) |
| |
| decoded_image = np.squeeze(decoded_image, axis=0) |
| decoded_image = np.transpose(decoded_image, (1, 2, 0)) |
| cv2.imwrite("generated.png", cv2.cvtColor(decoded_image, cv2.COLOR_RGB2BGR)) |
| print("(generated.png)") |
| else: |
| decoded_text = tokenizer.decode(generated_tokens) |
| print(f"{decoded_text}") |
|
|
| |
| print("清理 RKLLM 资源...") |
| rkllm_runtime.destroy() |
|
|