| from __future__ import annotations |
|
|
| from typing import Sequence |
|
|
| from torchgen import local |
| from torchgen.api import cpp |
| from torchgen.api.types import ( |
| ArgName, |
| BaseCType, |
| Binding, |
| boolT, |
| ConstRefCType, |
| CType, |
| deviceT, |
| layoutT, |
| ListCType, |
| MutRefCType, |
| NamedCType, |
| OptionalCType, |
| scalarT, |
| scalarTypeT, |
| tensorT, |
| ) |
| from torchgen.model import ( |
| Argument, |
| FunctionSchema, |
| Return, |
| SelfArgument, |
| TensorOptionsArguments, |
| Type, |
| ) |
| from torchgen.utils import assert_never |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| def name(func: FunctionSchema) -> str: |
| name = str(func.name.name) |
| |
| if func.is_out_fn(): |
| name += "_out" |
| if func.name.overload_name: |
| name += f"_{func.name.overload_name}" |
| return name |
|
|
|
|
| def argumenttype_type( |
| t: Type, *, mutable: bool, binds: ArgName, symint: bool |
| ) -> NamedCType: |
| if str(t) == "Tensor?": |
| tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT)) |
| if mutable and not local.use_const_ref_for_mutable_tensors(): |
| return NamedCType(binds, MutRefCType(tensor_type)) |
| else: |
| return NamedCType(binds, ConstRefCType(tensor_type)) |
| elif str(t) == "Tensor?[]": |
| return NamedCType( |
| binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))) |
| ) |
| elif str(t) == "Scalar": |
| return NamedCType(binds, ConstRefCType(BaseCType(scalarT))) |
| elif str(t) == "Scalar?": |
| return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT)))) |
| return cpp.argumenttype_type(t, mutable=mutable, binds=binds, symint=symint) |
|
|
|
|
| def returns_type(rs: Sequence[Return], *, symint: bool) -> CType: |
| return cpp.returns_type(rs, symint=symint) |
|
|
|
|
| def argument_type(a: Argument, *, binds: ArgName, symint: bool) -> NamedCType: |
| return argumenttype_type(a.type, mutable=a.is_write, binds=binds, symint=symint) |
|
|
|
|
| def argument( |
| a: Argument | SelfArgument | TensorOptionsArguments, |
| *, |
| is_out: bool, |
| symint: bool, |
| ) -> list[Binding]: |
| |
| |
| |
| |
| |
| should_default = not is_out |
| if isinstance(a, Argument): |
| default: str | None = None |
| if should_default and a.default is not None: |
| default = cpp.default_expr(a.default, a.type, symint=symint) |
| return [ |
| Binding( |
| nctype=argument_type(a, binds=a.name, symint=symint), |
| name=a.name, |
| default=default, |
| argument=a, |
| ) |
| ] |
| elif isinstance(a, SelfArgument): |
| |
| return argument(a.argument, is_out=is_out, symint=symint) |
| elif isinstance(a, TensorOptionsArguments): |
| default = None |
| if should_default: |
| default = "{}" |
| |
| |
| |
| return [ |
| Binding( |
| nctype=NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))), |
| name="dtype", |
| default=default, |
| argument=a, |
| ), |
| Binding( |
| nctype=NamedCType("layout", OptionalCType(BaseCType(layoutT))), |
| name="layout", |
| default=default, |
| argument=a, |
| ), |
| Binding( |
| nctype=NamedCType("device", OptionalCType(BaseCType(deviceT))), |
| name="device", |
| default=default, |
| argument=a, |
| ), |
| Binding( |
| nctype=NamedCType("pin_memory", OptionalCType(BaseCType(boolT))), |
| name="pin_memory", |
| default=default, |
| argument=a, |
| ), |
| ] |
| else: |
| assert_never(a) |
|
|
|
|
| def arguments(func: FunctionSchema, *, symint: bool) -> list[Binding]: |
| args: list[Argument | TensorOptionsArguments | SelfArgument] = [] |
| args.extend(func.arguments.non_out) |
| args.extend(func.arguments.out) |
| return [ |
| r for arg in args for r in argument(arg, symint=symint, is_out=func.is_out_fn()) |
| ] |
|
|