PencilFolder / diffsynth /models /comp_attn_model.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
import math
from dataclasses import dataclass
from typing import Optional, Sequence
import torch
import torch.nn.functional as F
from ..diffusion.base_pipeline import PipelineUnit
@dataclass
class CompAttnConfig:
subjects: Sequence[str]
bboxes: Optional[Sequence] = None
enable_sci: bool = True
enable_lam: bool = True
temperature: float = 0.2
apply_to_negative: bool = False
interpolate: bool = False
state_texts: Optional[Sequence[Sequence[str]]] = None
state_weights: Optional[Sequence] = None
state_scale: float = 1.0
state_template: str = "{subject} is {state}"
def find_subsequence_indices(prompt_ids: torch.Tensor, subject_ids: torch.Tensor, valid_len: int) -> list[int]:
if subject_ids.numel() == 0 or valid_len <= 0:
return []
prompt_slice = prompt_ids[:valid_len].tolist()
subject_list = subject_ids.tolist()
span = len(subject_list)
if span > valid_len:
return []
for start in range(valid_len - span + 1):
if prompt_slice[start:start + span] == subject_list:
return list(range(start, start + span))
return []
def build_subject_token_mask(indices_list: list[list[int]], seq_len: int) -> torch.Tensor:
mask = torch.zeros((len(indices_list), seq_len), dtype=torch.bool)
for i, indices in enumerate(indices_list):
if not indices:
continue
mask[i, torch.tensor(indices, dtype=torch.long)] = True
return mask
def compute_saliency(prompt_vecs: torch.Tensor, anchor_vecs: torch.Tensor, tau: float) -> torch.Tensor:
prompt_norm = prompt_vecs / (prompt_vecs.norm(dim=-1, keepdim=True) + 1e-8)
anchor_norm = anchor_vecs / (anchor_vecs.norm(dim=-1, keepdim=True) + 1e-8)
cosine = torch.matmul(prompt_norm, anchor_norm.transpose(0, 1))
scores = torch.exp(cosine / tau)
diag = scores.diagonal()
denom = scores.sum(dim=1).clamp(min=1e-8)
return diag / denom
def compute_delta(anchor_vecs: torch.Tensor) -> torch.Tensor:
total = anchor_vecs.sum(dim=0, keepdim=True)
return anchor_vecs * anchor_vecs.shape[0] - total
_sci_call_count = [0] # 使用列表以便在函数内修改
def apply_sci(context: torch.Tensor, state: dict, timestep: torch.Tensor) -> torch.Tensor:
if state is None or not state.get("enable_sci", False):
return context
subject_mask = state.get("subject_token_mask")
delta = state.get("delta")
saliency = state.get("saliency")
if subject_mask is None or delta is None or saliency is None:
return context
if subject_mask.numel() == 0:
return context
t_scale = float(state.get("timestep_scale", 1000.0))
t_value = float(timestep.reshape(-1)[0].item())
t_ratio = max(0.0, min(1.0, t_value / t_scale))
omega = 1.0 - t_ratio
delta = delta.to(device=context.device, dtype=context.dtype)
saliency = saliency.to(device=context.device, dtype=context.dtype)
scale = omega * (1.0 - saliency).unsqueeze(-1)
delta = delta * scale
mask = subject_mask.to(device=context.device)
token_delta = torch.matmul(mask.to(dtype=context.dtype).transpose(0, 1), delta)
apply_mask = state.get("apply_mask")
if apply_mask is not None:
apply_mask = apply_mask.to(device=context.device, dtype=context.dtype).view(-1, 1, 1)
else:
apply_mask = 1.0
# ========== DEBUG: 打印 SCI 信息 ==========
_sci_call_count[0] += 1
if _sci_call_count[0] % 100 == 1:
print(f"\n{'='*60}")
print(f"[SCI (Saliency-Controlled Intervention) #{_sci_call_count[0]}]")
print(f" timestep: {t_value:.2f}, t_ratio: {t_ratio:.4f}, omega: {omega:.4f}")
print(f" saliency per subject: {saliency.tolist()}")
print(f" delta shape: {delta.shape}")
print(f" delta norm per subject: {delta.norm(dim=-1).tolist()}")
print(f" token_delta shape: {token_delta.shape}")
print(f" context modification norm: {(token_delta.unsqueeze(0) * apply_mask).norm().item():.6f}")
print(f"{'='*60}\n")
return context + token_delta.unsqueeze(0) * apply_mask
def interpolate_bboxes(bboxes: torch.Tensor, target_frames: int) -> torch.Tensor:
if bboxes.shape[2] == target_frames:
return bboxes
b, m, f, _ = bboxes.shape
coords = bboxes.reshape(b * m, f, 4).transpose(1, 2)
coords = F.interpolate(coords, size=target_frames, mode="linear", align_corners=True)
coords = coords.transpose(1, 2).reshape(b, m, target_frames, 4)
return coords
def build_layout_mask_from_bboxes(
bboxes: torch.Tensor,
grid_size: tuple[int, int, int],
image_size: tuple[int, int],
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
if bboxes is None:
return None
bboxes = bboxes.to(device=device, dtype=dtype)
b, m, f_layout, _ = bboxes.shape
f_grid, h_grid, w_grid = grid_size
height, width = image_size
layout = torch.zeros((b, m, f_grid, h_grid, w_grid), device=device, dtype=dtype)
for bi in range(b):
for mi in range(m):
for ti in range(f_layout):
pt = int(ti * f_grid / max(1, f_layout))
pt = max(0, min(f_grid - 1, pt))
x0, y0, x1, y1 = bboxes[bi, mi, ti]
x0 = float(x0)
y0 = float(y0)
x1 = float(x1)
y1 = float(y1)
if x1 <= x0 or y1 <= y0:
continue
px0 = int(math.floor(x0 / max(1.0, width) * w_grid))
px1 = int(math.ceil(x1 / max(1.0, width) * w_grid))
py0 = int(math.floor(y0 / max(1.0, height) * h_grid))
py1 = int(math.ceil(y1 / max(1.0, height) * h_grid))
px0 = max(0, min(w_grid, px0))
px1 = max(0, min(w_grid, px1))
py0 = max(0, min(h_grid, py0))
py1 = max(0, min(h_grid, py1))
if px1 <= px0 or py1 <= py0:
continue
layout[bi, mi, pt, py0:py1, px0:px1] = 1.0
return layout.flatten(2)
_lam_attention_call_count = [0] # 使用列表以便在函数内修改
def lam_attention(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
num_heads: int,
state: dict,
) -> Optional[torch.Tensor]:
subject_mask = state.get("subject_token_mask_lam")
if subject_mask is None:
subject_mask = state.get("subject_token_mask")
layout_mask = state.get("layout_mask")
state_token_mask = state.get("state_token_mask")
state_token_weights = state.get("state_token_weights")
state_scale = float(state.get("state_scale", 1.0))
grid_shape = state.get("grid_shape")
enable_lam = bool(state.get("enable_lam", False))
enable_state = state_token_mask is not None and state_token_weights is not None and grid_shape is not None
if not enable_lam and not enable_state:
return None
b, q_len, dim = q.shape
_, k_len, _ = k.shape
if enable_lam:
if subject_mask is None or layout_mask is None:
return None
if subject_mask.numel() == 0 or layout_mask.numel() == 0:
return None
if layout_mask.shape[-1] != q_len:
return None
if subject_mask.shape[-1] != k_len:
return None
if enable_state:
if state_token_mask.shape[-1] != k_len:
return None
head_dim = dim // num_heads
qh = q.view(b, q_len, num_heads, head_dim).transpose(1, 2)
kh = k.view(b, k_len, num_heads, head_dim).transpose(1, 2)
vh = v.view(b, k_len, num_heads, head_dim).transpose(1, 2)
attn_scores = torch.matmul(qh.float(), kh.float().transpose(-2, -1)) / math.sqrt(head_dim)
# ========== DEBUG: 打印 attention map 信息 ==========
_lam_attention_call_count[0] += 1
call_id = _lam_attention_call_count[0]
# 每100次调用打印一次,避免输出过多
if call_id % 100 == 1:
print(f"\n{'='*60}")
print(f"[LAM Attention #{call_id}]")
print(f" Q shape: {q.shape}, K shape: {k.shape}, V shape: {v.shape}")
print(f" num_heads: {num_heads}, head_dim: {head_dim}")
print(f" attn_scores shape: {attn_scores.shape}")
print(f" attn_scores stats: min={attn_scores.min().item():.4f}, max={attn_scores.max().item():.4f}, mean={attn_scores.mean().item():.4f}")
if enable_lam and layout_mask is not None:
print(f" layout_mask shape: {layout_mask.shape}")
print(f" layout_mask sum per subject: {layout_mask.sum(dim=-1)}")
if subject_mask is not None:
print(f" subject_token_mask shape: {subject_mask.shape}")
print(f" subject_token_mask active tokens per subject: {subject_mask.sum(dim=-1).tolist()}")
if grid_shape is not None:
print(f" grid_shape (f, h, w): {grid_shape}")
print(f"{'='*60}")
bias = torch.zeros_like(attn_scores)
if enable_lam:
attn_max = attn_scores.max(dim=-1, keepdim=True).values
attn_min = attn_scores.min(dim=-1, keepdim=True).values
g_plus = attn_max - attn_scores
g_minus = attn_min - attn_scores
subject_mask = subject_mask.to(device=attn_scores.device)
layout_mask = layout_mask.to(device=attn_scores.device, dtype=attn_scores.dtype)
apply_mask = state.get("apply_mask")
if apply_mask is not None:
layout_mask = layout_mask * apply_mask.to(device=layout_mask.device, dtype=layout_mask.dtype).view(-1, 1, 1)
subject_any = subject_mask.any(dim=0)
for k_idx in range(subject_mask.shape[0]):
mask_k = subject_mask[k_idx]
if not mask_k.any():
continue
mask_other = subject_any & (~mask_k)
mask_k = mask_k.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len)
mask_other = mask_other.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len)
g_k = g_plus * mask_k + g_minus * mask_other
attn_k = attn_scores[..., subject_mask[k_idx]].mean(dim=-1).mean(dim=1)
adapt_mask = attn_k >= attn_k.mean(dim=-1, keepdim=True)
layout_k = layout_mask[:, k_idx]
adapt_f = adapt_mask.to(layout_k.dtype)
inter = (adapt_f * layout_k).sum(dim=-1)
union = (adapt_f + layout_k - adapt_f * layout_k).sum(dim=-1)
iou = inter / union.clamp(min=1e-6)
strength = (1.0 - iou).view(b, 1, 1, 1)
bias = bias + g_k * strength * layout_k.view(b, 1, q_len, 1)
if enable_state:
f, h, w = grid_shape
if f * h * w != q_len:
return None
state_token_mask = state_token_mask.to(device=attn_scores.device)
state_indices = torch.nonzero(state_token_mask, as_tuple=False).flatten()
if state_indices.numel() == 0:
return None
weights = state_token_weights.to(device=attn_scores.device, dtype=attn_scores.dtype)
if weights.shape[1] != f:
return None
time_index = torch.arange(q_len, device=attn_scores.device) // (h * w)
weights_q = weights[:, time_index, :]
if weights_q.shape[-1] != state_indices.numel():
return None
state_bias = torch.zeros((b, 1, q_len, k_len), device=attn_scores.device, dtype=attn_scores.dtype)
state_bias[:, :, :, state_indices] = weights_q.unsqueeze(1) * state_scale
bias = bias + state_bias
attn_probs = torch.softmax(attn_scores + bias, dim=-1).to(vh.dtype)
# ========== DEBUG: 打印 attention probs 和 bias 信息 ==========
if _lam_attention_call_count[0] % 100 == 1:
print(f"\n[LAM Attention #{_lam_attention_call_count[0]} - After Bias]")
print(f" bias shape: {bias.shape}")
print(f" bias stats: min={bias.min().item():.4f}, max={bias.max().item():.4f}, mean={bias.mean().item():.4f}")
print(f" bias non-zero ratio: {(bias != 0).float().mean().item():.4f}")
print(f" attn_probs shape: {attn_probs.shape}")
print(f" attn_probs stats: min={attn_probs.min().item():.6f}, max={attn_probs.max().item():.6f}")
# 打印每个 subject 对应 token 的平均 attention weight
if subject_mask is not None:
for subj_idx in range(subject_mask.shape[0]):
mask_k = subject_mask[subj_idx]
if mask_k.any():
# 计算所有 query 对该 subject tokens 的平均 attention
subj_attn = attn_probs[:, :, :, mask_k.to(attn_probs.device)].mean()
print(f" Subject {subj_idx} avg attention weight: {subj_attn.item():.6f}")
print(f"{'='*60}\n")
out = torch.matmul(attn_probs, vh)
out = out.transpose(1, 2).reshape(b, q_len, dim)
return out
class CompAttnUnit(PipelineUnit):
def __init__(self):
super().__init__(
seperate_cfg=True,
input_params_posi={"prompt": "prompt", "context": "context"},
input_params_nega={"prompt": "negative_prompt", "context": "context"},
output_params=("comp_attn_state",),
onload_model_names=("text_encoder",),
)
def _clean_text(self, pipe, text: str) -> str:
if getattr(pipe.tokenizer, "clean", None):
return pipe.tokenizer._clean(text)
return text
def _tokenize_subject(self, pipe, text: str) -> torch.Tensor:
text = self._clean_text(pipe, text)
tokens = pipe.tokenizer.tokenizer(text, add_special_tokens=False, return_tensors="pt")
return tokens["input_ids"][0]
def _normalize_bboxes(self, bboxes: Sequence) -> torch.Tensor:
bboxes = torch.as_tensor(bboxes, dtype=torch.float32)
if bboxes.dim() == 2 and bboxes.shape[-1] == 4:
bboxes = bboxes.unsqueeze(0).unsqueeze(0)
elif bboxes.dim() == 3 and bboxes.shape[-1] == 4:
bboxes = bboxes.unsqueeze(0)
elif bboxes.dim() != 4 or bboxes.shape[-1] != 4:
raise ValueError(f"comp_attn_bboxes must be (..., 4), got shape {tuple(bboxes.shape)}")
return bboxes
def process(self, pipe, prompt, context) -> dict:
config: Optional[CompAttnConfig] = getattr(pipe, "_comp_attn_config", None)
if context is None or prompt is None or config is None:
return {}
if not config.subjects:
return {}
negative_prompt = getattr(pipe, "_comp_attn_last_negative_prompt", None)
if (not config.apply_to_negative) and negative_prompt and prompt == negative_prompt:
return {}
pipe.load_models_to_device(self.onload_model_names)
ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True)
prompt_ids = ids[0]
valid_len = int(mask[0].sum().item())
indices_list = []
valid_subjects = []
for idx, subject in enumerate(config.subjects):
subject_ids = self._tokenize_subject(pipe, subject)
indices = find_subsequence_indices(prompt_ids, subject_ids, valid_len)
if not indices:
print(f"Comp-Attn: subject tokens not found in prompt: {subject}")
continue
indices_list.append(indices)
valid_subjects.append(idx)
if not indices_list:
return {}
subject_token_mask = build_subject_token_mask(indices_list, prompt_ids.shape[0]).to(device=context.device)
mask_float = subject_token_mask.to(dtype=context.dtype)
denom = mask_float.sum(dim=1, keepdim=True).clamp(min=1)
prompt_vecs = (mask_float @ context[0]) / denom
anchor_vecs = []
for idx in valid_subjects:
subject = config.subjects[idx]
sub_ids, sub_mask = pipe.tokenizer(subject, return_mask=True, add_special_tokens=True)
sub_ids = sub_ids.to(pipe.device)
sub_mask = sub_mask.to(pipe.device)
emb = pipe.text_encoder(sub_ids, sub_mask)
pooled = (emb * sub_mask.unsqueeze(-1)).sum(dim=1) / sub_mask.sum(dim=1, keepdim=True).clamp(min=1)
anchor_vecs.append(pooled)
anchor_vecs = torch.cat(anchor_vecs, dim=0)
saliency = compute_saliency(prompt_vecs.float(), anchor_vecs.float(), float(config.temperature)).to(prompt_vecs.dtype)
delta = compute_delta(anchor_vecs.to(prompt_vecs.dtype))
bboxes = None
state_vectors = None
state_weights = None
state_len = 0
if config.bboxes is not None:
bboxes = self._normalize_bboxes(config.bboxes)
if bboxes.shape[1] >= len(config.subjects):
bboxes = bboxes[:, valid_subjects]
if bboxes.shape[1] != len(valid_subjects):
print("Comp-Attn: bboxes subject count mismatch, disable LAM")
bboxes = None
if bboxes is not None and config.interpolate and getattr(pipe, "_comp_attn_num_frames", None) is not None:
bboxes = interpolate_bboxes(bboxes, int(pipe._comp_attn_num_frames))
if config.state_texts is not None and config.state_weights is not None:
state_texts = config.state_texts
if len(valid_subjects) != len(config.subjects):
subject_names = [config.subjects[i] for i in valid_subjects]
state_texts = [state_texts[i] for i in valid_subjects]
else:
subject_names = list(config.subjects)
if len(state_texts) != len(subject_names):
raise ValueError("state_texts must align with subjects")
state_count = len(state_texts[0])
for row in state_texts:
if len(row) != state_count:
raise ValueError("state_texts must have the same number of states per subject")
phrases = []
for subject, states in zip(subject_names, state_texts):
for state in states:
phrases.append(config.state_template.format(subject=subject, state=state))
ids, mask = pipe.tokenizer(phrases, return_mask=True, add_special_tokens=True)
ids = ids.to(pipe.device)
mask = mask.to(pipe.device)
emb = pipe.text_encoder(ids, mask)
pooled = (emb * mask.unsqueeze(-1)).sum(dim=1) / mask.sum(dim=1, keepdim=True).clamp(min=1)
state_vectors = pooled.to(dtype=prompt_vecs.dtype, device="cpu")
state_len = state_vectors.shape[0]
weights = torch.as_tensor(config.state_weights, dtype=torch.float32)
if weights.dim() == 3:
weights = weights.unsqueeze(0)
if weights.dim() != 4:
raise ValueError("state_weights must be (M,F,S) or (B,M,F,S)")
if weights.shape[1] >= len(config.subjects) and len(valid_subjects) != len(config.subjects):
weights = weights[:, valid_subjects]
if weights.shape[1] != len(subject_names) or weights.shape[3] != state_count:
raise ValueError("state_weights shape does not match state_texts")
weights = weights[:, :len(subject_names)]
weights = weights.permute(0, 2, 1, 3).contiguous()
weights = weights.reshape(weights.shape[0], weights.shape[1], weights.shape[2] * weights.shape[3])
state_weights = weights.to(device="cpu")
state = {
"enable_sci": bool(config.enable_sci),
"enable_lam": bool(config.enable_lam) and bboxes is not None,
"subject_token_mask": subject_token_mask,
"saliency": saliency,
"delta": delta,
"layout_bboxes": bboxes,
"state_vectors": state_vectors,
"state_weights": state_weights,
"state_scale": float(config.state_scale),
"prompt_len": int(prompt_ids.shape[0]),
"state_len": int(state_len),
"timestep_scale": 1000.0,
"apply_to_negative": bool(config.apply_to_negative),
}
if negative_prompt and prompt == negative_prompt:
pipe._comp_attn_state_neg = state
else:
pipe._comp_attn_state_pos = state
return {"comp_attn_state": state}
class CompAttnMergeUnit(PipelineUnit):
def __init__(self):
super().__init__(input_params=("cfg_merge",), output_params=("comp_attn_state",))
def process(self, pipe, cfg_merge) -> dict:
if not cfg_merge:
return {}
state_pos = getattr(pipe, "_comp_attn_state_pos", None)
state_neg = getattr(pipe, "_comp_attn_state_neg", None)
merged = state_pos or state_neg
if merged is None:
return {}
merged = dict(merged)
apply_to_negative = bool(merged.get("apply_to_negative", False))
merged["apply_mask"] = torch.tensor([1.0, 1.0 if apply_to_negative else 0.0])
return {"comp_attn_state": merged}
def patch_cross_attention(pipe) -> None:
for block in pipe.dit.blocks:
cross_attn = block.cross_attn
if getattr(cross_attn, "_comp_attn_patched", False):
continue
orig_forward = cross_attn.forward
def forward_with_lam(self, x, y, _orig=orig_forward, _pipe=pipe):
state = getattr(_pipe, "_comp_attn_runtime_state", None)
enable_lam = bool(state.get("enable_lam", False)) if state else False
enable_state = bool(state.get("state_token_weights") is not None) if state else False
if state is None or (not enable_lam and not enable_state):
return _orig(x, y)
if self.has_image_input:
img = y[:, :257]
ctx = y[:, 257:]
else:
ctx = y
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(ctx))
v = self.v(ctx)
lam_out = lam_attention(q, k, v, self.num_heads, state)
if lam_out is None:
out = self.attn(q, k, v)
else:
out = lam_out
if self.has_image_input:
k_img = self.norm_k_img(self.k_img(img))
v_img = self.v_img(img)
img_out = self.attn(q, k_img, v_img)
out = out + img_out
return self.o(out)
cross_attn.forward = forward_with_lam.__get__(cross_attn, cross_attn.__class__)
cross_attn._comp_attn_patched = True
def get_grid_from_latents(latents: torch.Tensor, patch_size: tuple[int, int, int]) -> tuple[int, int, int]:
f = latents.shape[2] // patch_size[0]
h = latents.shape[3] // patch_size[1]
w = latents.shape[4] // patch_size[2]
return f, h, w
def wrap_model_fn(pipe) -> None:
if getattr(pipe, "_comp_attn_model_fn_patched", False):
return
orig_model_fn = pipe.model_fn
def model_fn_wrapper(*args, **kwargs):
comp_attn_state = kwargs.pop("comp_attn_state", None)
height = kwargs.get("height")
width = kwargs.get("width")
num_frames = kwargs.get("num_frames")
if num_frames is not None:
pipe._comp_attn_num_frames = num_frames
if comp_attn_state is None:
return orig_model_fn(*args, **kwargs)
latents = kwargs.get("latents")
timestep = kwargs.get("timestep")
context = kwargs.get("context")
clip_feature = kwargs.get("clip_feature")
reference_latents = kwargs.get("reference_latents")
state_vectors = comp_attn_state.get("state_vectors")
state_weights = comp_attn_state.get("state_weights")
state_len = int(comp_attn_state.get("state_len", 0))
prompt_len = int(comp_attn_state.get("prompt_len", context.shape[1] if context is not None else 0))
if context is not None and timestep is not None:
context = apply_sci(context, comp_attn_state, timestep)
if state_vectors is not None and state_len > 0:
state_vec = state_vectors.to(device=context.device, dtype=context.dtype)
if state_vec.dim() == 2:
state_vec = state_vec.unsqueeze(0)
if state_vec.shape[0] != context.shape[0]:
state_vec = state_vec.repeat(context.shape[0], 1, 1)
context = torch.cat([context, state_vec], dim=1)
kwargs["context"] = context
subject_mask = comp_attn_state.get("subject_token_mask")
if subject_mask is not None:
clip_len = clip_feature.shape[1] if clip_feature is not None and pipe.dit.require_clip_embedding else 0
pad_clip = torch.zeros((subject_mask.shape[0], clip_len), dtype=torch.bool)
pad_state = torch.zeros((subject_mask.shape[0], state_len), dtype=torch.bool)
comp_attn_state["subject_token_mask_lam"] = torch.cat([pad_clip, subject_mask.cpu(), pad_state], dim=1)
if state_vectors is not None and state_len > 0:
clip_len = clip_feature.shape[1] if clip_feature is not None and pipe.dit.require_clip_embedding else 0
pad_prompt = torch.zeros((state_len, clip_len + prompt_len), dtype=torch.bool)
ones_state = torch.ones((state_len, state_len), dtype=torch.bool)
state_token_mask = torch.cat([pad_prompt, ones_state], dim=1).any(dim=0)
comp_attn_state["state_token_mask"] = state_token_mask
if latents is not None and height is not None and width is not None:
f, h, w = get_grid_from_latents(latents, pipe.dit.patch_size)
if comp_attn_state.get("enable_lam", False):
q_len = f * h * w
if reference_latents is not None:
q_len = (f + 1) * h * w
layout_mask = comp_attn_state.get("layout_mask")
layout_shape = comp_attn_state.get("layout_shape")
if layout_mask is None or layout_shape != (latents.shape[0], q_len):
layout_mask = build_layout_mask_from_bboxes(
comp_attn_state.get("layout_bboxes"),
(f, h, w),
(int(height), int(width)),
device=latents.device,
dtype=latents.dtype,
)
if reference_latents is not None:
pad = torch.zeros((layout_mask.shape[0], layout_mask.shape[1], h * w), device=latents.device, dtype=latents.dtype)
layout_mask = torch.cat([pad, layout_mask], dim=-1)
if layout_mask.shape[0] != latents.shape[0]:
layout_mask = layout_mask.repeat(latents.shape[0], 1, 1)
comp_attn_state["layout_mask"] = layout_mask
comp_attn_state["layout_shape"] = (latents.shape[0], q_len)
if state_weights is not None:
weights = state_weights.to(device=latents.device, dtype=latents.dtype)
if weights.shape[0] != latents.shape[0]:
weights = weights.repeat(latents.shape[0], 1, 1)
if weights.shape[1] != f:
weights = weights.transpose(1, 2)
weights = F.interpolate(weights, size=f, mode="linear", align_corners=True)
weights = weights.transpose(1, 2)
if reference_latents is not None:
pad = torch.zeros((weights.shape[0], 1, weights.shape[2]), device=weights.device, dtype=weights.dtype)
weights = torch.cat([pad, weights], dim=1)
f = f + 1
comp_attn_state["state_token_weights"] = weights
comp_attn_state["grid_shape"] = (f, h, w)
if (
latents is not None
and latents.shape[0] == 2
and not comp_attn_state.get("apply_to_negative", False)
and "apply_mask" not in comp_attn_state
):
comp_attn_state["apply_mask"] = torch.tensor([1.0, 0.0], device=latents.device, dtype=latents.dtype)
pipe._comp_attn_runtime_state = comp_attn_state
try:
return orig_model_fn(*args, **kwargs)
finally:
pipe._comp_attn_runtime_state = None
pipe.model_fn = model_fn_wrapper
pipe._comp_attn_model_fn_patched = True