| import ctypes |
| import numpy as np |
| import os |
|
|
| class GPUNanoF1: |
| def __init__(self): |
| |
| path = os.path.join(os.path.dirname(__file__), 'f1_kernel.so') |
| if not os.path.exists(path): |
| raise Exception("Le Kernel F-1 n'est pas compilé. Lancez 'sh compile.sh' d'abord.") |
| |
| self.lib = ctypes.CDLL(path) |
| self.lib.launch_f1_kernel.argtypes = [ |
| ctypes.POINTER(ctypes.c_float), |
| ctypes.POINTER(ctypes.c_float), |
| ctypes.POINTER(ctypes.c_float), |
| ctypes.c_int |
| ] |
|
|
| def compute(self, A, B): |
| |
| A = A.astype(np.float32) |
| B = B.astype(np.float32) |
| size = A.shape[0] |
| C = np.zeros((size, size), dtype=np.float32) |
| |
| self.lib.launch_f1_kernel( |
| A.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), |
| B.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), |
| C.ctypes.data_as(ctypes.POINTER(ctypes.c_float)), |
| size |
| ) |
| return C |
| |