| import math |
| from dataclasses import dataclass |
| from typing import Optional, Sequence |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| from diffsynth.diffusion.base_pipeline import PipelineUnit |
| from diffsynth.pipelines.wan_video import ( |
| WanVideoPipeline, |
| WanVideoUnit_PromptEmbedder, |
| WanVideoUnit_CfgMerger, |
| ) |
|
|
|
|
| @dataclass |
| class CompAttnConfig: |
| subjects: Sequence[str] |
| bboxes: Optional[Sequence] = None |
| enable_sci: bool = True |
| enable_lam: bool = True |
| temperature: float = 0.2 |
| apply_to_negative: bool = False |
| interpolate: bool = False |
|
|
|
|
| def find_subsequence_indices(prompt_ids: torch.Tensor, subject_ids: torch.Tensor, valid_len: int) -> list[int]: |
| if subject_ids.numel() == 0 or valid_len <= 0: |
| return [] |
| prompt_slice = prompt_ids[:valid_len].tolist() |
| subject_list = subject_ids.tolist() |
| span = len(subject_list) |
| if span > valid_len: |
| return [] |
| for start in range(valid_len - span + 1): |
| if prompt_slice[start:start + span] == subject_list: |
| return list(range(start, start + span)) |
| return [] |
|
|
|
|
| def build_subject_token_mask(indices_list: list[list[int]], seq_len: int) -> torch.Tensor: |
| mask = torch.zeros((len(indices_list), seq_len), dtype=torch.bool) |
| for i, indices in enumerate(indices_list): |
| if not indices: |
| continue |
| mask[i, torch.tensor(indices, dtype=torch.long)] = True |
| return mask |
|
|
|
|
| def compute_saliency(prompt_vecs: torch.Tensor, anchor_vecs: torch.Tensor, tau: float) -> torch.Tensor: |
| prompt_norm = prompt_vecs / (prompt_vecs.norm(dim=-1, keepdim=True) + 1e-8) |
| anchor_norm = anchor_vecs / (anchor_vecs.norm(dim=-1, keepdim=True) + 1e-8) |
| cosine = torch.matmul(prompt_norm, anchor_norm.transpose(0, 1)) |
| scores = torch.exp(cosine / tau) |
| diag = scores.diagonal() |
| denom = scores.sum(dim=1).clamp(min=1e-8) |
| return diag / denom |
|
|
|
|
| def compute_delta(anchor_vecs: torch.Tensor) -> torch.Tensor: |
| total = anchor_vecs.sum(dim=0, keepdim=True) |
| return anchor_vecs * anchor_vecs.shape[0] - total |
|
|
|
|
| def apply_sci(context: torch.Tensor, state: dict, timestep: torch.Tensor) -> torch.Tensor: |
| if state is None or not state.get("enable_sci", False): |
| return context |
| subject_mask = state.get("subject_token_mask") |
| delta = state.get("delta") |
| saliency = state.get("saliency") |
| if subject_mask is None or delta is None or saliency is None: |
| return context |
| if subject_mask.numel() == 0: |
| return context |
| t_scale = float(state.get("timestep_scale", 1000.0)) |
| t_value = float(timestep.reshape(-1)[0].item()) |
| t_ratio = max(0.0, min(1.0, t_value / t_scale)) |
| omega = 1.0 - t_ratio |
| delta = delta.to(device=context.device, dtype=context.dtype) |
| saliency = saliency.to(device=context.device, dtype=context.dtype) |
| scale = omega * (1.0 - saliency).unsqueeze(-1) |
| delta = delta * scale |
| mask = subject_mask.to(device=context.device) |
| token_delta = torch.matmul(mask.to(dtype=context.dtype).transpose(0, 1), delta) |
| apply_mask = state.get("apply_mask") |
| if apply_mask is not None: |
| apply_mask = apply_mask.to(device=context.device, dtype=context.dtype).view(-1, 1, 1) |
| else: |
| apply_mask = 1.0 |
| return context + token_delta.unsqueeze(0) * apply_mask |
|
|
|
|
| def interpolate_bboxes(bboxes: torch.Tensor, target_frames: int) -> torch.Tensor: |
| if bboxes.shape[2] == target_frames: |
| return bboxes |
| b, m, f, _ = bboxes.shape |
| coords = bboxes.reshape(b * m, f, 4).transpose(1, 2) |
| coords = F.interpolate(coords, size=target_frames, mode="linear", align_corners=True) |
| coords = coords.transpose(1, 2).reshape(b, m, target_frames, 4) |
| return coords |
|
|
|
|
| def build_layout_mask_from_bboxes( |
| bboxes: torch.Tensor, |
| grid_size: tuple[int, int, int], |
| image_size: tuple[int, int], |
| device: torch.device, |
| dtype: torch.dtype, |
| ) -> torch.Tensor: |
| if bboxes is None: |
| return None |
| bboxes = bboxes.to(device=device, dtype=dtype) |
| b, m, f_layout, _ = bboxes.shape |
| f_grid, h_grid, w_grid = grid_size |
| height, width = image_size |
| layout = torch.zeros((b, m, f_grid, h_grid, w_grid), device=device, dtype=dtype) |
| for bi in range(b): |
| for mi in range(m): |
| for ti in range(f_layout): |
| pt = int(ti * f_grid / max(1, f_layout)) |
| pt = max(0, min(f_grid - 1, pt)) |
| x0, y0, x1, y1 = bboxes[bi, mi, ti] |
| x0 = float(x0) |
| y0 = float(y0) |
| x1 = float(x1) |
| y1 = float(y1) |
| if x1 <= x0 or y1 <= y0: |
| continue |
| px0 = int(math.floor(x0 / max(1.0, width) * w_grid)) |
| px1 = int(math.ceil(x1 / max(1.0, width) * w_grid)) |
| py0 = int(math.floor(y0 / max(1.0, height) * h_grid)) |
| py1 = int(math.ceil(y1 / max(1.0, height) * h_grid)) |
| px0 = max(0, min(w_grid, px0)) |
| px1 = max(0, min(w_grid, px1)) |
| py0 = max(0, min(h_grid, py0)) |
| py1 = max(0, min(h_grid, py1)) |
| if px1 <= px0 or py1 <= py0: |
| continue |
| layout[bi, mi, pt, py0:py1, px0:px1] = 1.0 |
| return layout.flatten(2) |
|
|
|
|
| def lam_attention( |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| num_heads: int, |
| state: dict, |
| ) -> Optional[torch.Tensor]: |
| subject_mask = state.get("subject_token_mask_lam") or state.get("subject_token_mask") |
| layout_mask = state.get("layout_mask") |
| if subject_mask is None or layout_mask is None: |
| return None |
| if subject_mask.numel() == 0 or layout_mask.numel() == 0: |
| return None |
| b, q_len, dim = q.shape |
| _, k_len, _ = k.shape |
| if layout_mask.shape[-1] != q_len: |
| return None |
| if subject_mask.shape[-1] != k_len: |
| return None |
| head_dim = dim // num_heads |
| qh = q.view(b, q_len, num_heads, head_dim).transpose(1, 2) |
| kh = k.view(b, k_len, num_heads, head_dim).transpose(1, 2) |
| vh = v.view(b, k_len, num_heads, head_dim).transpose(1, 2) |
| attn_scores = torch.matmul(qh.float(), kh.float().transpose(-2, -1)) / math.sqrt(head_dim) |
| attn_max = attn_scores.max(dim=-1, keepdim=True).values |
| attn_min = attn_scores.min(dim=-1, keepdim=True).values |
| g_plus = attn_max - attn_scores |
| g_minus = attn_min - attn_scores |
| subject_mask = subject_mask.to(device=attn_scores.device) |
| layout_mask = layout_mask.to(device=attn_scores.device, dtype=attn_scores.dtype) |
| apply_mask = state.get("apply_mask") |
| if apply_mask is not None: |
| layout_mask = layout_mask * apply_mask.to(device=layout_mask.device, dtype=layout_mask.dtype).view(-1, 1, 1) |
| subject_any = subject_mask.any(dim=0) |
| bias = torch.zeros_like(attn_scores) |
| for k_idx in range(subject_mask.shape[0]): |
| mask_k = subject_mask[k_idx] |
| if not mask_k.any(): |
| continue |
| mask_other = subject_any & (~mask_k) |
| mask_k = mask_k.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len) |
| mask_other = mask_other.to(dtype=attn_scores.dtype).view(1, 1, 1, k_len) |
| g_k = g_plus * mask_k + g_minus * mask_other |
| attn_k = attn_scores[..., subject_mask[k_idx]].mean(dim=-1).mean(dim=1) |
| adapt_mask = attn_k >= attn_k.mean(dim=-1, keepdim=True) |
| layout_k = layout_mask[:, k_idx] |
| adapt_f = adapt_mask.to(layout_k.dtype) |
| inter = (adapt_f * layout_k).sum(dim=-1) |
| union = (adapt_f + layout_k - adapt_f * layout_k).sum(dim=-1) |
| iou = inter / union.clamp(min=1e-6) |
| strength = (1.0 - iou).view(b, 1, 1, 1) |
| bias = bias + g_k * strength * layout_k.view(b, 1, q_len, 1) |
| attn_probs = torch.softmax(attn_scores + bias, dim=-1).to(vh.dtype) |
| out = torch.matmul(attn_probs, vh) |
| out = out.transpose(1, 2).reshape(b, q_len, dim) |
| return out |
|
|
|
|
| class CompAttnUnit(PipelineUnit): |
| def __init__(self): |
| super().__init__( |
| seperate_cfg=True, |
| input_params_posi={"prompt": "prompt", "context": "context"}, |
| input_params_nega={"prompt": "negative_prompt", "context": "context"}, |
| output_params=("comp_attn_state",), |
| onload_model_names=("text_encoder",), |
| ) |
|
|
| def _clean_text(self, pipe: WanVideoPipeline, text: str) -> str: |
| if getattr(pipe.tokenizer, "clean", None): |
| return pipe.tokenizer._clean(text) |
| return text |
|
|
| def _tokenize_subject(self, pipe: WanVideoPipeline, text: str) -> torch.Tensor: |
| text = self._clean_text(pipe, text) |
| tokens = pipe.tokenizer.tokenizer(text, add_special_tokens=False, return_tensors="pt") |
| return tokens["input_ids"][0] |
|
|
| def _normalize_bboxes(self, bboxes: Sequence) -> torch.Tensor: |
| bboxes = torch.as_tensor(bboxes, dtype=torch.float32) |
| if bboxes.dim() == 2 and bboxes.shape[-1] == 4: |
| bboxes = bboxes.unsqueeze(0).unsqueeze(0) |
| elif bboxes.dim() == 3 and bboxes.shape[-1] == 4: |
| bboxes = bboxes.unsqueeze(0) |
| elif bboxes.dim() != 4 or bboxes.shape[-1] != 4: |
| raise ValueError(f"comp_attn_bboxes must be (..., 4), got shape {tuple(bboxes.shape)}") |
| return bboxes |
|
|
| def process(self, pipe: WanVideoPipeline, prompt, context) -> dict: |
| config: Optional[CompAttnConfig] = getattr(pipe, "_comp_attn_config", None) |
| if context is None or prompt is None or config is None: |
| return {} |
| if not config.subjects: |
| return {} |
| negative_prompt = getattr(pipe, "_comp_attn_last_negative_prompt", None) |
| if (not config.apply_to_negative) and negative_prompt and prompt == negative_prompt: |
| return {} |
| pipe.load_models_to_device(self.onload_model_names) |
| ids, mask = pipe.tokenizer(prompt, return_mask=True, add_special_tokens=True) |
| prompt_ids = ids[0] |
| valid_len = int(mask[0].sum().item()) |
| indices_list = [] |
| valid_subjects = [] |
| for idx, subject in enumerate(config.subjects): |
| subject_ids = self._tokenize_subject(pipe, subject) |
| indices = find_subsequence_indices(prompt_ids, subject_ids, valid_len) |
| if not indices: |
| print(f"Comp-Attn: subject tokens not found in prompt: {subject}") |
| continue |
| indices_list.append(indices) |
| valid_subjects.append(idx) |
| if not indices_list: |
| return {} |
| subject_token_mask = build_subject_token_mask(indices_list, prompt_ids.shape[0]).to(device=context.device) |
| mask_float = subject_token_mask.to(dtype=context.dtype) |
| denom = mask_float.sum(dim=1, keepdim=True).clamp(min=1) |
| prompt_vecs = (mask_float @ context[0]) / denom |
| anchor_vecs = [] |
| for idx in valid_subjects: |
| subject = config.subjects[idx] |
| sub_ids, sub_mask = pipe.tokenizer(subject, return_mask=True, add_special_tokens=True) |
| sub_ids = sub_ids.to(pipe.device) |
| sub_mask = sub_mask.to(pipe.device) |
| emb = pipe.text_encoder(sub_ids, sub_mask) |
| pooled = (emb * sub_mask.unsqueeze(-1)).sum(dim=1) / sub_mask.sum(dim=1, keepdim=True).clamp(min=1) |
| anchor_vecs.append(pooled) |
| anchor_vecs = torch.cat(anchor_vecs, dim=0) |
| saliency = compute_saliency(prompt_vecs.float(), anchor_vecs.float(), float(config.temperature)).to(prompt_vecs.dtype) |
| delta = compute_delta(anchor_vecs.to(prompt_vecs.dtype)) |
| bboxes = None |
| if config.bboxes is not None: |
| bboxes = self._normalize_bboxes(config.bboxes) |
| if bboxes.shape[1] >= len(config.subjects): |
| bboxes = bboxes[:, valid_subjects] |
| if bboxes.shape[1] != len(valid_subjects): |
| print("Comp-Attn: bboxes subject count mismatch, disable LAM") |
| bboxes = None |
| if bboxes is not None and config.interpolate and getattr(pipe, "_comp_attn_num_frames", None) is not None: |
| bboxes = interpolate_bboxes(bboxes, int(pipe._comp_attn_num_frames)) |
| state = { |
| "enable_sci": bool(config.enable_sci), |
| "enable_lam": bool(config.enable_lam) and bboxes is not None, |
| "subject_token_mask": subject_token_mask, |
| "saliency": saliency, |
| "delta": delta, |
| "layout_bboxes": bboxes, |
| "timestep_scale": 1000.0, |
| "apply_to_negative": bool(config.apply_to_negative), |
| } |
| if negative_prompt and prompt == negative_prompt: |
| pipe._comp_attn_state_neg = state |
| else: |
| pipe._comp_attn_state_pos = state |
| return {"comp_attn_state": state} |
|
|
|
|
| class CompAttnMergeUnit(PipelineUnit): |
| def __init__(self): |
| super().__init__(input_params=("cfg_merge",), output_params=("comp_attn_state",)) |
|
|
| def process(self, pipe: WanVideoPipeline, cfg_merge) -> dict: |
| if not cfg_merge: |
| return {} |
| state_pos = getattr(pipe, "_comp_attn_state_pos", None) |
| state_neg = getattr(pipe, "_comp_attn_state_neg", None) |
| merged = state_pos or state_neg |
| if merged is None: |
| return {} |
| merged = dict(merged) |
| apply_to_negative = bool(merged.get("apply_to_negative", False)) |
| merged["apply_mask"] = torch.tensor([1.0, 1.0 if apply_to_negative else 0.0]) |
| return {"comp_attn_state": merged} |
|
|
|
|
| def _patch_cross_attention(pipe: WanVideoPipeline): |
| for block in pipe.dit.blocks: |
| cross_attn = block.cross_attn |
| if getattr(cross_attn, "_comp_attn_patched", False): |
| continue |
| orig_forward = cross_attn.forward |
|
|
| def forward_with_lam(self, x, y, _orig=orig_forward, _pipe=pipe): |
| state = getattr(_pipe, "_comp_attn_runtime_state", None) |
| if state is None or not state.get("enable_lam", False): |
| return _orig(x, y) |
| if self.has_image_input: |
| img = y[:, :257] |
| ctx = y[:, 257:] |
| else: |
| ctx = y |
| q = self.norm_q(self.q(x)) |
| k = self.norm_k(self.k(ctx)) |
| v = self.v(ctx) |
| lam_out = lam_attention(q, k, v, self.num_heads, state) |
| if lam_out is None: |
| out = self.attn(q, k, v) |
| else: |
| out = lam_out |
| if self.has_image_input: |
| k_img = self.norm_k_img(self.k_img(img)) |
| v_img = self.v_img(img) |
| img_out = self.attn(q, k_img, v_img) |
| out = out + img_out |
| return self.o(out) |
|
|
| cross_attn.forward = forward_with_lam.__get__(cross_attn, cross_attn.__class__) |
| cross_attn._comp_attn_patched = True |
|
|
|
|
| def _get_grid_from_latents(latents: torch.Tensor, patch_size: tuple[int, int, int]) -> tuple[int, int, int]: |
| f = latents.shape[2] // patch_size[0] |
| h = latents.shape[3] // patch_size[1] |
| w = latents.shape[4] // patch_size[2] |
| return f, h, w |
|
|
|
|
| def _wrap_model_fn(pipe: WanVideoPipeline): |
| if getattr(pipe, "_comp_attn_model_fn_patched", False): |
| return |
| orig_model_fn = pipe.model_fn |
|
|
| def model_fn_wrapper(*args, **kwargs): |
| comp_attn_state = kwargs.pop("comp_attn_state", None) |
| height = kwargs.get("height") |
| width = kwargs.get("width") |
| num_frames = kwargs.get("num_frames") |
| if num_frames is not None: |
| pipe._comp_attn_num_frames = num_frames |
| if comp_attn_state is None: |
| return orig_model_fn(*args, **kwargs) |
| latents = kwargs.get("latents") |
| timestep = kwargs.get("timestep") |
| context = kwargs.get("context") |
| clip_feature = kwargs.get("clip_feature") |
| reference_latents = kwargs.get("reference_latents") |
| if context is not None and timestep is not None: |
| context = apply_sci(context, comp_attn_state, timestep) |
| kwargs["context"] = context |
| if comp_attn_state.get("enable_lam", False) and latents is not None and height is not None and width is not None: |
| f, h, w = _get_grid_from_latents(latents, pipe.dit.patch_size) |
| base_f = f |
| q_len = f * h * w |
| if reference_latents is not None: |
| q_len = (f + 1) * h * w |
| layout_mask = comp_attn_state.get("layout_mask") |
| layout_shape = comp_attn_state.get("layout_shape") |
| if layout_mask is None or layout_shape != (latents.shape[0], q_len): |
| layout_mask = build_layout_mask_from_bboxes( |
| comp_attn_state.get("layout_bboxes"), |
| (base_f, h, w), |
| (int(height), int(width)), |
| device=latents.device, |
| dtype=latents.dtype, |
| ) |
| if reference_latents is not None: |
| pad = torch.zeros((layout_mask.shape[0], layout_mask.shape[1], h * w), device=latents.device, dtype=latents.dtype) |
| layout_mask = torch.cat([pad, layout_mask], dim=-1) |
| if layout_mask.shape[0] != latents.shape[0]: |
| layout_mask = layout_mask.repeat(latents.shape[0], 1, 1) |
| comp_attn_state["layout_mask"] = layout_mask |
| comp_attn_state["layout_shape"] = (latents.shape[0], q_len) |
| subject_mask = comp_attn_state.get("subject_token_mask") |
| if subject_mask is not None and clip_feature is not None and pipe.dit.require_clip_embedding: |
| pad_len = clip_feature.shape[1] |
| pad = torch.zeros((subject_mask.shape[0], pad_len), dtype=torch.bool) |
| comp_attn_state["subject_token_mask_lam"] = torch.cat([pad, subject_mask.cpu()], dim=1) |
| if ( |
| latents is not None |
| and latents.shape[0] == 2 |
| and not comp_attn_state.get("apply_to_negative", False) |
| and "apply_mask" not in comp_attn_state |
| ): |
| comp_attn_state["apply_mask"] = torch.tensor([1.0, 0.0], device=latents.device, dtype=latents.dtype) |
| pipe._comp_attn_runtime_state = comp_attn_state |
| try: |
| return orig_model_fn(*args, **kwargs) |
| finally: |
| pipe._comp_attn_runtime_state = None |
|
|
| pipe.model_fn = model_fn_wrapper |
| pipe._comp_attn_model_fn_patched = True |
|
|
|
|
| def attach_comp_attn(pipe: WanVideoPipeline) -> WanVideoPipeline: |
| if getattr(pipe, "_comp_attn_attached", False): |
| return pipe |
| prompt_idx = None |
| cfg_idx = None |
| for idx, unit in enumerate(pipe.units): |
| if prompt_idx is None and isinstance(unit, WanVideoUnit_PromptEmbedder): |
| prompt_idx = idx |
| if cfg_idx is None and isinstance(unit, WanVideoUnit_CfgMerger): |
| cfg_idx = idx |
| if prompt_idx is not None: |
| pipe.units.insert(prompt_idx + 1, CompAttnUnit()) |
| else: |
| pipe.units.append(CompAttnUnit()) |
| if cfg_idx is not None: |
| pipe.units.insert(cfg_idx + 1, CompAttnMergeUnit()) |
| else: |
| pipe.units.append(CompAttnMergeUnit()) |
| _patch_cross_attention(pipe) |
| _wrap_model_fn(pipe) |
| pipe._comp_attn_attached = True |
| return pipe |
|
|
|
|
| class CompAttnPipelineWrapper: |
| def __init__(self, pipe: WanVideoPipeline): |
| self.pipe = attach_comp_attn(pipe) |
|
|
| def __getattr__(self, name): |
| return getattr(self.pipe, name) |
|
|
| def __call__(self, prompt: str, negative_prompt: str = "", comp_attn: Optional[CompAttnConfig] = None, **kwargs): |
| num_frames = kwargs.get("num_frames") |
| if num_frames is not None: |
| self.pipe._comp_attn_num_frames = num_frames |
| self.pipe._comp_attn_config = comp_attn |
| self.pipe._comp_attn_last_prompt = prompt |
| self.pipe._comp_attn_last_negative_prompt = negative_prompt |
| return self.pipe(prompt=prompt, negative_prompt=negative_prompt, **kwargs) |
|
|
|
|
| def build_comp_attn_pipeline(*args, **kwargs) -> CompAttnPipelineWrapper: |
| pipe = WanVideoPipeline.from_pretrained(*args, **kwargs) |
| return CompAttnPipelineWrapper(pipe) |
|
|