sam-motamed commited on
Commit
c6cc81a
·
verified ·
1 Parent(s): f3b6d71

Move dist_utils.py to diffusers/

Browse files
Files changed (1) hide show
  1. dist_utils.py +0 -138
dist_utils.py DELETED
@@ -1,138 +0,0 @@
1
- from typing import Optional
2
-
3
- import torch
4
- import torch.nn.functional as F
5
- import torch.distributed as dist
6
- from diffusers.models.attention import Attention
7
- from diffusers.models.embeddings import apply_rotary_emb
8
-
9
- try:
10
- import xfuser
11
- from xfuser.core.distributed import (get_sequence_parallel_rank,
12
- get_sequence_parallel_world_size,
13
- get_sp_group, get_world_group,
14
- init_distributed_environment,
15
- initialize_model_parallel)
16
- from xfuser.core.long_ctx_attention import xFuserLongContextAttention
17
- except Exception as ex:
18
- get_sequence_parallel_world_size = None
19
- get_sequence_parallel_rank = None
20
- xFuserLongContextAttention = None
21
- get_sp_group = None
22
- get_world_group = None
23
- init_distributed_environment = None
24
- initialize_model_parallel = None
25
-
26
- def set_multi_gpus_devices(ulysses_degree, ring_degree):
27
- if ulysses_degree > 1 or ring_degree > 1:
28
- if get_sp_group is None:
29
- raise RuntimeError("xfuser is not installed.")
30
- dist.init_process_group("nccl")
31
- print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
32
- ulysses_degree, ring_degree, dist.get_rank(),
33
- dist.get_world_size()))
34
- assert dist.get_world_size() == ring_degree * ulysses_degree, \
35
- "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
36
- init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
37
- initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
38
- ring_degree=ring_degree,
39
- ulysses_degree=ulysses_degree)
40
- # device = torch.device("cuda:%d" % dist.get_rank())
41
- device = torch.device(f"cuda:{get_world_group().local_rank}")
42
- print('rank=%d device=%s' % (get_world_group().rank, str(device)))
43
- else:
44
- device = "cuda"
45
- return device
46
-
47
- class CogVideoXMultiGPUsAttnProcessor2_0:
48
- r"""
49
- Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
50
- query and key vectors, but does not include spatial normalization.
51
- """
52
-
53
- def __init__(self):
54
- if xFuserLongContextAttention is not None:
55
- try:
56
- self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
57
- except Exception:
58
- self.hybrid_seq_parallel_attn = None
59
- else:
60
- self.hybrid_seq_parallel_attn = None
61
- if not hasattr(F, "scaled_dot_product_attention"):
62
- raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
63
-
64
- def __call__(
65
- self,
66
- attn: Attention,
67
- hidden_states: torch.Tensor,
68
- encoder_hidden_states: torch.Tensor,
69
- attention_mask: Optional[torch.Tensor] = None,
70
- image_rotary_emb: Optional[torch.Tensor] = None,
71
- ) -> torch.Tensor:
72
- text_seq_length = encoder_hidden_states.size(1)
73
-
74
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
75
-
76
- batch_size, sequence_length, _ = (
77
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
78
- )
79
-
80
- if attention_mask is not None:
81
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
82
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
83
-
84
- query = attn.to_q(hidden_states)
85
- key = attn.to_k(hidden_states)
86
- value = attn.to_v(hidden_states)
87
-
88
- inner_dim = key.shape[-1]
89
- head_dim = inner_dim // attn.heads
90
-
91
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
92
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
93
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
94
-
95
- if attn.norm_q is not None:
96
- query = attn.norm_q(query)
97
- if attn.norm_k is not None:
98
- key = attn.norm_k(key)
99
-
100
- # Apply RoPE if needed
101
- if image_rotary_emb is not None:
102
- query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
103
- if not attn.is_cross_attention:
104
- key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
105
-
106
- if self.hybrid_seq_parallel_attn is None:
107
- hidden_states = F.scaled_dot_product_attention(
108
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
109
- )
110
- hidden_states = hidden_states
111
- else:
112
- img_q = query[:, :, text_seq_length:].transpose(1, 2)
113
- txt_q = query[:, :, :text_seq_length].transpose(1, 2)
114
- img_k = key[:, :, text_seq_length:].transpose(1, 2)
115
- txt_k = key[:, :, :text_seq_length].transpose(1, 2)
116
- img_v = value[:, :, text_seq_length:].transpose(1, 2)
117
- txt_v = value[:, :, :text_seq_length].transpose(1, 2)
118
-
119
- hidden_states = self.hybrid_seq_parallel_attn(
120
- None,
121
- img_q, img_k, img_v, dropout_p=0.0, causal=False,
122
- joint_tensor_query=txt_q,
123
- joint_tensor_key=txt_k,
124
- joint_tensor_value=txt_v,
125
- joint_strategy='front',
126
- ).transpose(1, 2)
127
-
128
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
129
-
130
- # linear proj
131
- hidden_states = attn.to_out[0](hidden_states)
132
- # dropout
133
- hidden_states = attn.to_out[1](hidden_states)
134
-
135
- encoder_hidden_states, hidden_states = hidden_states.split(
136
- [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
137
- )
138
- return hidden_states, encoder_hidden_states