| import torch |
| import torch.nn.functional as F |
|
|
| from ._ops import add_op_namespace_prefix |
|
|
|
|
| @torch.library.custom_op(add_op_namespace_prefix("silu_and_mul"), mutates_args=()) |
| def _silu_and_mul(x: torch.Tensor) -> torch.Tensor: |
| d = x.shape[-1] // 2 |
| return F.silu(x[..., :d]) * x[..., d:] |
|
|
|
|
| def backward(ctx, grad_output): |
| x = ctx.saved_tensors[0] |
| d = x.shape[-1] // 2 |
| x1, x2 = x[..., :d], x[..., d:] |
| sigmoid_x1 = torch.sigmoid(x1) |
| silu_x1 = F.silu(x1) |
| dsilu_dx1 = sigmoid_x1 + silu_x1 * (1 - sigmoid_x1) |
| dx1 = grad_output * x2 * dsilu_dx1 |
| dx2 = grad_output * silu_x1 |
| return torch.cat([dx1, dx2], dim=-1) |
|
|
|
|
| def setup_context(ctx, inputs, output): |
| (x,) = inputs |
| ctx.save_for_backward(x) |
|
|
|
|
| _silu_and_mul.register_autograd(backward, setup_context=setup_context) |
|
|
|
|
| @_silu_and_mul.register_fake |
| def _(x: torch.Tensor) -> torch.Tensor: |
| return x.new_empty(x.shape[0], x.shape[1] // 2) |
|
|