| import ctypes |
| import numpy as np |
| from pyaxdev import _lib, AxDeviceType, AxDevices, check_error |
|
|
|
|
| class SRInit(ctypes.Structure): |
| _fields_ = [ |
| ('dev_type', AxDeviceType), |
| ('devid', ctypes.c_char), |
| ('model_path', ctypes.c_char * 256) |
| ] |
|
|
| class SRImage(ctypes.Structure): |
| _fields_ = [ |
| ('width', ctypes.c_int), |
| ('height', ctypes.c_int), |
| ('pVirAddr', ctypes.POINTER(ctypes.c_ubyte)) |
| ] |
|
|
|
|
| _lib.ax_sr_init.argtypes = [ctypes.POINTER(SRInit), ctypes.POINTER(ctypes.c_void_p)] |
| _lib.ax_sr_init.restype = ctypes.c_int |
|
|
| _lib.ax_sr_deinit.argtypes = [ctypes.c_void_p] |
| _lib.ax_sr_deinit.restype = ctypes.c_int |
|
|
| _lib.ax_sr_run.argtypes = [ctypes.c_void_p, ctypes.POINTER(SRImage), ctypes.POINTER(SRImage)] |
| _lib.ax_sr_run.restype = ctypes.c_int |
|
|
|
|
| class SR: |
| def __init__(self, init_info: dict): |
| self.handle = None |
| self.init_info = SRInit() |
| |
| |
| self.init_info.dev_type = init_info.get('dev_type', AxDeviceType.axcl_device) |
| self.init_info.devid = init_info.get('devid', 0) |
| setattr(self.init_info, 'model_path', init_info['model_path'].encode('utf-8')) |
|
|
| handle = ctypes.c_void_p() |
| check_error(_lib.ax_sr_init(ctypes.byref(self.init_info), ctypes.byref(handle))) |
| self.handle = handle |
|
|
| def __del__(self): |
| if self.handle: |
| _lib.ax_sr_deinit(self.handle) |
|
|
| def __call__(self, image_data: np.ndarray) -> None: |
| image = SRImage() |
| image.width = image_data.shape[1] |
| image.height = image_data.shape[0] |
| image.pVirAddr = ctypes.cast(image_data.ctypes.data, ctypes.POINTER(ctypes.c_ubyte)) |
| |
| np_sr_image = np.zeros((image.height*2, image.width*2, 3), dtype=np.uint8) |
| sr_image = SRImage() |
| sr_image.width = image.width*2 |
| sr_image.height = image.height*2 |
| sr_image.pVirAddr = ctypes.cast(np_sr_image.ctypes.data, ctypes.POINTER(ctypes.c_ubyte)) |
| |
| check_error(_lib.ax_sr_run(self.handle, ctypes.byref(image), ctypes.byref(sr_image))) |
| return np_sr_image |
|
|