| from __future__ import annotations
|
|
|
| import contextlib
|
| import functools
|
| from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
|
|
|
| import torchgen.local as local
|
| from torchgen.model import (
|
| BackendIndex,
|
| DispatchKey,
|
| NativeFunction,
|
| NativeFunctionsGroup,
|
| NativeFunctionsViewGroup,
|
| )
|
| from torchgen.utils import context, S, T
|
|
|
|
|
| if TYPE_CHECKING:
|
| from collections.abc import Iterator
|
|
|
|
|
|
|
|
|
| F = TypeVar(
|
| "F",
|
| NativeFunction,
|
| NativeFunctionsGroup,
|
| NativeFunctionsViewGroup,
|
| Union[NativeFunction, NativeFunctionsGroup],
|
| Union[NativeFunction, NativeFunctionsViewGroup],
|
| )
|
|
|
| F2 = TypeVar(
|
| "F2",
|
| NativeFunction,
|
| NativeFunctionsGroup,
|
| Optional[NativeFunction],
|
| bool,
|
| str,
|
| )
|
|
|
| F3 = TypeVar("F3", tuple[NativeFunction, Any], list[NativeFunction])
|
|
|
|
|
| @contextlib.contextmanager
|
| def native_function_manager(
|
| g: NativeFunctionsGroup | NativeFunctionsViewGroup | NativeFunction,
|
| ) -> Iterator[None]:
|
| if isinstance(g, NativeFunctionsGroup):
|
|
|
|
|
|
|
|
|
| f = g.out
|
| elif isinstance(g, NativeFunctionsViewGroup):
|
|
|
| f = g.view
|
| else:
|
| f = g
|
| with context(lambda: f"in native_functions.yaml line {f.loc}:\n {f.func}"):
|
| with local.parametrize(
|
| use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors,
|
| use_ilistref_for_tensor_lists=f.part_of_structured_group,
|
| ):
|
| yield
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def with_native_function(func: Callable[[F], T]) -> Callable[[F], T]:
|
| @functools.wraps(func)
|
| def wrapper(f: F) -> T:
|
| with native_function_manager(f):
|
| return func(f)
|
|
|
| return wrapper
|
|
|
|
|
| def with_native_function_and(func: Callable[[F, F2], T]) -> Callable[[F, F2], T]:
|
| @functools.wraps(func)
|
| def wrapper(f: F, f2: F2) -> T:
|
|
|
| with native_function_manager(f):
|
| return func(f, f2)
|
|
|
| return wrapper
|
|
|
|
|
| def method_with_native_function(func: Callable[[S, F], T]) -> Callable[[S, F], T]:
|
| @functools.wraps(func)
|
| def wrapper(slf: S, f: F) -> T:
|
| with native_function_manager(f):
|
| return func(slf, f)
|
|
|
| return wrapper
|
|
|
|
|
| def method_with_nested_native_function(
|
| func: Callable[[S, F3], T],
|
| ) -> Callable[[S, F3], T]:
|
| @functools.wraps(func)
|
| def wrapper(slf: S, f: F3) -> T:
|
| with native_function_manager(f[0]):
|
| return func(slf, f)
|
|
|
| return wrapper
|
|
|
|
|
|
|
|
|
| def with_native_function_and_index(
|
| func: Callable[[F, BackendIndex], T],
|
| ) -> Callable[[F, BackendIndex], T]:
|
| @functools.wraps(func)
|
| def wrapper(f: F, backend_index: BackendIndex) -> T:
|
| with native_function_manager(f):
|
| return func(f, backend_index)
|
|
|
| return wrapper
|
|
|
|
|
|
|
| def with_native_function_and_indices(
|
| func: Callable[[F, dict[DispatchKey, BackendIndex]], T],
|
| ) -> Callable[[F, dict[DispatchKey, BackendIndex]], T]:
|
| @functools.wraps(func)
|
| def wrapper(f: F, backend_indices: dict[DispatchKey, BackendIndex]) -> T:
|
| with native_function_manager(f):
|
| return func(f, backend_indices)
|
|
|
| return wrapper
|
|
|