| |
| from rknnlite.api import RKNNLite |
| import numpy as np |
| import os |
| import warnings |
| import logging |
| from typing import List, Dict, Union, Optional |
|
|
| try: |
| import onnxruntime as ort |
| HAS_ORT = True |
| except ImportError: |
| HAS_ORT = False |
| warnings.warn("onnxruntime未安装,只能使用RKNN后端", ImportWarning) |
|
|
| |
| logger = logging.getLogger("somemodelruntime_rknnlite2") |
| logger.setLevel(logging.ERROR) |
| if not logger.handlers: |
| handler = logging.StreamHandler() |
| handler.setFormatter(logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')) |
| logger.addHandler(handler) |
|
|
| |
| _LOGGING_LEVEL_MAP = { |
| 0: logging.DEBUG, |
| 1: logging.INFO, |
| 2: logging.WARNING, |
| 3: logging.ERROR, |
| 4: logging.CRITICAL |
| } |
|
|
| |
| try: |
| env_log_level = os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') |
| if env_log_level is not None: |
| log_level = int(env_log_level) |
| if log_level in _LOGGING_LEVEL_MAP: |
| logger.setLevel(_LOGGING_LEVEL_MAP[log_level]) |
| logger.info(f"从环境变量设置日志级别: {log_level}") |
| else: |
| logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {log_level}, 应该是0-4之间的整数") |
| except ValueError: |
| logger.warning(f"环境变量ZTU_MODELRT_RKNNL2_LOG_LEVEL的值无效: {env_log_level}, 应该是0-4之间的整数") |
|
|
|
|
| def set_default_logger_severity(level: int) -> None: |
| """ |
| Sets the default logging severity. 0:Verbose, 1:Info, 2:Warning, 3:Error, 4:Fatal |
| |
| Args: |
| level: 日志级别(0-4) |
| """ |
| if level not in _LOGGING_LEVEL_MAP: |
| raise ValueError(f"无效的日志级别: {level}, 应该是0-4之间的整数") |
| logger.setLevel(_LOGGING_LEVEL_MAP[level]) |
|
|
| def set_default_logger_verbosity(level: int) -> None: |
| """ |
| Sets the default logging verbosity level. To activate the verbose log, |
| you need to set the default logging severity to 0:Verbose level. |
| |
| Args: |
| level: 日志级别(0-4) |
| """ |
| set_default_logger_severity(level) |
|
|
| |
| RKNN_DTYPE_MAP = { |
| 0: np.float32, |
| 1: np.float16, |
| 2: np.int8, |
| 3: np.uint8, |
| 4: np.int16, |
| 5: np.uint16, |
| 6: np.int32, |
| 7: np.uint32, |
| 8: np.int64, |
| 9: bool, |
| 10: np.int8, |
| } |
|
|
| def get_available_providers() -> List[str]: |
| """ |
| 获取可用的设备提供者列表(为保持接口兼容性的占位函数) |
| |
| Returns: |
| list: 可用的设备提供者列表,总是返回["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"] |
| """ |
| return ["CPUExecutionProvider", "somemodelruntime_rknnlite2_ExecutionProvider"] |
|
|
|
|
| def get_device() -> str: |
| """ |
| 获取当前设备 |
| |
| Returns: |
| str: 当前设备 |
| """ |
| return "RKNN2" |
|
|
| def get_version_info() -> Dict[str, str]: |
| """ |
| 获取版本信息 |
| |
| Returns: |
| dict: 包含API和驱动版本信息的字典 |
| """ |
| runtime = RKNNLite() |
| version = runtime.get_sdk_version() |
| return { |
| "api_version": version.split('\n')[2].split(': ')[1].split(' ')[0], |
| "driver_version": version.split('\n')[3].split(': ')[1] |
| } |
|
|
| class IOTensor: |
| """输入/输出张量的信息封装类""" |
| def __init__(self, name, shape, type=None): |
| self.name = name.decode() if isinstance(name, bytes) else name |
| self.shape = shape |
| self.type = type |
|
|
| def __str__(self): |
| return f"IOTensor(name='{self.name}', shape={self.shape}, type={self.type})" |
|
|
| class SessionOptions: |
| """会话选项类""" |
| def __init__(self): |
| self.enable_profiling = False |
| self.intra_op_num_threads = 1 |
| self.log_severity_level = -1 |
| self.log_verbosity_level = -1 |
|
|
|
|
| class InferenceSession: |
| """ |
| RKNNLite运行时封装类,API风格类似ONNX Runtime |
| """ |
|
|
| def __new__(cls, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs): |
| processed_path = InferenceSession._process_model_path(model_path, sess_options) |
| if isinstance(processed_path, str) and processed_path.lower().endswith('.onnx'): |
| logger.info("使用ONNX Runtime加载模型") |
| if not HAS_ORT: |
| raise RuntimeError("未安装onnxruntime,无法加载ONNX模型") |
| return ort.InferenceSession(processed_path, sess_options=sess_options, **kwargs) |
| else: |
| |
| instance = super().__new__(cls) |
| |
| instance._processed_path = processed_path |
| return instance |
|
|
| def __init__(self, model_path: str, sess_options: Optional[SessionOptions] = None, **kwargs): |
| """ |
| 初始化运行时并加载模型 |
| |
| Args: |
| model_path: 模型文件路径(.rknn或.onnx) |
| sess_options: 会话选项 |
| **kwargs: 其他初始化参数 |
| """ |
| options = sess_options or SessionOptions() |
|
|
| |
| if os.getenv('ZTU_MODELRT_RKNNL2_LOG_LEVEL') is None: |
| if options.log_severity_level != -1: |
| set_default_logger_severity(options.log_severity_level) |
| if options.log_verbosity_level != -1: |
| set_default_logger_verbosity(options.log_verbosity_level) |
| |
| |
| model_path = getattr(self, '_processed_path', model_path) |
| if isinstance(model_path, str) and model_path.lower().endswith('.onnx'): |
| |
| return |
|
|
| |
| self.model_path = model_path |
| if not os.path.exists(self.model_path): |
| logger.error(f"模型文件不存在: {self.model_path}") |
| raise FileNotFoundError(f"模型文件不存在: {self.model_path}") |
|
|
| self.runtime = RKNNLite(verbose=options.enable_profiling) |
|
|
| logger.debug(f"正在加载模型: {self.model_path}") |
| ret = self.runtime.load_rknn(self.model_path) |
| if ret != 0: |
| logger.error(f"加载RKNN模型失败: {self.model_path}") |
| raise RuntimeError(f'加载RKNN模型失败: {self.model_path}') |
| logger.debug("模型加载成功") |
|
|
|
|
| if options.intra_op_num_threads == 1: |
| core_mask = RKNNLite.NPU_CORE_AUTO |
| elif options.intra_op_num_threads == 2: |
| core_mask = RKNNLite.NPU_CORE_0_1 |
| elif options.intra_op_num_threads == 3: |
| core_mask = RKNNLite.NPU_CORE_0_1_2 |
| else: |
| raise ValueError(f"intra_op_num_threads的值无效: {options.intra_op_num_threads}, 只能是1,2或3") |
|
|
| logger.debug("正在初始化运行时环境") |
| ret = self.runtime.init_runtime(core_mask=core_mask) |
| if ret != 0: |
| logger.error("初始化运行时环境失败") |
| raise RuntimeError('初始化运行时环境失败') |
|
|
| logger.debug("运行时环境初始化成功") |
|
|
| |
| try: |
| |
| env_custom = os.getenv('ZTU_MODELRT_RKNN2_REG_CUSTOM_OP_LIB', '').strip() |
| if env_custom: |
| paths = [seg.strip() for seg in re.split(r"[,;:]", env_custom) if seg.strip()] |
| ok = 0 |
| for p in paths: |
| if self.register_custom_op_lib(p): |
| ok += 1 |
| if ok > 0: |
| logger.info(f"已注册 {ok}/{len(paths)} 个自定义算子插件") |
| |
| if os.getenv('ZTU_MODELRT_RKNN2_REG_SYSTEM_CUSTOM_OP_LIB', '1') == '1': |
| cnt = self.register_system_custom_op_lib() |
| if cnt > 0: |
| logger.info(f"已从系统目录注册 {cnt} 个自定义算子插件") |
| except Exception as e: |
| logger.warning(f"自动注册自定义算子插件失败: {e}") |
|
|
| |
| if os.getenv('ZTU_MODELRT_RKNN2_REG_BUNDLED_OPS', '0') == '1': |
| logger.info("根据环境变量注册捆绑算子") |
| self.register_bundled_ops() |
|
|
| self._init_io_info() |
| self.options = options |
|
|
| def get_performance_info(self) -> Dict[str, float]: |
| """ |
| 获取性能信息 |
| |
| Returns: |
| dict: 包含性能信息的字典 |
| """ |
| if not self.options.perf_debug: |
| raise RuntimeError("性能分析未启用,请在SessionOptions中设置perf_debug=True") |
| |
| perf = self.runtime.rknn_runtime.get_run_perf() |
| return { |
| "run_duration": perf.run_duration / 1000.0 |
| } |
|
|
| def set_core_mask(self, core_mask: int) -> None: |
| """ |
| 设置NPU核心使用模式 |
| |
| Args: |
| core_mask: NPU核心掩码,使用NPU_CORE_*常量 |
| """ |
| ret = self.runtime.rknn_runtime.set_core_mask(core_mask) |
| if ret != 0: |
| raise RuntimeError("设置NPU核心模式失败") |
|
|
| @staticmethod |
| def _process_model_path(model_path, sess_options): |
| """ |
| 处理模型路径,支持.onnx和.rknn文件 |
| |
| Args: |
| model_path: 模型文件路径 |
| """ |
| |
| if model_path.lower().endswith('.onnx'): |
| logger.info("检测到ONNX模型文件") |
| |
| |
| skip_models = os.getenv('ZTU_MODELRT_RKNNL2_SKIP', '').strip() |
| if skip_models: |
| skip_list = [m.strip() for m in skip_models.split(',')] |
| |
| model_name = os.path.basename(model_path) |
| if model_name.lower() in [m.lower() for m in skip_list]: |
| logger.info(f"模型{model_name}在跳过列表中,将使用ONNX Runtime") |
| return model_path |
| |
| |
| rknn_path = os.path.splitext(model_path)[0] + '.rknn' |
| if os.path.exists(rknn_path): |
| logger.info(f"找到对应的RKNN模型,将使用RKNN: {rknn_path}") |
| return rknn_path |
| else: |
| logger.info("未找到对应的RKNN模型,将使用ONNX Runtime") |
| return model_path |
| |
| return model_path |
| |
| def _convert_nhwc_to_nchw(self, shape): |
| """将NHWC格式的shape转换为NCHW格式""" |
| if len(shape) == 4: |
| |
| n, h, w, c = shape |
| return [n, c, h, w] |
| return shape |
| |
| def _init_io_info(self): |
| """初始化模型的输入输出信息""" |
| runtime = self.runtime.rknn_runtime |
| |
| |
| n_input, n_output = runtime.get_in_out_num() |
| |
| |
| self.input_tensors = [] |
| for i in range(n_input): |
| attr = runtime.get_tensor_attr(i) |
| shape = [attr.dims[j] for j in range(attr.n_dims)] |
| |
| shape = self._convert_nhwc_to_nchw(shape) |
| |
| dtype = RKNN_DTYPE_MAP.get(attr.type, None) |
| tensor = IOTensor(attr.name, shape, dtype) |
| self.input_tensors.append(tensor) |
| |
| |
| self.output_tensors = [] |
| for i in range(n_output): |
| attr = runtime.get_tensor_attr(i, is_output=True) |
| shape = runtime.get_output_shape(i) |
| |
| dtype = RKNN_DTYPE_MAP.get(attr.type, None) |
| tensor = IOTensor(attr.name, shape, dtype) |
| self.output_tensors.append(tensor) |
| |
| def get_inputs(self): |
| """ |
| 获取模型输入信息 |
| |
| Returns: |
| list: 包含输入信息的列表 |
| """ |
| return self.input_tensors |
| |
| def get_outputs(self): |
| """ |
| 获取模型输出信息 |
| |
| Returns: |
| list: 包含输出信息的列表 |
| """ |
| return self.output_tensors |
| |
| def run(self, output_names=None, input_feed=None, data_format="nchw", **kwargs): |
| """ |
| 执行模型推理 |
| |
| Args: |
| output_names: 输出节点名称列表,指定需要返回哪些输出 |
| input_feed: 输入数据字典或列表 |
| data_format: 输入数据格式,"nchw"或"nhwc" |
| **kwargs: 其他运行时参数 |
| |
| Returns: |
| list: 模型输出结果列表,如果指定了output_names则只返回指定的输出 |
| """ |
| if input_feed is None: |
| logger.error("input_feed不能为None") |
| raise ValueError("input_feed不能为None") |
| |
| |
| if isinstance(input_feed, dict): |
| |
| inputs = [] |
| input_map = {tensor.name: i for i, tensor in enumerate(self.input_tensors)} |
| for tensor in self.input_tensors: |
| if tensor.name not in input_feed: |
| raise ValueError(f"缺少输入: {tensor.name}") |
| inputs.append(input_feed[tensor.name]) |
| elif isinstance(input_feed, (list, tuple)): |
| |
| if len(input_feed) != len(self.input_tensors): |
| raise ValueError(f"输入数量不匹配: 期望{len(self.input_tensors)}, 实际{len(input_feed)}") |
| inputs = list(input_feed) |
| else: |
| logger.error("input_feed必须是字典或列表类型") |
| raise ValueError("input_feed必须是字典或列表类型") |
| |
| |
| try: |
| logger.debug("开始执行推理") |
| all_outputs = self.runtime.inference(inputs=inputs, data_format=data_format) |
| |
| |
| if output_names is None: |
| return all_outputs |
| |
| |
| output_map = {tensor.name: i for i, tensor in enumerate(self.output_tensors)} |
| selected_outputs = [] |
| for name in output_names: |
| if name not in output_map: |
| raise ValueError(f"未找到输出节点: {name}") |
| selected_outputs.append(all_outputs[output_map[name]]) |
| |
| return selected_outputs |
| |
| except Exception as e: |
| logger.error(f"推理执行失败: {str(e)}") |
| raise RuntimeError(f"推理执行失败: {str(e)}") |
| |
| def close(self): |
| """ |
| 关闭会话,释放资源 |
| """ |
| if self.runtime is not None: |
| logger.info("正在释放运行时资源") |
| self.runtime.release() |
| self.runtime = None |
| |
| def __enter__(self): |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.close() |
|
|
| def end_profiling(self) -> Optional[str]: |
| """ |
| 结束性能分析的存根方法 |
| |
| Returns: |
| Optional[str]: None |
| """ |
| warnings.warn("end_profiling()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return None |
| |
| def get_profiling_start_time_ns(self) -> int: |
| """ |
| 获取性能分析开始时间的存根方法 |
| |
| Returns: |
| int: 0 |
| """ |
| warnings.warn("get_profiling_start_time_ns()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return 0 |
| |
| def get_modelmeta(self) -> Dict[str, str]: |
| """ |
| 获取模型元数据的存根方法 |
| |
| Returns: |
| Dict[str, str]: 空字典 |
| """ |
| warnings.warn("get_modelmeta()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return {} |
| |
| def get_session_options(self) -> SessionOptions: |
| """ |
| 获取会话选项 |
| |
| Returns: |
| SessionOptions: 当前会话选项 |
| """ |
| return self.options |
| |
| def get_providers(self) -> List[str]: |
| """ |
| 获取当前使用的providers的存根方法 |
| |
| Returns: |
| List[str]: ["CPUExecutionProvider"] |
| """ |
| warnings.warn("get_providers()是存根方法,始终返回CPUExecutionProvider", RuntimeWarning, stacklevel=2) |
| return ["CPUExecutionProvider"] |
| |
| def get_provider_options(self) -> Dict[str, Dict[str, str]]: |
| """ |
| 获取provider选项的存根方法 |
| |
| Returns: |
| Dict[str, Dict[str, str]]: 空字典 |
| """ |
| warnings.warn("get_provider_options()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return {} |
|
|
| def get_session_config(self) -> Dict[str, str]: |
| """ |
| 获取会话配置的存根方法 |
| |
| Returns: |
| Dict[str, str]: 空字典 |
| """ |
| warnings.warn("get_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return {} |
|
|
| def get_session_state(self) -> Dict[str, str]: |
| """ |
| 获取会话状态的存根方法 |
| |
| Returns: |
| Dict[str, str]: 空字典 |
| """ |
| warnings.warn("get_session_state()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return {} |
|
|
| def set_session_config(self, config: Dict[str, str]) -> None: |
| """ |
| 设置会话配置的存根方法 |
| |
| Args: |
| config: 会话配置字典 |
| """ |
| warnings.warn("set_session_config()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
|
|
| def get_memory_info(self) -> Dict[str, int]: |
| """ |
| 获取内存使用信息的存根方法 |
| |
| Returns: |
| Dict[str, int]: 空字典 |
| """ |
| warnings.warn("get_memory_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return {} |
|
|
| def set_memory_pattern(self, enable: bool) -> None: |
| """ |
| 设置内存模式的存根方法 |
| |
| Args: |
| enable: 是否启用内存模式 |
| """ |
| warnings.warn("set_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
|
|
| def disable_memory_pattern(self) -> None: |
| """ |
| 禁用内存模式的存根方法 |
| """ |
| warnings.warn("disable_memory_pattern()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
|
|
| def get_optimization_level(self) -> int: |
| """ |
| 获取优化级别的存根方法 |
| |
| Returns: |
| int: 0 |
| """ |
| warnings.warn("get_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return 0 |
|
|
| def set_optimization_level(self, level: int) -> None: |
| """ |
| 设置优化级别的存根方法 |
| |
| Args: |
| level: 优化级别 |
| """ |
| warnings.warn("set_optimization_level()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
|
|
| def get_model_metadata(self) -> Dict[str, str]: |
| """ |
| 获取模型元数据的存根方法(与get_modelmeta不同的接口) |
| |
| Returns: |
| Dict[str, str]: 空字典 |
| """ |
| warnings.warn("get_model_metadata()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return {} |
|
|
| def get_model_path(self) -> str: |
| """ |
| 获取模型路径 |
| |
| Returns: |
| str: 模型文件路径 |
| """ |
| return self.model_path |
|
|
| def get_input_type_info(self) -> List[Dict[str, str]]: |
| """ |
| 获取输入类型信息的存根方法 |
| |
| Returns: |
| List[Dict[str, str]]: 空列表 |
| """ |
| warnings.warn("get_input_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return [] |
|
|
| def get_output_type_info(self) -> List[Dict[str, str]]: |
| """ |
| 获取输出类型信息的存根方法 |
| |
| Returns: |
| List[Dict[str, str]]: 空列表 |
| """ |
| warnings.warn("get_output_type_info()是存根方法,不提供实际功能", RuntimeWarning, stacklevel=2) |
| return [] |
|
|
| |
|
|
| def _init_custom_op_types(self): |
| """初始化自定义算子的类型定义""" |
| |
| self._RKNN_TENSOR_FLOAT32 = 0 |
| self._RKNN_TENSOR_UINT8 = 3 |
| self._RKNN_TENSOR_INT64 = 8 |
| self._RKNN_TARGET_TYPE_CPU = 1 |
|
|
| |
| class RKNN_TensorAttr(ctypes.Structure): |
| _fields_ = [ |
| ("index", ctypes.c_uint32), |
| ("n_dims", ctypes.c_uint32), |
| ("dims", ctypes.c_uint32 * RKNN_MAX_DIMS), |
| ("name", ctypes.c_char * RKNN_MAX_NAME_LEN), |
| ("n_elems", ctypes.c_uint32), |
| ("size", ctypes.c_uint32), |
| ("fmt", ctypes.c_int), |
| ("type", ctypes.c_int), |
| ("qnt_type", ctypes.c_int), |
| ("fl", ctypes.c_int8), |
| ("zp", ctypes.c_int32), |
| ("scale", ctypes.c_float), |
| ("w_stride", ctypes.c_uint32), |
| ("size_with_stride", ctypes.c_uint32), |
| ("pass_through", ctypes.c_uint8), |
| ("h_stride", ctypes.c_uint32), |
| ] |
|
|
| class RKNN_TensorMem(ctypes.Structure): |
| _fields_ = [ |
| ("virt_addr", ctypes.c_void_p), |
| ("phys_addr", ctypes.c_uint64), |
| ("fd", ctypes.c_int32), |
| ("offset", ctypes.c_int32), |
| ("size", ctypes.c_uint32), |
| ("flags", ctypes.c_uint32), |
| ("priv_data", ctypes.c_void_p), |
| ] |
|
|
| class RKNN_CustomOpTensor(ctypes.Structure): |
| _fields_ = [ |
| ("attr", RKNN_TensorAttr), |
| ("mem", RKNN_TensorMem), |
| ] |
|
|
| class RKNN_GPUOpContext(ctypes.Structure): |
| _fields_ = [ |
| ("cl_context", ctypes.c_void_p), |
| ("cl_command_queue", ctypes.c_void_p), |
| ("cl_kernel", ctypes.c_void_p), |
| ] |
|
|
| InternalCtxType = ( |
| ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32 |
| ) |
|
|
| class RKNN_CustomOpContext(ctypes.Structure): |
| _fields_ = [ |
| ("target", ctypes.c_int), |
| ("internal_ctx", InternalCtxType), |
| ("gpu_ctx", RKNN_GPUOpContext), |
| ("priv_data", ctypes.c_void_p), |
| ] |
|
|
| class RKNN_CustomOpAttr(ctypes.Structure): |
| _fields_ = [ |
| ("name", ctypes.c_char * RKNN_MAX_NAME_LEN), |
| ("dtype", ctypes.c_int), |
| ("n_elems", ctypes.c_uint32), |
| ("data", ctypes.c_void_p), |
| ] |
|
|
| CB_SIG = ctypes.CFUNCTYPE( |
| ctypes.c_int, |
| ctypes.POINTER(RKNN_CustomOpContext), |
| ctypes.POINTER(RKNN_CustomOpTensor), |
| ctypes.c_uint32, |
| ctypes.POINTER(RKNN_CustomOpTensor), |
| ctypes.c_uint32, |
| ) |
|
|
| DESTROY_SIG = ctypes.CFUNCTYPE( |
| ctypes.c_int, ctypes.POINTER(RKNN_CustomOpContext) |
| ) |
|
|
| class RKNN_CustomOp(ctypes.Structure): |
| _fields_ = [ |
| ("version", ctypes.c_uint32), |
| ("target", ctypes.c_int), |
| ("op_type", ctypes.c_char * RKNN_MAX_NAME_LEN), |
| ("cl_kernel_name", ctypes.c_char * RKNN_MAX_NAME_LEN), |
| ("cl_kernel_source", ctypes.c_char_p), |
| ("cl_source_size", ctypes.c_uint64), |
| ("cl_build_options", ctypes.c_char * RKNN_MAX_NAME_LEN), |
| ("init", CB_SIG), |
| ("prepare", CB_SIG), |
| ("compute", CB_SIG), |
| ("compute_native", CB_SIG), |
| ("destroy", DESTROY_SIG), |
| ] |
|
|
| |
| self._RKNN_TensorAttr = RKNN_TensorAttr |
| self._RKNN_TensorMem = RKNN_TensorMem |
| self._RKNN_CustomOpTensor = RKNN_CustomOpTensor |
| self._RKNN_CustomOpContext = RKNN_CustomOpContext |
| self._RKNN_CustomOpAttr = RKNN_CustomOpAttr |
| self._RKNN_CustomOp = RKNN_CustomOp |
| self._CB_SIG = CB_SIG |
| self._DESTROY_SIG = DESTROY_SIG |
|
|
| def _create_attr_readers(self, get_op_attr): |
| """创建属性读取函数""" |
| def read_attr_int64(op_ctx_ptr, key: str, default: int = 0) -> int: |
| attr = self._RKNN_CustomOpAttr() |
| get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr)) |
| if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_INT64 and attr.data: |
| return ctypes.c_int64.from_address(attr.data).value |
| return default |
|
|
| def read_attr_float32(op_ctx_ptr, key: str, default: float = 0) -> float: |
| attr = self._RKNN_CustomOpAttr() |
| get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr)) |
| if attr.n_elems == 1 and attr.dtype == self._RKNN_TENSOR_FLOAT32 and attr.data: |
| return ctypes.c_float.from_address(attr.data).value |
| return default |
|
|
| def read_attr_str(op_ctx_ptr, key: str, default: str = "") -> str: |
| attr = self._RKNN_CustomOpAttr() |
| get_op_attr(op_ctx_ptr, key.encode("utf-8"), ctypes.byref(attr)) |
| if attr.n_elems > 0 and attr.dtype == self._RKNN_TENSOR_UINT8 and attr.data: |
| buf = (ctypes.c_ubyte * attr.n_elems).from_address(attr.data) |
| try: |
| return bytes(buf).decode("utf-8", errors="ignore").strip('"') |
| except Exception: |
| return default |
| return default |
| |
| |
| return read_attr_int64, read_attr_str, read_attr_float32 |
| |
| def _build_py_custom_op(self, |
| op_type: str, |
| n_inputs: int, |
| n_outputs: int, |
| on_init, |
| on_compute): |
| """通用的Python自定义算子构造器 |
| |
| Args: |
| op_type: 算子类型名(字符串) |
| n_inputs: 输入个数 |
| n_outputs: 输出个数 |
| on_init: 回调,签名 on_init(op_ctx_p, read_attr_int64, read_attr_str) -> state |
| on_compute: 回调,签名 on_compute(op_ctx_p, inputs_p, outputs_p, state) -> int(0成功) |
| Returns: |
| (RKNN_CustomOp对象, 回调tuple) |
| """ |
| @self._CB_SIG |
| def _py_init(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p): |
| try: |
| |
| runtime = self.runtime.rknn_base.rknn_runtime |
| read_attr_int64, read_attr_str, read_attr_float32 = self._create_attr_readers(runtime.lib.rknn_custom_op_get_op_attr) |
| user_state = on_init(op_ctx_p, read_attr_int64, read_attr_str, read_attr_float32) |
| |
| if not hasattr(self, "_custom_op_states"): |
| self._custom_op_states = {} |
| if not hasattr(self, "_next_custom_op_id"): |
| self._next_custom_op_id = 1 |
| inst_id = int(self._next_custom_op_id) |
| self._next_custom_op_id += 1 |
| |
| self._custom_op_states[inst_id] = user_state |
| |
| try: |
| op_ctx_p.contents.priv_data = ctypes.c_void_p(inst_id) |
| except Exception: |
| |
| op_ctx_p.contents.priv_data = inst_id |
| return 0 |
| except Exception as e: |
| logger.error(f"{op_type} init失败: {e}") |
| return -1 |
|
|
| @self._CB_SIG |
| def _py_prepare(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p): |
| return 0 |
|
|
| @self._CB_SIG |
| def _py_compute(op_ctx_p, inputs_p, n_inputs_p, outputs_p, n_outputs_p): |
| try: |
| if n_inputs_p != n_inputs or n_outputs_p != n_outputs: |
| return -1 |
| |
| try: |
| inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0 |
| except Exception: |
| inst_id = 0 |
| user_state = None |
| if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states: |
| user_state = self._custom_op_states.get(inst_id) |
| else: |
| logger.error(f"{op_type} compute失败: 找不到实例状态, inst_id={inst_id}") |
| return -1 |
| return on_compute(op_ctx_p, inputs_p, outputs_p, user_state) |
| except Exception as e: |
| logger.error(f"{op_type} compute失败: {e}") |
| import traceback |
| logger.error(f"{op_type} compute失败: {traceback.format_exc()}") |
| return -1 |
|
|
| @self._DESTROY_SIG |
| def _py_destroy(op_ctx_p): |
| try: |
| |
| try: |
| inst_id = int(op_ctx_p.contents.priv_data) if op_ctx_p.contents.priv_data else 0 |
| except Exception: |
| inst_id = 0 |
| if hasattr(self, "_custom_op_states") and inst_id in self._custom_op_states: |
| del self._custom_op_states[inst_id] |
| |
| try: |
| op_ctx_p.contents.priv_data = ctypes.c_void_p(0) |
| except Exception: |
| op_ctx_p.contents.priv_data = 0 |
| return 0 |
| except Exception: |
| return -1 |
|
|
| op = self._RKNN_CustomOp() |
| op.version = 1 |
| op.target = self._RKNN_TARGET_TYPE_CPU |
| op.op_type = op_type.encode("utf-8") |
| op.cl_kernel_name = b"" |
| op.cl_kernel_source = None |
| op.cl_source_size = 0 |
| op.cl_build_options = b"" |
| op.init = _py_init |
| op.prepare = _py_prepare |
| op.compute = _py_compute |
| op.compute_native = self._CB_SIG() |
| op.destroy = _py_destroy |
|
|
| return op, (_py_init, _py_prepare, _py_compute, _py_destroy) |
|
|
|
|
| def _tensor_to_numpy(self, rknn_tensor): |
| """将 RKNN_CustomOpTensor 转换为 Numpy 数组视图""" |
| |
| |
| dtype_map = { |
| self._RKNN_TENSOR_FLOAT32: (ctypes.c_float, np.float32), |
| self._RKNN_TENSOR_UINT8: (ctypes.c_uint8, np.uint8), |
| self._RKNN_TENSOR_INT64: (ctypes.c_int64, np.int64), |
| } |
| c_type, np_dtype = dtype_map.get(rknn_tensor.attr.type, (None, None)) |
| if c_type is None: |
| raise TypeError(f"不支持的RKNN张量类型: {rknn_tensor.attr.type}") |
|
|
| |
| addr = (rknn_tensor.mem.virt_addr or 0) + int(rknn_tensor.mem.offset) |
| ptr = ctypes.cast(addr, ctypes.POINTER(c_type)) |
| shape = tuple(rknn_tensor.attr.dims[i] for i in range(rknn_tensor.attr.n_dims)) |
| |
| |
| return np.ctypeslib.as_array(ptr, shape=shape) |
|
|
|
|
| def _create_onnxscript_op_creator(self, |
| op_type: str, |
| |
| onnxscript_func_builder, |
| n_inputs: int, |
| n_outputs: int, |
| attributes: dict = {}, |
| constants: dict = {}): |
| """ |
| 一个高阶工厂函数,用于创建基于ONNXScript的自定义算子构造器。 |
| 它在 on_init 阶段动态生成最终的 onnxscript 计算函数。 |
| |
| Args: |
| op_type (str): 算子类型名。 |
| onnxscript_func_builder: 一个函数,它接收所有属性和常量作为关键字参数, |
| 并返回一个编译好的 onnxscript 函数。 |
| 例如: def builder(mean, scale): |
| @onnxscript.script() |
| def compute(like): |
| return opset.RandomNormalLike(like, mean=mean, scale=scale) |
| return compute |
| attributes (dict): 从模型中读取的属性字典。 |
| constants (dict): 编译时常量字典。 |
| n_inputs (int): 输入个数。 |
| n_outputs (int): 输出个数。 |
| """ |
|
|
| def creator_func(): |
| def on_init(op_ctx_p, read_i64, read_s, read_f32): |
| |
| attr_values = {} |
| for name, (attr_type, default) in attributes.items(): |
| if attr_type == 'int64': |
| attr_values[name] = read_i64(op_ctx_p, name, default) |
| elif attr_type == 'str': |
| attr_values[name] = read_s(op_ctx_p, name, default) |
| elif attr_type == 'float32': |
| attr_values[name] = read_f32(op_ctx_p, name, default) |
| else: |
| raise ValueError(f"不支持的属性类型: {attr_type}") |
|
|
| |
| final_kwargs = {**constants, **attr_values} |
|
|
| |
| |
| compute_func = onnxscript_func_builder(**final_kwargs) |
|
|
| |
| return {"compute_func": compute_func} |
|
|
| def on_compute(op_ctx_p, inputs_p, outputs_p, state): |
| compute_func = state["compute_func"] |
|
|
| input_nps = [self._tensor_to_numpy(inputs_p[i]) for i in range(n_inputs)] |
| output_nps = [self._tensor_to_numpy(outputs_p[i]) for i in range(n_outputs)] |
| |
| results = compute_func(*input_nps) |
|
|
| if n_outputs == 1: |
| result_val = results[0] if isinstance(results, tuple) else results |
| output_nps[0][...] = result_val |
| else: |
| for i in range(n_outputs): |
| output_nps[i][...] = results[i] |
| |
| return 0 |
|
|
| return self._build_py_custom_op( |
| op_type=op_type, |
| n_inputs=n_inputs, |
| n_outputs=n_outputs, |
| on_init=on_init, |
| on_compute=on_compute |
| ) |
| |
| return creator_func |
|
|
| def _create_gridsample_op(self): |
| import onnxscript |
| from onnxscript import opset17 as opset |
|
|
| def grid_sample_builder(align_corners, mode, padding_mode): |
| @onnxscript.script() |
| def grid_sample_compute(X, G): |
| return opset.GridSample(X, G, align_corners=align_corners, mode=mode, padding_mode=padding_mode) |
| return grid_sample_compute |
|
|
| grid_sample_creator = self._create_onnxscript_op_creator( |
| op_type="GridSample", |
| onnxscript_func_builder=grid_sample_builder, |
| attributes={ |
| "align_corners": ("int64", 0), |
| "mode": ("str", "bilinear"), |
| "padding_mode": ("str", "zeros"), |
| }, |
| n_inputs = 2, |
| n_outputs = 1 |
| ) |
| return grid_sample_creator |
|
|
| def _create_scatterelements_op(self): |
| import onnxscript |
| from onnxscript import opset17 as opset |
| |
| @onnxscript.script() |
| def scatter_elements_compute(data, indices, updates): |
| indices_i64 = opset.Cast(indices, to=onnxscript.INT64.dtype) |
| return opset.ScatterElements(data, indices_i64, updates) |
| |
| scatter_elements_creator = self._create_onnxscript_op_creator( |
| op_type="ScatterElements", |
| onnxscript_func_builder=lambda: scatter_elements_compute, |
| n_inputs = 3, |
| n_outputs = 1 |
| ) |
| return scatter_elements_creator |
|
|
| def _create_randomnormallike_op(self): |
| import onnxscript |
| from onnxscript import opset17 as opset |
|
|
| def random_normal_like_builder(mean, scale): |
| @onnxscript.script() |
| def random_normal_like_compute(like): |
| return opset.RandomNormalLike(like, mean=mean, scale=scale) |
| |
| return random_normal_like_compute |
| |
| |
| random_normal_like_creator = self._create_onnxscript_op_creator( |
| op_type="RandomNormalLike", |
| onnxscript_func_builder=random_normal_like_builder, |
| attributes={ |
| "mean": ("float32", 0.0), |
| "scale": ("float32", 1.0), |
| }, |
| n_inputs = 1, |
| n_outputs = 1 |
| ) |
| return random_normal_like_creator |
|
|
| def _create_einsum_op(self): |
| import onnxscript |
| from onnxscript import opset17 as opset |
|
|
| def einsum_builder(equation): |
|
|
| @onnxscript.script() |
| def einsum_compute(in1, in2): |
| return opset.Einsum(in1, in2, equation=equation) |
| |
| return einsum_compute |
| |
| |
| einsum_creator = self._create_onnxscript_op_creator( |
| op_type="Einsum", |
| onnxscript_func_builder=einsum_builder, |
| attributes={ |
| "equation": ("str", ""), |
| }, |
| n_inputs = 2, |
| n_outputs = 1 |
| ) |
| return einsum_creator |
|
|
| def register_bundled_ops(self) -> None: |
| """注册自定义操作""" |
| if getattr(self, "_custom_ops_registered", False): |
| return |
|
|
| runtime = self.runtime.rknn_base.rknn_runtime |
| lib = runtime.lib |
| ctx = runtime.context |
|
|
| try: |
| _ = lib.rknn_register_custom_ops |
| _ = lib.rknn_custom_op_get_op_attr |
| except AttributeError as e: |
| logger.debug(f"SDK不支持自定义算子注册: {e}") |
| return |
|
|
| self._init_custom_op_types() |
|
|
| |
|
|
| |
| op_creator_factories = [ |
| self._create_gridsample_op, |
| self._create_scatterelements_op, |
| self._create_randomnormallike_op, |
| self._create_einsum_op, |
| |
| ] |
|
|
| ops_to_register = [] |
| all_callbacks = [] |
|
|
| for factory in op_creator_factories: |
| try: |
| |
| creator_func = factory() |
| |
| op, callbacks = creator_func() |
| ops_to_register.append(op) |
| all_callbacks.extend(callbacks) |
| logger.debug(f"成功创建自定义算子: {op.op_type.decode()}") |
| except Exception as e: |
| logger.warning(f"创建自定义算子失败: {e}", exc_info=True) |
|
|
| if not ops_to_register: |
| logger.debug("没有可注册的自定义算子") |
| return |
|
|
| |
| num_ops = len(ops_to_register) |
| op_array = (self._RKNN_CustomOp * num_ops)(*ops_to_register) |
| ret = lib.rknn_register_custom_ops(ctx, op_array, num_ops) |
| if ret != 0: |
| logger.error(f"注册自定义算子失败, ret={ret} (可能是误报, 继续执行...)") |
| |
|
|
| logger.info(f"成功注册 {len(ops_to_register)} 个自定义算子") |
|
|
| self._custom_ops_registered = True |
| self._registered_ops = ops_to_register |
| self._op_callbacks = all_callbacks |
|
|
| def _load_and_register_plugin_op(self, so_path: str) -> bool: |
| """加载单个插件库并注册其中的自定义算子。 |
| |
| 要求插件实现 get_rknn_custom_op(),返回 rknn_custom_op*。 |
| 我们将该 C 指针直接传递给 rknn_register_custom_ops,避免复制。 |
| """ |
| if not os.path.isfile(so_path): |
| logger.warning(f"插件库不存在: {so_path}") |
| return False |
|
|
| runtime = self.runtime.rknn_base.rknn_runtime |
| lib = runtime.lib |
| ctx = runtime.context |
|
|
| |
| ContextCType = ctypes.c_uint64 if ctypes.sizeof(ctypes.c_void_p) == 8 else ctypes.c_uint32 |
| |
| try: |
| lib.rknn_register_custom_ops.argtypes = [ContextCType, ctypes.c_void_p, ctypes.c_uint32] |
| lib.rknn_register_custom_ops.restype = ctypes.c_int |
| except Exception: |
| pass |
|
|
| |
| try: |
| handle = ctypes.CDLL(so_path) |
| except Exception as e: |
| logger.error(f"dlopen 失败: {so_path}, err={e}") |
| return False |
|
|
| |
| try: |
| get_sym = getattr(handle, "get_rknn_custom_op") |
| except AttributeError: |
| logger.error(f"插件缺少符号 get_rknn_custom_op: {so_path}") |
| return False |
|
|
| |
| try: |
| get_sym.argtypes = [] |
| except Exception: |
| pass |
| get_sym.restype = ctypes.c_void_p |
|
|
| op_void_ptr = get_sym() |
| if not op_void_ptr: |
| logger.error(f"get_rknn_custom_op 返回空指针: {so_path}") |
| return False |
|
|
| |
| ctx_val = ContextCType(runtime.context) |
| ret = lib.rknn_register_custom_ops(ctx_val, ctypes.c_void_p(op_void_ptr), 1) |
| if ret != 0: |
| logger.error(f"rknn_register_custom_ops 失败, ret={ret}, so={so_path} (可能是误报, 继续执行...)") |
| |
|
|
| |
| if not hasattr(self, "_plugin_handles"): |
| self._plugin_handles = [] |
| self._plugin_handles.append(handle) |
| logger.info(f"成功注册插件自定义算子: {so_path}") |
| return True |
|
|
| def register_plugin_ops(self, plugin_paths: List[str]) -> int: |
| """按给定路径列表注册插件库中的自定义算子。返回成功数量。""" |
| if not plugin_paths: |
| return 0 |
| success = 0 |
| for path in plugin_paths: |
| try: |
| if self._load_and_register_plugin_op(path): |
| success += 1 |
| except Exception as e: |
| logger.error(f"注册插件失败: {path}, err={e}") |
| return success |
|
|
| |
| def register_custom_op_lib(self, path: str) -> bool: |
| return self._load_and_register_plugin_op(path) |
|
|
| |
| def register_system_custom_op_lib(self) -> int: |
| if os.name != 'posix': |
| return 0 |
| |
| system_dir = "/usr/lib/rknpu/op_plugins/" |
| if not os.path.isdir(system_dir): |
| return 0 |
| try: |
| entries = os.listdir(system_dir) |
| except Exception: |
| return 0 |
| so_list = [] |
| for name in entries: |
| |
| if name.startswith("librkcst_") and name.endswith('.so'): |
| so_list.append(os.path.join(system_dir, name)) |
| return self.register_plugin_ops(so_list) |