| """Kernel test utils""" |
|
|
| import unittest |
| from typing import Any, Dict, Optional, Sequence, Tuple, Union |
|
|
| import torch |
| from torch._prims_common import TensorLikeType |
|
|
| from .allclose_default import get_default_atol, get_default_rtol |
|
|
| |
| |
| DEFAULT_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( |
| "test_schema", |
| "test_autograd_registration", |
| "test_faketensor", |
| ) |
|
|
| ALL_OPCHECK_TEST_UTILS: Tuple[str, ...] = ( |
| "test_schema", |
| "test_autograd_registration", |
| "test_faketensor", |
| "test_aot_dispatch_dynamic", |
| ) |
|
|
|
|
| def assert_close( |
| a: TensorLikeType, |
| b: TensorLikeType, |
| atol: float | None = None, |
| rtol: float | None = None, |
| ) -> None: |
| atol = atol if atol is not None else get_default_atol(a) |
| rtol = rtol if rtol is not None else get_default_rtol(a) |
| torch.testing.assert_close(a, b, atol=atol, rtol=rtol) |
|
|
|
|
| |
| def fp8_allclose( |
| a: TensorLikeType, |
| b: TensorLikeType, |
| rtol: float = 1e-05, |
| atol: float = 1e-08, |
| equal_nan: bool = False, |
| ) -> bool: |
| """ |
| Reference implementation of torch.allclose |
| """ |
| torch._refs._check_close_args(name="torch.allclose", |
| a=a, |
| b=b, |
| rtol=rtol, |
| atol=atol) |
|
|
| return bool( |
| torch.all( |
| torch.isclose(a.double(), |
| b.double(), |
| rtol=rtol, |
| atol=atol, |
| equal_nan=equal_nan)).item()) |
|
|
|
|
| |
| |
| def opcheck( |
| op: Union[ |
| torch._ops.OpOverload, |
| torch._ops.OpOverloadPacket, |
| torch._library.custom_ops.CustomOpDef, |
| ], |
| args: Tuple[Any, ...], |
| kwargs: Optional[Dict[str, Any]] = None, |
| *, |
| test_utils: Union[str, Sequence[str]] = ALL_OPCHECK_TEST_UTILS, |
| raise_exception: bool = True, |
| cond: bool = True, |
| ) -> Dict[str, str]: |
| with unittest.mock.patch("torch.allclose", new=fp8_allclose): |
| return (torch.library.opcheck(op, |
| args, |
| kwargs, |
| test_utils=test_utils, |
| raise_exception=raise_exception) |
| if cond else {}) |
|
|