InternVL3_5-2B-RKLLM / ztu_somemodelruntime_rknnlite2.py
happyme531's picture
Upload 11 files
7fc4eb4 verified
# 模块级常量和函数
from rknnlite.api import RKNNLite
import numpy as np
import os
import warnings
import logging
from typing import List, Dict, Union, Optional
try:
import onnxruntime as ort
HAS_ORT = True
except ImportError:
HAS_ORT = False
warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning)
# 配置日志
logger = logging.getLogger("somemodelruntime_rknnlite2")
logger.setLevel(logging.ERROR) # 默认只输出错误信息
if not logger.handlers:
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s'))
logger.addHandler(handler)
# ONNX Runtime日志级别到Python logging级别的映射
_LOGGING_LEVEL_MAP = {
0: logging.DEBUG, # Verbose
1: logging.INFO, # Info
2: logging.WARNING, # Warning
3: logging.ERROR, # Error
4: logging.CRITICAL # Fatal
}
# 检查环境变量中的日志级别设置
try:
env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL')
if env_log_level is not None:
log_level = int(env_log_level)
if log_level in _LOGGING_LEVEL_MAP:
logger.setLevel(_LOGGING_LEVEL_MAP[log_level])
logger.info(f"从环境变量设置日志级别: {log_level}")
else:
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数")
except ValueError:
logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数")
def set_default_logger_severity(level: int) -> None:
"""
Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal
Args:
level: 日志级别(0-4)
"""
if level not in _LOGGING_LEVEL_MAP:
raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数")
logger.setLevel(_LOGGING_LEVEL_MAP[level])
def set_default_logger_verbosity(level: int) -> None:
"""
Sets the default logging verbosity level. To activate the verbose log,
you need to set the default logging severity to 0:Verbose level.
Args:
level: 日志级别(0-4)
"""
set_default_logger_severity(level)
# RKNN tensor type到numpy dtype的映射
RKNN_DTYPE_MAP = {
0: np.float32, # RKNN_TENSOR_FLOAT32
1: np.float16, # RKNN_TENSOR_FLOAT16
2: np.int8, # RKNN_TENSOR_INT8
3: np.uint8, # RKNN_TENSOR_UINT8
4: np.int16, # RKNN_TENSOR_INT16
5: np.uint16, # RKNN_TENSOR_UINT16
6: np.int32, # RKNN_TENSOR_INT32
7: np.uint32, # RKNN_TENSOR_UINT32
8: np.int64, # RKNN_TENSOR_INT64
9: bool, # RKNN_TENSOR_BOOL
10: np.int8, # RKNN_TENSOR_INT4 (用int8表示)
}
def get_available_providers() -> List[str]:
"""
获取可用的设备提供者列表(为保持接口兼容性的占位函数)
Returns:
list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
"""
return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"]
def get_device() -> str:
"""
获取当前设备
Returns:
str: 当前设备
"""
return "RKNN2"
def get_version_info() -> Dict[str, str]:
"""
获取版本信息
Returns:
dict: 包含API和驱动版本信息的字典
"""
runtime = RKNNLite()
version = runtime.get_sdk_version()
return {
"api_version": version.split('\n')[2].split(': ')[1].split(' ')[0],
"driver_version": version.split('\n')[3].split(': ')[1]
}
class IOTensor:
"""输入/输出张量的信息封装类"""
def __init__(self, name, shape, type=None):
self.name = name.decode() if isinstance(name, bytes) else name
self.shape = shape
self.type = type
def __str__(self):
return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})"
class SessionOptions:
"""会话选项类"""
def __init__(self):
self.enable_profiling = False # 是否使用性能分析
self.intra_op_num_threads = 1 # 设置RKNN的线程数, 对应rknn的core_mask
self.log_severity_level = -1 # 另一个设置日志级别的参数
self.log_verbosity_level = -1 # 另一个设置日志级别的参数
class InferenceSession:
"""
RKNNLite运行时封装类,API风格类似ONNX Runtime
"""
def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
processed_path = InferenceSession._process_model_path(model_path, sess_options)
if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'):
logger.info("使用ONNX Runtime加载模型")
if not HAS_ORT:
raise RuntimeError("未安装onnxruntime,无法加载ONNX模型")
return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs)
else:
# 如果不是 ONNX 模型,则调用父类的 __new__ 创建 InferenceSession 实例
instance = super().__new__(cls)
# 保存处理后的路径
instance._processed_path = processed_path
return instance
def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs):
"""
初始化运行时并加载模型
Args:
model_path: 模型文件路径(.rknn或.onnx)
sess_options: 会话选项
**kwargs: 其他初始化参数
"""
options = sess_options or SessionOptions()
# 只在未设置环境变量时使用SessionOptions中的日志级别
if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None:
if options.log_severity_level != -1:
set_default_logger_severity(options.log_severity_level)
if options.log_verbosity_level != -1:
set_default_logger_verbosity(options.log_verbosity_level)
# 使用__new__中处理好的路径
model_path = getattr(self, '_processed_path', model_path)
if isinstance(model_path, str) and model_path.lower().endswith('.onnx'):
# 避免重复加载 ONNX 模型
return
# ... 现有的 RKNN 模型加载和初始化代码 ...
self.model_path = model_path
if not os.path.exists(self.model_path):
logger.error(f"模型文件不存在: {self.model_path}")
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
self.runtime = RKNNLite(verbose=options.enable_profiling)
logger.debug(f"正在加载模型: {self.model_path}")
ret = self.runtime.load_rknn(self.model_path)
if ret != 0:
logger.error(f"加载RKNN模型失败: {self.model_path}")
raise RuntimeError(f'加载RKNN模型失败: {self.model_path}')
logger.debug("模型加载成功")
if options.intra_op_num_threads == 1:
core_mask = RKNNLite.NPU_CORE_AUTO
elif options.intra_op_num_threads == 2:
core_mask = RKNNLite.NPU_CORE_0_1
elif options.intra_op_num_threads == 3:
core_mask = RKNNLite.NPU_CORE_0_1_2
else:
raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3")
logger.debug("正在初始化运行时环境")
ret = self.runtime.init_runtime(core_mask=core_mask)
if ret != 0:
logger.error("初始化运行时环境失败")
raise RuntimeError('初始化运行时环境失败')
logger.debug("运行时环境初始化成功")
# 在 runtime 初始化后,按环境变量自动注册自定义算子插件库
try:
# 注册用户指定路径插件(逗号/分号分隔)
env_custom = os.getenv('ZTU_MODELRT_RKNN2_REG_CUSTOM_OP_LIB', '').strip()
if env_custom:
paths = [seg.strip() for seg in re.split(r"[,;:]", env_custom) if seg.strip()]
ok = 0
for p in paths:
if self.register_custom_op_lib(p):
ok += 1
if ok > 0:
logger.info(f"已注册 {ok}/{len(paths)} 个自定义算子插件")
# 注册系统目录下插件
if os.getenv('ZTU_MODELRT_RKNN2_REG_SYSTEM_CUSTOM_OP_LIB', '1') == '1':
cnt = self.register_system_custom_op_lib()
if cnt > 0:
logger.info(f"已从系统目录注册 {cnt} 个自定义算子插件")
except Exception as e:
logger.warning(f"自动注册自定义算子插件失败: {e}")
# 可选:按环境变量注册内置(基于Python)捆绑算子
if os.getenv('ZTU_MODELRT_RKNN2_REG_BUNDLED_OPS', '0') == '1':
logger.info("根据环境变量注册捆绑算子")
self.register_bundled_ops()
self._init_io_info()
self.options = options
def get_performance_info(self) -> Dict[str, float]:
"""
获取性能信息
Returns:
dict: 包含性能信息的字典
"""
if not self.options.perf_debug:
raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True")
perf = self.runtime.rknn_runtime.get_run_perf()
return {
"run_duration": perf.run_duration / 1000.0 # 转换为毫秒
}
def set_core_mask(self, core_mask: int) -> None:
"""
设置NPU核心使用模式
Args:
core_mask: NPU核心掩码,使用NPU_CORE_*常量
"""
ret = self.runtime.rknn_runtime.set_core_mask(core_mask)
if ret != 0:
raise RuntimeError("设置NPU核心模式失败")
@staticmethod
def _process_model_path(model_path, sess_options):
"""
处理模型路径,支持.onnx和.rknn文件
Args:
model_path: 模型文件路径
"""
# 如果是ONNX文件,检查是否需要自动加载RKNN
if model_path.lower().endswith('.onnx'):
logger.info("检测到ONNX模型文件")
# 获取需要跳过自动加载的模型列表
skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip()
if skip_models:
skip_list = [m.strip() for m in skip_models.split(',')]
# 获取模型文件名(不含路径)用于匹配
model_name = os.path.basename(model_path)
if model_name.lower() in [m.lower() for m in skip_list]:
logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime")
return model_path
# 构造RKNN文件路径
rknn_path = os.path.splitext(model_path)[0] + '.rknn'
if os.path.exists(rknn_path):
logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}")
return rknn_path
else:
logger.info("未找到对应的RKNN模型,将使用ONNX Runtime")
return model_path
return model_path
def _convert_nhwc_to_nchw(self, shape):
"""将NHWC格式的shape转换为NCHW格式"""
if len(shape) == 4:
# NHWC -> NCHW
n, h, w, c = shape
return [n, c, h, w]
return shape
def _init_io_info(self):
"""初始化模型的输入输出信息"""
runtime = self.runtime.rknn_runtime
# 获取输入输出数量
n_input, n_output = runtime.get_in_out_num()
# 获取输入信息
self.input_tensors = []
for i in range(n_input):
attr = runtime.get_tensor_attr(i)
shape = [attr.dims[j] for j in range(attr.n_dims)]
# 对四维输入进行NHWC到NCHW的转换
shape = self._convert_nhwc_to_nchw(shape)
# 获取dtype
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
tensor = IOTensor(attr.name, shape, dtype)
self.input_tensors.append(tensor)
# 获取输出信息
self.output_tensors = []
for i in range(n_output):
attr = runtime.get_tensor_attr(i, is_output=True)
shape = runtime.get_output_shape(i)
# 获取dtype
dtype = RKNN_DTYPE_MAP.get(attr.type, None)
tensor = IOTensor(attr.name, shape, dtype)
self.output_tensors.append(tensor)
def get_inputs(self):
"""
获取模型输入信息
Returns:
list: 包含输入信息的列表
"""
return self.input_tensors
def get_outputs(self):
"""
获取模型输出信息
Returns:
list: 包含输出信息的列表
"""
return self.output_tensors
def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs):
"""
执行模型推理
Args:
output_names: 输出节点名称列表,指定需要返回哪些输出
input_feed: 输入数据字典或列表
data_format: 输入数据格式,"nchw"或"nhwc"
**kwargs: 其他运行时参数
Returns:
list: 模型输出结果列表,如果指定了output_names则只返回指定的输出
"""
if input_feed is None:
logger.error("input_feed不能为None")
raise ValueError("input_feed不能为None")
# 准备输入数据
if isinstance(input_feed, dict):
# 如果是字典,按照模型输入顺序排列
inputs = []
input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)}
for tensor in self.input_tensors:
if tensor.name not in input_feed:
raise ValueError(f"缺少输入: {tensor.name}")
inputs.append(input_feed[tensor.name])
elif isinstance(input_feed, (list, tuple)):
# 如果是列表,确保长度匹配
if len(input_feed) != len(self.input_tensors):
raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}")
inputs = list(input_feed)
else:
logger.error("input_feed必须是字典或列表类型")
raise ValueError("input_feed必须是字典或列表类型")
# 执行推理
try:
logger.debug("开始执行推理")
all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format)
# 如果没有指定output_names,返回所有输出
if output_names is None:
return all_outputs
# 获取指定的输出
output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)}
selected_outputs = []
for name in output_names:
if name not in output_map:
raise ValueError(f"未找到输出节点: {name}")
selected_outputs.append(all_outputs[output_map[name]])
return selected_outputs
except Exception as e:
logger.error(f"推理执行失败: {str(e)}")
raise RuntimeError(f"推理执行失败: {str(e)}")
def close(self):
"""
关闭会话,释放资源
"""
if self.runtime is not None:
logger.info("正在释放运行时资源")
self.runtime.release()
self.runtime = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.close()
def end_profiling(self) -> Optional[str]:
"""
结束性能分析的存根方法
Returns:
Optional[str]: None
"""
warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return None
def get_profiling_start_time_ns(self) -> int:
"""
获取性能分析开始时间的存根方法
Returns:
int: 0
"""
warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return 0
def get_modelmeta(self) -> Dict[str, str]:
"""
获取模型元数据的存根方法
Returns:
Dict[str, str]: 空字典
"""
warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return {}
def get_session_options(self) -> SessionOptions:
"""
获取会话选项
Returns:
SessionOptions: 当前会话选项
"""
return self.options
def get_providers(self) -> List[str]:
"""
获取当前使用的providers的存根方法
Returns:
List[str]: ["CPUExecutionProvider"]
"""
warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2)
return ["CPUExecutionProvider"]
def get_provider_options(self) -> Dict[str, Dict[str, str]]:
"""
获取provider选项的存根方法
Returns:
Dict[str, Dict[str, str]]: 空字典
"""
warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return {}
def get_session_config(self) -> Dict[str, str]:
"""
获取会话配置的存根方法
Returns:
Dict[str, str]: 空字典
"""
warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return {}
def get_session_state(self) -> Dict[str, str]:
"""
获取会话状态的存根方法
Returns:
Dict[str, str]: 空字典
"""
warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return {}
def set_session_config(self, config: Dict[str, str]) -> None:
"""
设置会话配置的存根方法
Args:
config: 会话配置字典
"""
warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
def get_memory_info(self) -> Dict[str, int]:
"""
获取内存使用信息的存根方法
Returns:
Dict[str, int]: 空字典
"""
warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return {}
def set_memory_pattern(self, enable: bool) -> None:
"""
设置内存模式的存根方法
Args:
enable: 是否启用内存模式
"""
warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
def disable_memory_pattern(self) -> None:
"""
禁用内存模式的存根方法
"""
warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
def get_optimization_level(self) -> int:
"""
获取优化级别的存根方法
Returns:
int: 0
"""
warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return 0
def set_optimization_level(self, level: int) -> None:
"""
设置优化级别的存根方法
Args:
level: 优化级别
"""
warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
def get_model_metadata(self) -> Dict[str, str]:
"""
获取模型元数据的存根方法(与get_modelmeta不同的接口)
Returns:
Dict[str, str]: 空字典
"""
warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return {}
def get_model_path(self) -> str:
"""
获取模型路径
Returns:
str: 模型文件路径
"""
return self.model_path
def get_input_type_info(self) -> List[Dict[str, str]]:
"""
获取输入类型信息的存根方法
Returns:
List[Dict[str, str]]: 空列表
"""
warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return []
def get_output_type_info(self) -> List[Dict[str, str]]:
"""
获取输出类型信息的存根方法
Returns:
List[Dict[str, str]]: 空列表
"""
warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2)
return []
################### 自定义算子 ###################
def _init_custom_op_types(self):
"""初始化自定义算子的类型定义"""
# 常量
self._RKNN_TENSOR_FLOAT32 = 0
self._RKNN_TENSOR_UINT8 = 3
self._RKNN_TENSOR_INT64 = 8
self._RKNN_TARGET_TYPE_CPU = 1
# 结构体定义
class RKNN_TensorAttr(ctypes.Structure):
_fields_ = [
("index", ctypes.c_uint32),
("n_dims", ctypes.c_uint32),
("dims", ctypes.c_uint32 * RKNN_MAX_DIMS),
("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
("n_elems", ctypes.c_uint32),
("size", ctypes.c_uint32),
("fmt", ctypes.c_int),
("type", ctypes.c_int),
("qnt_type", ctypes.c_int),
("fl", ctypes.c_int8),
("zp", ctypes.c_int32),
("scale", ctypes.c_float),
("w_stride", ctypes.c_uint32),
("size_with_stride", ctypes.c_uint32),
("pass_through", ctypes.c_uint8),
("h_stride", ctypes.c_uint32),
]
class RKNN_TensorMem(ctypes.Structure):
_fields_ = [
("virt_addr", ctypes.c_void_p),
("phys_addr", ctypes.c_uint64),
("fd", ctypes.c_int32),
("offset", ctypes.c_int32),
("size", ctypes.c_uint32),
("flags", ctypes.c_uint32),
("priv_data", ctypes.c_void_p),
]
class RKNN_CustomOpTensor(ctypes.Structure):
_fields_ = [
("attr", RKNN_TensorAttr),
("mem", RKNN_TensorMem),
]
class RKNN_GPUOpContext(ctypes.Structure):
_fields_ = [
("cl_context", ctypes.c_void_p),
("cl_command_queue", ctypes.c_void_p),
("cl_kernel", ctypes.c_void_p),
]
InternalCtxType = (
ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
)
class RKNN_CustomOpContext(ctypes.Structure):
_fields_ = [
("target", ctypes.c_int),
("internal_ctx", InternalCtxType),
("gpu_ctx", RKNN_GPUOpContext),
("priv_data", ctypes.c_void_p),
]
class RKNN_CustomOpAttr(ctypes.Structure):
_fields_ = [
("name", ctypes.c_char * RKNN_MAX_NAME_LEN),
("dtype", ctypes.c_int),
("n_elems", ctypes.c_uint32),
("data", ctypes.c_void_p),
]
CB_SIG = ctypes.CFUNCTYPE(
ctypes.c_int,
ctypes.POINTER(RKNN_CustomOpContext),
ctypes.POINTER(RKNN_CustomOpTensor),
ctypes.c_uint32,
ctypes.POINTER(RKNN_CustomOpTensor),
ctypes.c_uint32,
)
DESTROY_SIG = ctypes.CFUNCTYPE(
ctypes.c_int, ctypes.POINTER(RKNN_CustomOpContext)
)
class RKNN_CustomOp(ctypes.Structure):
_fields_ = [
("version", ctypes.c_uint32),
("target", ctypes.c_int),
("op_type", ctypes.c_char * RKNN_MAX_NAME_LEN),
("cl_kernel_name", ctypes.c_char * RKNN_MAX_NAME_LEN),
("cl_kernel_source", ctypes.c_char_p),
("cl_source_size", ctypes.c_uint64),
("cl_build_options", ctypes.c_char * RKNN_MAX_NAME_LEN),
("init", CB_SIG),
("prepare", CB_SIG),
("compute", CB_SIG),
("compute_native", CB_SIG),
("destroy", DESTROY_SIG),
]
# 保存类型定义
self._RKNN_TensorAttr = RKNN_TensorAttr
self._RKNN_TensorMem = RKNN_TensorMem
self._RKNN_CustomOpTensor = RKNN_CustomOpTensor
self._RKNN_CustomOpContext = RKNN_CustomOpContext
self._RKNN_CustomOpAttr = RKNN_CustomOpAttr
self._RKNN_CustomOp = RKNN_CustomOp
self._CB_SIG = CB_SIG
self._DESTROY_SIG = DESTROY_SIG
def _create_attr_readers(self, get_op_attr):
"""创建属性读取函数"""
def read_attr_int64(op_ctx_ptr, key: str, default: int = 0) -> int:
attr = self._RKNN_CustomOpAttr()
get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_INT64 and attr.data:
return ctypes.c_int64.from_address(attr.data).value
return default
def read_attr_float32(op_ctx_ptr, key: str, default: float = 0) -> float:
attr = self._RKNN_CustomOpAttr()
get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_FLOAT32 and attr.data:
return ctypes.c_float.from_address(attr.data).value
return default
def read_attr_str(op_ctx_ptr, key: str, default: str = "") -> str:
attr = self._RKNN_CustomOpAttr()
get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr))
if attr.n_elems > 0 and attr.dtype == self._RKNN_TENSOR_UINT8 and attr.data:
buf = (ctypes.c_ubyte * attr.n_elems).from_address(attr.data)
try:
return bytes(buf).decode("utf-8", errors="ignore").strip('"')
except Exception:
return default
return default
return read_attr_int64, read_attr_str, read_attr_float32
def _build_py_custom_op(self,
op_type: str,
n_inputs: int,
n_outputs: int,
on_init,
on_compute):
"""通用的Python自定义算子构造器
Args:
op_type: 算子类型名(字符串)
n_inputs: 输入个数
n_outputs: 输出个数
on_init: 回调,签名 on_init(op_ctx_p, read_attr_int64, read_attr_str) -> state
on_compute: 回调,签名 on_compute(op_ctx_p, inputs_p, outputs_p, state) -> int(0成功)
Returns:
(RKNN_CustomOp对象, 回调tuple)
"""
@self._CB_SIG
def _py_init(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
try:
# 允许无需提前读取属性
runtime = self.runtime.rknn_base.rknn_runtime
read_attr_int64, read_attr_str, read_attr_float32 = self._create_attr_readers(runtime.lib.rknn_custom_op_get_op_attr)
user_state = on_init(op_ctx_p, read_attr_int64, read_attr_str, read_attr_float32)
# 为该实例分配唯一ID, 并写入priv_data
if not hasattr(self, "_custom_op_states"):
self._custom_op_states = {}
if not hasattr(self, "_next_custom_op_id"):
self._next_custom_op_id = 1
inst_id = int(self._next_custom_op_id)
self._next_custom_op_id += 1
# 保存Python侧状态
self._custom_op_states[inst_id] = user_state
# 将实例ID写入priv_data
try:
op_ctx_p.contents.priv_data = ctypes.c_void_p(inst_id)
except Exception:
# 回退: 直接写入整数
op_ctx_p.contents.priv_data = inst_id
return 0
except Exception as e:
logger.error(f"{op_type} init失败: {e}")
return -1
@self._CB_SIG
def _py_prepare(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
return 0
@self._CB_SIG
def _py_compute(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p):
try:
if n_inputs_p != n_inputs or n_outputs_p != n_outputs:
return -1
# 通过priv_data取回该实例的状态
try:
inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
except Exception:
inst_id = 0
user_state = None
if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
user_state = self._custom_op_states.get(inst_id)
else:
logger.error(f"{op_type} compute失败: 找不到实例状态, inst_id={inst_id}")
return -1
return on_compute(op_ctx_p, inputs_p, outputs_p, user_state)
except Exception as e:
logger.error(f"{op_type} compute失败: {e}")
import traceback
logger.error(f"{op_type} compute失败: {traceback.format_exc()}")
return -1
@self._DESTROY_SIG
def _py_destroy(op_ctx_p):
try:
# 清理该实例的状态
try:
inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0
except Exception:
inst_id = 0
if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states:
del self._custom_op_states[inst_id]
# 将priv_data清空
try:
op_ctx_p.contents.priv_data = ctypes.c_void_p(0)
except Exception:
op_ctx_p.contents.priv_data = 0
return 0
except Exception:
return -1
op = self._RKNN_CustomOp()
op.version = 1
op.target = self._RKNN_TARGET_TYPE_CPU
op.op_type = op_type.encode("utf-8")
op.cl_kernel_name = b""
op.cl_kernel_source = None
op.cl_source_size = 0
op.cl_build_options = b""
op.init = _py_init
op.prepare = _py_prepare
op.compute = _py_compute
op.compute_native = self._CB_SIG() # NULL
op.destroy = _py_destroy
return op, (_py_init, _py_prepare, _py_compute, _py_destroy)
def _tensor_to_numpy(self, rknn_tensor):
"""将 RKNN_CustomOpTensor 转换为 Numpy 数组视图"""
# 确定Numpy数据类型
# 您可以扩展这个映射
dtype_map = {
self._RKNN_TENSOR_FLOAT32: (ctypes.c_float, np.float32),
self._RKNN_TENSOR_UINT8: (ctypes.c_uint8, np.uint8),
self._RKNN_TENSOR_INT64: (ctypes.c_int64, np.int64),
}
c_type, np_dtype = dtype_map.get(rknn_tensor.attr.type, (None, None))
if c_type is None:
raise TypeError(f"不支持的RKNN张量类型: {rknn_tensor.attr.type}")
# 获取内存地址和形状
addr = (rknn_tensor.mem.virt_addr or 0) + int(rknn_tensor.mem.offset)
ptr = ctypes.cast(addr, ctypes.POINTER(c_type))
shape = tuple(rknn_tensor.attr.dims[i] for i in range(rknn_tensor.attr.n_dims))
# 创建Numpy数组视图
return np.ctypeslib.as_array(ptr, shape=shape)
def _create_onnxscript_op_creator(self,
op_type: str,
# 现在接收一个"函数模板构造器"
onnxscript_func_builder,
n_inputs: int,
n_outputs: int,
attributes: dict = {},
constants: dict = {}):
"""
一个高阶工厂函数,用于创建基于ONNXScript的自定义算子构造器。
它在 on_init 阶段动态生成最终的 onnxscript 计算函数。
Args:
op_type (str): 算子类型名。
onnxscript_func_builder: 一个函数,它接收所有属性和常量作为关键字参数,
并返回一个编译好的 onnxscript 函数。
例如: def builder(mean, scale):
@onnxscript.script()
def compute(like):
return opset.RandomNormalLike(like, mean=mean, scale=scale)
return compute
attributes (dict): 从模型中读取的属性字典。
constants (dict): 编译时常量字典。
n_inputs (int): 输入个数。
n_outputs (int): 输出个数。
"""
def creator_func():
def on_init(op_ctx_p, read_i64, read_s, read_f32):
# 1. 读取所有动态属性
attr_values = {}
for name, (attr_type, default) in attributes.items():
if attr_type == 'int64':
attr_values[name] = read_i64(op_ctx_p, name, default)
elif attr_type == 'str':
attr_values[name] = read_s(op_ctx_p, name, default)
elif attr_type == 'float32':
attr_values[name] = read_f32(op_ctx_p, name, default)
else:
raise ValueError(f"不支持的属性类型: {attr_type}")
# 2. 合并常量和属性
final_kwargs = {**constants, **attr_values}
# 3. 动态构建 onnxscript 函数! <<<<< 核心修改
# 这确保了所有属性值都作为常量被闭包捕获
compute_func = onnxscript_func_builder(**final_kwargs)
# 4. 将最终生成的、已编译的函数存入 state
return {"compute_func": compute_func}
def on_compute(op_ctx_p, inputs_p, outputs_p, state):
compute_func = state["compute_func"]
input_nps = [self._tensor_to_numpy(inputs_p[i]) for i in range(n_inputs)]
output_nps = [self._tensor_to_numpy(outputs_p[i]) for i in range(n_outputs)]
results = compute_func(*input_nps)
if n_outputs == 1:
result_val = results[0] if isinstance(results, tuple) else results
output_nps[0][...] = result_val
else:
for i in range(n_outputs):
output_nps[i][...] = results[i]
return 0
return self._build_py_custom_op(
op_type=op_type,
n_inputs=n_inputs,
n_outputs=n_outputs,
on_init=on_init,
on_compute=on_compute
)
return creator_func
def _create_gridsample_op(self):
import onnxscript
from onnxscript import opset17 as opset
def grid_sample_builder(align_corners, mode, padding_mode):
@onnxscript.script()
def grid_sample_compute(X, G):
return opset.GridSample(X, G, align_corners=align_corners, mode=mode, padding_mode=padding_mode)
return grid_sample_compute
grid_sample_creator = self._create_onnxscript_op_creator(
op_type="GridSample",
onnxscript_func_builder=grid_sample_builder, # << 传入 builder
attributes={
"align_corners": ("int64", 0),
"mode": ("str", "bilinear"),
"padding_mode": ("str", "zeros"),
},
n_inputs = 2,
n_outputs = 1
)
return grid_sample_creator
def _create_scatterelements_op(self):
import onnxscript
from onnxscript import opset17 as opset
@onnxscript.script()
def scatter_elements_compute(data, indices, updates):
indices_i64 = opset.Cast(indices, to=onnxscript.INT64.dtype)
return opset.ScatterElements(data, indices_i64, updates)
scatter_elements_creator = self._create_onnxscript_op_creator(
op_type="ScatterElements",
onnxscript_func_builder=lambda: scatter_elements_compute,
n_inputs = 3,
n_outputs = 1
)
return scatter_elements_creator
def _create_randomnormallike_op(self):
import onnxscript
from onnxscript import opset17 as opset
def random_normal_like_builder(mean, scale):
@onnxscript.script()
def random_normal_like_compute(like):
return opset.RandomNormalLike(like, mean=mean, scale=scale)
return random_normal_like_compute
# 3. 使用新的工厂函数
random_normal_like_creator = self._create_onnxscript_op_creator(
op_type="RandomNormalLike",
onnxscript_func_builder=random_normal_like_builder, # << 传入 builder
attributes={
"mean": ("float32", 0.0),
"scale": ("float32", 1.0),
},
n_inputs = 1,
n_outputs = 1
)
return random_normal_like_creator
def _create_einsum_op(self):
import onnxscript
from onnxscript import opset17 as opset
def einsum_builder(equation):
@onnxscript.script()
def einsum_compute(in1, in2):
return opset.Einsum(in1, in2, equation=equation)
return einsum_compute
# 3. 使用新的工厂函数
einsum_creator = self._create_onnxscript_op_creator(
op_type="Einsum",
onnxscript_func_builder=einsum_builder, # << 传入 builder
attributes={
"equation": ("str", ""),
},
n_inputs = 2,
n_outputs = 1
)
return einsum_creator
def register_bundled_ops(self) -> None:
"""注册自定义操作"""
if getattr(self, "_custom_ops_registered", False):
return
runtime = self.runtime.rknn_base.rknn_runtime
lib = runtime.lib
ctx = runtime.context
try:
_ = lib.rknn_register_custom_ops
_ = lib.rknn_custom_op_get_op_attr
except AttributeError as e:
logger.debug(f"SDK不支持自定义算子注册: {e}")
return
self._init_custom_op_types()
# 注意:插件库注册已在模型加载后由环境变量控制,不在此处重复触发
# 算子创建函数的列表现在更加清晰
op_creator_factories = [
self._create_gridsample_op,
self._create_scatterelements_op,
self._create_randomnormallike_op,
self._create_einsum_op,
# self._create_my_custom_add_op, # 添加新算子非常简单
]
ops_to_register = []
all_callbacks = []
for factory in op_creator_factories:
try:
# 调用工厂获得真正的构造器
creator_func = factory()
# 调用构造器生成算子实例
op, callbacks = creator_func()
ops_to_register.append(op)
all_callbacks.extend(callbacks)
logger.debug(f"成功创建自定义算子: {op.op_type.decode()}")
except Exception as e:
logger.warning(f"创建自定义算子失败: {e}", exc_info=True)
if not ops_to_register:
logger.debug("没有可注册的自定义算子")
return
# 创建一个ctypes数组以包含所有要注册的算子, 然后一次性注册
num_ops = len(ops_to_register)
op_array = (self._RKNN_CustomOp * num_ops)(*ops_to_register)
ret = lib.rknn_register_custom_ops(ctx, op_array, num_ops)
if ret != 0:
logger.error(f"注册自定义算子失败, ret={ret} (可能是误报, 继续执行...)")
# raise RuntimeError(f"rknn_register_custom_ops 失败, ret={ret}")
logger.info(f"成功注册 {len(ops_to_register)} 个自定义算子")
self._custom_ops_registered = True
self._registered_ops = ops_to_register
self._op_callbacks = all_callbacks
def _load_and_register_plugin_op(self, so_path: str) -> bool:
"""加载单个插件库并注册其中的自定义算子。
要求插件实现 get_rknn_custom_op(),返回 rknn_custom_op*。
我们将该 C 指针直接传递给 rknn_register_custom_ops,避免复制。
"""
if not os.path.isfile(so_path):
logger.warning(f"插件库不存在: {so_path}")
return False
runtime = self.runtime.rknn_base.rknn_runtime
lib = runtime.lib
ctx = runtime.context
# 根据平台位宽设置 rknn_context 的 ctypes 类型
ContextCType = ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32
# 设置 rknn_register_custom_ops(ctx, op_ptr, num) 签名。第二参数按 void* 传递,避免结构体布局不一致
try:
lib.rknn_register_custom_ops.argtypes = [ContextCType, ctypes.c_void_p, ctypes.c_uint32]
lib.rknn_register_custom_ops.restype = ctypes.c_int
except Exception:
pass
# 加载插件
try:
handle = ctypes.CDLL(so_path)
except Exception as e:
logger.error(f"dlopen 失败: {so_path}, err={e}")
return False
# 获取 get_rknn_custom_op 符号
try:
get_sym = getattr(handle, "get_rknn_custom_op")
except AttributeError:
logger.error(f"插件缺少符号 get_rknn_custom_op: {so_path}")
return False
# 返回类型直接使用 void*,避免 Python 解析第三方结构体
try:
get_sym.argtypes = []
except Exception:
pass
get_sym.restype = ctypes.c_void_p
op_void_ptr = get_sym()
if not op_void_ptr:
logger.error(f"get_rknn_custom_op 返回空指针: {so_path}")
return False
# 直接使用原生指针注册(零拷贝)
ctx_val = ContextCType(runtime.context)
ret = lib.rknn_register_custom_ops(ctx_val, ctypes.c_void_p(op_void_ptr), 1)
if ret != 0:
logger.error(f"rknn_register_custom_ops 失败, ret={ret}, so={so_path} (可能是误报, 继续执行...)")
# return False
# 保留句柄,避免被垃圾回收卸载
if not hasattr(self, "_plugin_handles"):
self._plugin_handles = []
self._plugin_handles.append(handle)
logger.info(f"成功注册插件自定义算子: {so_path}")
return True
def register_plugin_ops(self, plugin_paths: List[str]) -> int:
"""按给定路径列表注册插件库中的自定义算子。返回成功数量。"""
if not plugin_paths:
return 0
success = 0
for path in plugin_paths:
try:
if self._load_and_register_plugin_op(path):
success += 1
except Exception as e:
logger.error(f"注册插件失败: {path}, err={e}")
return success
# 对外API:注册单个自定义算子插件库
def register_custom_op_lib(self, path: str) -> bool:
return self._load_and_register_plugin_op(path)
# 对外API:扫描并注册 Linux 系统目录下所有插件库(Android 不处理)
def register_system_custom_op_lib(self) -> int:
if os.name != 'posix':
return 0
# 仅 Linux:RKNN 官方默认目录
system_dir = "/usr/lib/rknpu/op_plugins/"
if not os.path.isdir(system_dir):
return 0
try:
entries = os.listdir(system_dir)
except Exception:
return 0
so_list = []
for name in entries:
# 官方要求文件名以 librkcst_ 开头
if name.startswith("librkcst_") and name.endswith('.so'):
so_list.append(os.path.join(system_dir, name))
return self.register_plugin_ops(so_list)