| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torchgen.api.types as api_types |
| from torchgen.api import cpp, structured |
| from torchgen.api.types import ( |
| ArgName, |
| BaseCppType, |
| BaseCType, |
| Binding, |
| ConstRefCType, |
| CType, |
| NamedCType, |
| scalarT, |
| ) |
| from torchgen.model import ( |
| Argument, |
| BaseTy, |
| BaseType, |
| DispatchKey, |
| FunctionSchema, |
| NativeFunctionsGroup, |
| Type, |
| ) |
|
|
|
|
| def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str: |
| assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas" |
| return f"ufunc_{func.name.name}_{dispatch_key}" |
|
|
|
|
| def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str: |
| return schema_kernel_name(g.out.func, dispatch_key) |
|
|
|
|
| |
| |
| |
| |
| |
| def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None: |
| |
| r = cpp.valuetype_type(t, binds=binds, symint=False) |
| if r is not None: |
| return r |
|
|
| if t == BaseType(BaseTy.Scalar): |
| return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) |
| elif t == BaseType(BaseTy.Tensor): |
| return None |
| else: |
| raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
| def opmath_type(scalar_t: BaseCppType) -> BaseCppType: |
| if scalar_t == api_types.scalar_t: |
| return api_types.opmath_t |
| raise NotImplementedError |
|
|
|
|
| |
| |
| |
| |
| |
| def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType: |
| r = cpp.valuetype_type(t, binds=binds, symint=False) |
| if r is not None: |
| return r |
|
|
| if t == BaseType(BaseTy.Scalar): |
| return NamedCType(binds, BaseCType(opmath_type(scalar_t))) |
| elif t == BaseType(BaseTy.Tensor): |
| return NamedCType(binds, BaseCType(opmath_type(scalar_t))) |
| else: |
| raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
| |
| |
| |
| |
| def ufunctor_apply_type( |
| t: Type, *, binds: ArgName, scalar_t: BaseCppType |
| ) -> NamedCType: |
| if t == BaseType(BaseTy.Tensor): |
| return NamedCType(binds, BaseCType(scalar_t)) |
| else: |
| raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
| |
| |
| |
| def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType: |
| r = cpp.valuetype_type(t, binds=binds, symint=False) |
| if r is not None: |
| return r |
|
|
| if t == BaseType(BaseTy.Scalar): |
| return NamedCType(binds, compute_t) |
| elif t == BaseType(BaseTy.Tensor): |
| return NamedCType(binds, compute_t) |
| else: |
| raise AssertionError(f"unrecognized type {repr(t)}") |
|
|
|
|
| def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding: |
| return Binding( |
| nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t), |
| name=a.name, |
| default=None, |
| argument=a, |
| ) |
|
|
|
|
| def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding: |
| return Binding( |
| nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t), |
| name=a.name, |
| default=None, |
| argument=a, |
| ) |
|
|
|
|
| def ufunc_argument(a: Argument, compute_t: CType) -> Binding: |
| return Binding( |
| nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t), |
| name=a.name, |
| default=None, |
| argument=a, |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class UfunctorBindings: |
| ctor: list[Binding] |
| apply: list[Binding] |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def ufunctor_arguments( |
| g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType |
| ) -> UfunctorBindings: |
| ctor = [] |
| apply = [] |
| for a in g.functional.func.arguments.flat_non_out: |
| if a.type.is_tensor_like(): |
| if scalar_tensor_idx == 0: |
| |
| ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) |
| scalar_tensor_idx = None |
| else: |
| if scalar_tensor_idx is not None: |
| scalar_tensor_idx -= 1 |
| apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t)) |
| else: |
| ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t)) |
| assert scalar_tensor_idx is None |
| return UfunctorBindings(ctor=ctor, apply=apply) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]: |
| return [ |
| ufunc_argument(a, compute_t=compute_t) |
| for a in g.functional.func.arguments.flat_non_out |
| ] |
|
|
|
|
| |
| |
| |
| |
| |
| def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]: |
| |
| |
| return [ |
| r |
| for a in g.out.func.arguments.flat_non_out |
| if not a.type.is_tensor_like() |
| for r in structured.argument(a) |
| ] |
|
|