| from __future__ import annotations |
|
|
| from typing import NoReturn, Sequence |
|
|
| from torchgen.api.types import ( |
| ArrayRefCType, |
| BaseCType, |
| Binding, |
| boolT, |
| ConstRefCType, |
| deviceT, |
| Expr, |
| intArrayRefT, |
| iOptTensorListRefT, |
| layoutT, |
| ListCType, |
| longT, |
| memoryFormatT, |
| MutRefCType, |
| NamedCType, |
| opmath_t, |
| OptionalCType, |
| optionalIntArrayRefT, |
| optionalScalarRefT, |
| optionalSymIntArrayRefT, |
| optionalTensorRefT, |
| scalar_t, |
| scalarT, |
| scalarTypeT, |
| SpecialArgName, |
| symIntArrayRefT, |
| SymIntT, |
| tensorOptionsT, |
| tensorT, |
| VectorCType, |
| ) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| options_ctype = NamedCType("options", ConstRefCType(BaseCType(tensorOptionsT))) |
|
|
| out_tensor_ctype = NamedCType("out", ConstRefCType(BaseCType(tensorT))) |
|
|
| longVec_ctype = VectorCType(BaseCType(longT)) |
| longSymVec_ctype = VectorCType(BaseCType(SymIntT)) |
| optionalLongVec_ctype = OptionalCType(VectorCType(BaseCType(longT))) |
| optionalScalar_ctype = OptionalCType(BaseCType(scalarT)) |
| optionalTensor_ctype = OptionalCType(BaseCType(tensorT)) |
|
|
|
|
| class UnsatError(RuntimeError): |
| pass |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def translate( |
| bindings: Sequence[Expr | Binding], |
| goals: Sequence[NamedCType | Binding], |
| *, |
| method: bool = False, |
| allow_expensive_conversions: bool = False, |
| ) -> list[Expr]: |
| binding_exprs: list[Expr] = [] |
| for b in bindings: |
| if isinstance(b, Binding): |
| binding_exprs.append( |
| Expr( |
| expr=b.name, |
| type=b.nctype, |
| ) |
| ) |
| else: |
| binding_exprs.append(b) |
|
|
| goal_ctypes: list[NamedCType] = [] |
| for g in goals: |
| if isinstance(g, Binding): |
| goal_ctypes.append(g.nctype) |
| else: |
| goal_ctypes.append(g) |
|
|
| |
| ctx: dict[NamedCType, str] = {} |
| for b in binding_exprs: |
| ctx[b.type] = b.expr |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| t = b.type |
| if ( |
| isinstance(t, ConstRefCType) |
| and isinstance(t.elem, OptionalCType) |
| and isinstance(t.elem.elem, BaseCType) |
| and str(t.elem.elem.type) == "at::Tensor" |
| ): |
| ctx[ |
| NamedCType(t.elem.elem.name, ConstRefCType(BaseCType(tensorT))) |
| ] = f"({b.expr}.has_value() ? *{b.expr} : at::Tensor())" |
|
|
| if t.type == ConstRefCType(OptionalCType(BaseCType(tensorT))): |
| ctx[ |
| NamedCType(t.name, BaseCType(optionalTensorRefT)) |
| ] = f"(({b.expr}.has_value() && (*{b.expr}).defined()) ? at::OptionalTensorRef(*{b.expr}) : at::OptionalTensorRef())" |
|
|
| if t.type == ConstRefCType(BaseCType(scalarT)): |
| ctx[NamedCType(t.name, BaseCType(opmath_t))] = f"({b.expr}).to<opmath_t>()" |
|
|
| if t.type == ConstRefCType(OptionalCType(BaseCType(scalarT))): |
| ctx[ |
| NamedCType(t.name, BaseCType(optionalScalarRefT)) |
| ] = f"({b.expr}.has_value() ? at::OptionalScalarRef(&({b.expr}.value())) : at::OptionalScalarRef())" |
|
|
| if t.type == BaseCType(scalar_t): |
| ctx[ |
| NamedCType(t.name, BaseCType(opmath_t)) |
| ] = f"static_cast<opmath_t>({b.expr})" |
|
|
| |
| if t.type == ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))): |
| ctx[ |
| NamedCType(t.name, BaseCType(iOptTensorListRefT)) |
| ] = f"at::IOptTensorListRef({b.expr})" |
|
|
| |
| if method: |
| ctx[ |
| NamedCType("self", MutRefCType(BaseCType(tensorT))) |
| ] = "const_cast<Tensor&>(*this)" |
| ctx[ |
| NamedCType("self", ConstRefCType(BaseCType(tensorT))) |
| ] = "const_cast<Tensor&>(*this)" |
| |
| |
|
|
| def unsat(goal: NamedCType) -> NoReturn: |
| ctx_desc = "\n".join( |
| f" {t.cpp_type()} {t.name}; // {e}" for t, e in ctx.items() |
| ) |
| raise UnsatError( |
| f""" |
| Failed to synthesize the expression "{goal.cpp_type()} {goal.name}". |
| When I failed, the following bindings were available in the context: |
| |
| {ctx_desc} |
| |
| This probably means there is a missing rule in the rules of torchgen.api.translate. |
| Check this module for more information. |
| """ |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| def solve(goal: NamedCType, *, direct: bool) -> str: |
| def direct_solve(goal: NamedCType) -> str: |
| return solve(goal, direct=True) |
|
|
| if goal in ctx: |
| |
| return ctx[goal] |
|
|
| |
| if isinstance(goal.type, ConstRefCType): |
| try: |
| |
| |
| |
| return solve( |
| NamedCType(goal.name, MutRefCType(goal.type.elem)), direct=direct |
| ) |
| except UnsatError: |
| pass |
|
|
| |
| if isinstance(goal.type, MutRefCType): |
| try: |
| return solve(NamedCType(goal.name, goal.type.elem), direct=direct) |
| except UnsatError: |
| pass |
|
|
| |
| |
| |
| if goal.type == ArrayRefCType(BaseCType(longT)): |
| return solve(NamedCType(goal.name, BaseCType(intArrayRefT)), direct=direct) |
|
|
| if direct: |
| unsat(goal) |
|
|
| |
| if goal == NamedCType("memory_format", OptionalCType(BaseCType(memoryFormatT))): |
| memory_format = direct_solve( |
| NamedCType( |
| SpecialArgName.possibly_redundant_memory_format, |
| OptionalCType(BaseCType(memoryFormatT)), |
| ) |
| ) |
| |
| |
| if options_ctype in goal_ctypes: |
| return memory_format |
| try: |
| options = direct_solve(options_ctype) |
| return f"c10::impl::check_tensor_options_and_extract_memory_format({options}, {memory_format})" |
| except UnsatError: |
| return memory_format |
| elif goal == NamedCType("options", BaseCType(tensorOptionsT)): |
| dtype = direct_solve( |
| NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))) |
| ) |
| pin_memory = direct_solve( |
| NamedCType("pin_memory", OptionalCType(BaseCType(boolT))) |
| ) |
| device = direct_solve( |
| NamedCType("device", OptionalCType(BaseCType(deviceT))) |
| ) |
| layout = direct_solve( |
| NamedCType("layout", OptionalCType(BaseCType(layoutT))) |
| ) |
| return f"TensorOptions().dtype({dtype}).layout({layout}).device({device}).pinned_memory({pin_memory})" |
|
|
| elif goal == NamedCType("dtype", OptionalCType(BaseCType(scalarTypeT))): |
| try: |
| options = direct_solve(options_ctype) |
| return f"c10::optTypeMetaToScalarType({options}.dtype_opt())" |
| except UnsatError: |
| out_tensor = direct_solve(out_tensor_ctype) |
| return f"{out_tensor}.scalar_type()" |
|
|
| elif goal == NamedCType("layout", OptionalCType(BaseCType(layoutT))): |
| try: |
| options = direct_solve(options_ctype) |
| return f"{options}.layout_opt()" |
| except UnsatError: |
| out_tensor = direct_solve(out_tensor_ctype) |
| return f"{out_tensor}.layout()" |
|
|
| elif goal == NamedCType("device", OptionalCType(BaseCType(deviceT))): |
| try: |
| options = direct_solve(options_ctype) |
| return f"{options}.device_opt()" |
| except UnsatError: |
| out_tensor = direct_solve(out_tensor_ctype) |
| return f"{out_tensor}.device()" |
|
|
| elif goal == NamedCType("pin_memory", OptionalCType(BaseCType(boolT))): |
| try: |
| options = direct_solve(options_ctype) |
| return f"{options}.pinned_memory_opt()" |
| except UnsatError: |
| |
| |
| out_tensor = direct_solve(out_tensor_ctype) |
| return "::std::nullopt" |
|
|
| |
| elif goal.type == BaseCType(intArrayRefT): |
| try: |
| return direct_solve(NamedCType(goal.name, longVec_ctype)) |
| except UnsatError: |
| |
| symIntArrayRef_type = direct_solve( |
| NamedCType(goal.name, BaseCType(symIntArrayRefT)) |
| ) |
| return f"C10_AS_INTARRAYREF_SLOW({symIntArrayRef_type})" |
| elif goal.type == BaseCType(symIntArrayRefT): |
| try: |
| r = direct_solve(NamedCType(goal.name, BaseCType(intArrayRefT))) |
| return f"c10::fromIntArrayRefSlow({r})" |
| except UnsatError: |
| return direct_solve(NamedCType(goal.name, longSymVec_ctype)) |
| elif goal.type == BaseCType(SymIntT): |
| return direct_solve(NamedCType(goal.name, BaseCType(longT))) |
| elif goal.type == OptionalCType(BaseCType(SymIntT)): |
| argname = direct_solve( |
| NamedCType(goal.name, OptionalCType(BaseCType(longT))) |
| ) |
| return f"{argname}.has_value() ? ::std::make_optional(c10::SymInt(*{argname})) : ::std::nullopt" |
| elif goal.type == BaseCType(longT): |
| symInt_type = direct_solve(NamedCType(goal.name, BaseCType(SymIntT))) |
| return f"{symInt_type}.guard_int(__FILE__, __LINE__)" |
| elif goal.type == OptionalCType(BaseCType(longT)): |
| argname = direct_solve( |
| NamedCType(goal.name, OptionalCType(BaseCType(SymIntT))) |
| ) |
| return f"{argname}.has_value() ? ::std::make_optional({argname}->guard_int(__FILE__, __LINE__)) : ::std::nullopt" |
| elif goal.type == BaseCType(optionalIntArrayRefT): |
| try: |
| return direct_solve(NamedCType(goal.name, optionalLongVec_ctype)) |
| except UnsatError: |
| argname = direct_solve( |
| NamedCType(goal.name, BaseCType(optionalSymIntArrayRefT)) |
| ) |
| return f"{argname}.has_value() ? ::std::make_optional(C10_AS_INTARRAYREF_SLOW(*{argname})) : ::std::nullopt" |
| elif goal.type == BaseCType(optionalSymIntArrayRefT): |
| |
| |
| argname = direct_solve( |
| NamedCType(goal.name, BaseCType(optionalIntArrayRefT)) |
| ) |
| return f"{argname}.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*{argname})) : ::std::nullopt" |
| elif goal.type == BaseCType(optionalScalarRefT): |
| return direct_solve(NamedCType(goal.name, optionalScalar_ctype)) |
| elif goal.type == BaseCType(optionalTensorRefT): |
| return direct_solve(NamedCType(goal.name, optionalTensor_ctype)) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| if allow_expensive_conversions: |
| if goal.type == VectorCType(BaseCType(longT)): |
| intArrayRef_ctype = NamedCType(goal.name, BaseCType(intArrayRefT)) |
| argname = direct_solve(intArrayRef_ctype) |
| return f"{argname}.vec()" |
| if goal.type == VectorCType(BaseCType(SymIntT)): |
| symIntArrayRef_ctype = NamedCType(goal.name, BaseCType(symIntArrayRefT)) |
| argname = direct_solve(symIntArrayRef_ctype) |
| return f"{argname}.vec()" |
| elif goal.type == OptionalCType(VectorCType(BaseCType(longT))): |
| optionalIntArrayRef_ctype = NamedCType( |
| goal.name, BaseCType(optionalIntArrayRefT) |
| ) |
| argname = direct_solve(optionalIntArrayRef_ctype) |
| return f"{argname}.has_value() ? ::std::make_optional({argname}->vec()) : ::std::nullopt" |
| elif goal.type == OptionalCType(BaseCType(scalarT)): |
| optionalScalarRef_ctype = NamedCType( |
| goal.name, BaseCType(optionalScalarRefT) |
| ) |
| argname = direct_solve(optionalScalarRef_ctype) |
| return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" |
| elif goal.type == OptionalCType(BaseCType(scalarT)): |
| optionalTensorRef_ctype = NamedCType( |
| goal.name, BaseCType(optionalTensorRefT) |
| ) |
| argname = direct_solve(optionalTensorRef_ctype) |
| return f"{argname}.has_value() ? ::std::make_optional({argname}) : ::std::nullopt" |
| |
| |
| |
| |
|
|
| |
| |
| if goal.type == MutRefCType(BaseCType(tensorT)): |
| const_ref_tensor_ctype = NamedCType( |
| goal.name, ConstRefCType(BaseCType(tensorT)) |
| ) |
| argname = direct_solve(const_ref_tensor_ctype) |
| return f"const_cast<Tensor&>({argname})" |
|
|
| unsat(goal) |
|
|
| return [Expr(solve(g, direct=False), g) for g in goal_ctypes] |
|
|