| from __future__ import annotations |
|
|
| from torchgen.api import cpp |
| from torchgen.api.types import ( |
| ArgName, |
| ArrayRefCType, |
| BaseCType, |
| Binding, |
| ConstRefCType, |
| dimnameListT, |
| intArrayRefT, |
| iOptTensorListRefT, |
| iTensorListRefT, |
| NamedCType, |
| OptionalCType, |
| optionalIntArrayRefT, |
| optionalScalarRefT, |
| optionalTensorRefT, |
| scalarT, |
| tensorT, |
| ) |
| from torchgen.model import ( |
| Argument, |
| BaseTy, |
| BaseType, |
| ListType, |
| NativeFunctionsGroup, |
| OptionalType, |
| SelfArgument, |
| TensorOptionsArguments, |
| Type, |
| ) |
| from torchgen.utils import assert_never |
|
|
|
|
| |
| |
| |
|
|
|
|
| |
| |
| |
| def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType: |
| |
| |
| |
| |
| |
| |
| r = cpp.valuetype_type(t, symint=False, binds=binds, mutable=mutable) |
| if r is not None: |
| return r |
|
|
| if isinstance(t, BaseType): |
| if t.name == BaseTy.Tensor: |
| return NamedCType(binds, ConstRefCType(BaseCType(tensorT))) |
| elif t.name == BaseTy.Scalar: |
| return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) |
| else: |
| raise AssertionError(f"base type should have been value type {t}") |
| elif isinstance(t, OptionalType): |
| if t.elem == BaseType(BaseTy.Tensor): |
| return NamedCType(binds, BaseCType(optionalTensorRefT)) |
| elif t.elem == BaseType(BaseTy.Scalar): |
| return NamedCType(binds, BaseCType(optionalScalarRefT)) |
| elif isinstance(t.elem, ListType) and str(t.elem.elem) == "int": |
| return NamedCType(binds, BaseCType(optionalIntArrayRefT)) |
| elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) |
| return NamedCType(binds, OptionalCType(elem.type)) |
| elif isinstance(t, ListType): |
| if t.elem == BaseType(BaseTy.Tensor): |
| return NamedCType(binds, ConstRefCType(BaseCType(iTensorListRefT))) |
| elif t.elem == OptionalType(BaseType(BaseTy.Tensor)): |
| return NamedCType(binds, BaseCType(iOptTensorListRefT)) |
| |
| |
| |
| elif str(t.elem) == "int": |
| return NamedCType(binds, BaseCType(intArrayRefT)) |
| elif str(t.elem) == "Dimname": |
| return NamedCType(binds, BaseCType(dimnameListT)) |
| elem = argumenttype_type(t.elem, mutable=mutable, binds=binds) |
| return NamedCType(binds, ArrayRefCType(elem.type)) |
| else: |
| raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
| def argument_type(a: Argument, *, binds: ArgName) -> NamedCType: |
| return argumenttype_type(a.type, mutable=a.is_write, binds=binds) |
|
|
|
|
| |
| |
| |
| |
|
|
|
|
| |
| def argument(a: Argument | SelfArgument | TensorOptionsArguments) -> list[Binding]: |
| if isinstance(a, Argument): |
| return [ |
| Binding( |
| nctype=argument_type(a, binds=a.name), |
| name=a.name, |
| default=None, |
| argument=a, |
| ) |
| ] |
| elif isinstance(a, SelfArgument): |
| return argument(a.argument) |
| elif isinstance(a, TensorOptionsArguments): |
| raise AssertionError("structured kernels don't support TensorOptions yet") |
| else: |
| assert_never(a) |
|
|
|
|
| def impl_arguments(g: NativeFunctionsGroup) -> list[Binding]: |
| args: list[Argument | TensorOptionsArguments | SelfArgument] = [] |
|
|
| if g.out.precomputed: |
| |
| |
| |
| non_out_args_replaced: list[ |
| Argument | TensorOptionsArguments | SelfArgument |
| ] = [] |
| for a in g.out.func.arguments.non_out: |
| if isinstance(a, Argument) and a.name in g.out.precomputed.replace: |
| |
| |
| non_out_args_replaced.extend(g.out.precomputed.replace[a.name]) |
| else: |
| |
| non_out_args_replaced.append(a) |
|
|
| args.extend(non_out_args_replaced) |
| |
| |
| args.extend(g.out.precomputed.add) |
| else: |
| args.extend(g.out.func.arguments.non_out) |
|
|
| args.extend(g.out.func.arguments.out) |
| return [r for arg in args for r in argument(arg)] |
|
|
|
|
| def meta_arguments(g: NativeFunctionsGroup) -> list[Binding]: |
| args: list[Argument | TensorOptionsArguments | SelfArgument] = [] |
| args.extend(g.functional.func.arguments.non_out) |
| return [r for arg in args for r in argument(arg)] |
|
|
|
|
| def out_arguments(g: NativeFunctionsGroup) -> list[Binding]: |
| args: list[Argument | TensorOptionsArguments | SelfArgument] = [] |
| args.extend(g.out.func.arguments.out) |
| return [r for arg in args for r in argument(arg)] |
|
|