| """ |
| This generates .pyi stubs for the cffi Python bindings generated by regenerate.py |
| """ |
| import sys, re, itertools |
| sys.path.extend(['.', '..']) |
|
|
| from pycparser import c_ast, parse_file, CParser |
| import pycparser.plyparser |
| from pycparser.c_ast import PtrDecl, TypeDecl, FuncDecl, EllipsisParam, IdentifierType, Struct, Enum, Typedef |
| from typing import Tuple |
|
|
| __c_type_to_python_type = { |
| 'void': 'None', '_Bool': 'bool', |
| 'char': 'int', 'short': 'int', 'int': 'int', 'long': 'int', |
| 'ptrdiff_t': 'int', 'size_t': 'int', |
| 'int8_t': 'int', 'uint8_t': 'int', |
| 'int16_t': 'int', 'uint16_t': 'int', |
| 'int32_t': 'int', 'uint32_t': 'int', |
| 'int64_t': 'int', 'uint64_t': 'int', |
| 'float': 'float', 'double': 'float', |
| 'ggml_fp16_t': 'np.float16', |
| } |
|
|
| def format_type(t: TypeDecl): |
| if isinstance(t, PtrDecl) or isinstance(t, Struct): |
| return 'ffi.CData' |
| if isinstance(t, Enum): |
| return 'int' |
| if isinstance(t, TypeDecl): |
| return format_type(t.type) |
| if isinstance(t, IdentifierType): |
| assert len(t.names) == 1, f'Expected a single name, got {t.names}' |
| return __c_type_to_python_type.get(t.names[0]) or 'ffi.CData' |
| return t.name |
|
|
| class PythonStubFuncDeclVisitor(c_ast.NodeVisitor): |
| def __init__(self): |
| self.sigs = {} |
| self.sources = {} |
|
|
| def get_source_snippet_lines(self, coord: pycparser.plyparser.Coord) -> Tuple[list[str], list[str]]: |
| if coord.file not in self.sources: |
| with open(coord.file, 'rt') as f: |
| self.sources[coord.file] = f.readlines() |
| source_lines = self.sources[coord.file] |
| ncomment_lines = len(list(itertools.takewhile(lambda i: re.search(r'^\s*(//|/\*)', source_lines[i]), range(coord.line - 2, -1, -1)))) |
| comment_lines = [l.strip() for l in source_lines[coord.line - 1 - ncomment_lines:coord.line - 1]] |
| decl_lines = [] |
| for line in source_lines[coord.line - 1:]: |
| decl_lines.append(line.rstrip()) |
| if (';' in line) or ('{' in line): break |
| return (comment_lines, decl_lines) |
|
|
| def visit_Enum(self, node: Enum): |
| if node.values is not None: |
| for e in node.values.enumerators: |
| self.sigs[e.name] = f' @property\n def {e.name}(self) -> int: ...' |
|
|
| def visit_Typedef(self, node: Typedef): |
| pass |
|
|
| def visit_FuncDecl(self, node: FuncDecl): |
| ret_type = node.type |
| is_ptr = False |
| while isinstance(ret_type, PtrDecl): |
| ret_type = ret_type.type |
| is_ptr = True |
|
|
| fun_name = ret_type.declname |
| if fun_name.startswith('__'): |
| return |
|
|
| args = [] |
| argnames = [] |
| def gen_name(stem): |
| i = 1 |
| while True: |
| new_name = stem if i == 1 else f'{stem}{i}' |
| if new_name not in argnames: return new_name |
| i += 1 |
|
|
| for a in node.args.params: |
| if isinstance(a, EllipsisParam): |
| arg_name = gen_name('args') |
| argnames.append(arg_name) |
| args.append('*' + gen_name('args')) |
| elif format_type(a.type) == 'None': |
| continue |
| else: |
| arg_name = a.name or gen_name('arg') |
| argnames.append(arg_name) |
| args.append(f'{arg_name}: {format_type(a.type)}') |
|
|
| ret = format_type(ret_type if not is_ptr else node.type) |
|
|
| comment_lines, decl_lines = self.get_source_snippet_lines(node.coord) |
|
|
| lines = [f' def {fun_name}({", ".join(args)}) -> {ret}:'] |
| if len(comment_lines) == 0 and len(decl_lines) == 1: |
| lines += [f' """{decl_lines[0]}"""'] |
| else: |
| lines += [' """'] |
| lines += [f' {c.lstrip("/* ")}' for c in comment_lines] |
| if len(comment_lines) > 0: |
| lines += [''] |
| lines += [f' {d}' for d in decl_lines] |
| lines += [' """'] |
| lines += [' ...'] |
| self.sigs[fun_name] = '\n'.join(lines) |
|
|
| def generate_stubs(header: str): |
| """ |
| Generates a .pyi Python stub file for the GGML API using C header files. |
| """ |
|
|
| v = PythonStubFuncDeclVisitor() |
| v.visit(CParser().parse(header, "<input>")) |
|
|
| keys = list(v.sigs.keys()) |
| keys.sort() |
|
|
| return '\n'.join([ |
| '# auto-generated file', |
| 'import ggml.ffi as ffi', |
| 'import numpy as np', |
| 'class lib:', |
| *[v.sigs[k] for k in keys] |
| ]) |
|
|