Lakonik's picture
Add AsymFLUX.2-klein Space demo
8a62807
from typing import Optional
import torch
@torch.jit.script
def guidance_jit(
pos_mean, neg_mean, guidance_scale,
orthogonal: float = 1.0, parallel_dir: Optional[torch.Tensor] = None):
bias = (pos_mean - neg_mean) * (guidance_scale - 1)
if orthogonal:
dim = list(range(1, pos_mean.dim()))
if parallel_dir is None:
parallel_dir = pos_mean
bias = bias - ((bias * parallel_dir).mean(
dim=dim, keepdim=True
) / (parallel_dir * parallel_dir).mean(
dim=dim, keepdim=True
).clamp(min=1e-6) * parallel_dir).mul(orthogonal)
return bias