| from functools import wraps |
| import torch |
| import diffusers |
|
|
| |
|
|
|
|
| |
| original_fourier_filter = diffusers.utils.torch_utils.fourier_filter |
| @wraps(diffusers.utils.torch_utils.fourier_filter) |
| def fourier_filter(x_in, threshold, scale): |
| return_dtype = x_in.dtype |
| return original_fourier_filter(x_in.to(dtype=torch.float32), threshold, scale).to(dtype=return_dtype) |
|
|
|
|
| |
| class FluxPosEmbed(torch.nn.Module): |
| def __init__(self, theta: int, axes_dim): |
| super().__init__() |
| self.theta = theta |
| self.axes_dim = axes_dim |
|
|
| def forward(self, ids: torch.Tensor) -> torch.Tensor: |
| n_axes = ids.shape[-1] |
| cos_out = [] |
| sin_out = [] |
| pos = ids.float() |
| for i in range(n_axes): |
| cos, sin = diffusers.models.embeddings.get_1d_rotary_pos_embed( |
| self.axes_dim[i], |
| pos[:, i], |
| theta=self.theta, |
| repeat_interleave_real=True, |
| use_real=True, |
| freqs_dtype=torch.float32, |
| ) |
| cos_out.append(cos) |
| sin_out.append(sin) |
| freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) |
| freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) |
| return freqs_cos, freqs_sin |
|
|
|
|
| def ipex_diffusers(device_supports_fp64=False, can_allocate_plus_4gb=False): |
| diffusers.utils.torch_utils.fourier_filter = fourier_filter |
| if not device_supports_fp64: |
| diffusers.models.embeddings.FluxPosEmbed = FluxPosEmbed |
|
|