sam-motamed commited on
Commit
b51f7da
·
verified ·
1 Parent(s): ea2b214

Add dist_utils.py

Browse files
Files changed (1) hide show
  1. dist_utils.py +138 -0
dist_utils.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+ from diffusers.models.attention import Attention
7
+ from diffusers.models.embeddings import apply_rotary_emb
8
+
9
+ try:
10
+ import xfuser
11
+ from xfuser.core.distributed import (get_sequence_parallel_rank,
12
+ get_sequence_parallel_world_size,
13
+ get_sp_group, get_world_group,
14
+ init_distributed_environment,
15
+ initialize_model_parallel)
16
+ from xfuser.core.long_ctx_attention import xFuserLongContextAttention
17
+ except Exception as ex:
18
+ get_sequence_parallel_world_size = None
19
+ get_sequence_parallel_rank = None
20
+ xFuserLongContextAttention = None
21
+ get_sp_group = None
22
+ get_world_group = None
23
+ init_distributed_environment = None
24
+ initialize_model_parallel = None
25
+
26
+ def set_multi_gpus_devices(ulysses_degree, ring_degree):
27
+ if ulysses_degree > 1 or ring_degree > 1:
28
+ if get_sp_group is None:
29
+ raise RuntimeError("xfuser is not installed.")
30
+ dist.init_process_group("nccl")
31
+ print('parallel inference enabled: ulysses_degree=%d ring_degree=%d rank=%d world_size=%d' % (
32
+ ulysses_degree, ring_degree, dist.get_rank(),
33
+ dist.get_world_size()))
34
+ assert dist.get_world_size() == ring_degree * ulysses_degree, \
35
+ "number of GPUs(%d) should be equal to ring_degree * ulysses_degree." % dist.get_world_size()
36
+ init_distributed_environment(rank=dist.get_rank(), world_size=dist.get_world_size())
37
+ initialize_model_parallel(sequence_parallel_degree=dist.get_world_size(),
38
+ ring_degree=ring_degree,
39
+ ulysses_degree=ulysses_degree)
40
+ # device = torch.device("cuda:%d" % dist.get_rank())
41
+ device = torch.device(f"cuda:{get_world_group().local_rank}")
42
+ print('rank=%d device=%s' % (get_world_group().rank, str(device)))
43
+ else:
44
+ device = "cuda"
45
+ return device
46
+
47
+ class CogVideoXMultiGPUsAttnProcessor2_0:
48
+ r"""
49
+ Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
50
+ query and key vectors, but does not include spatial normalization.
51
+ """
52
+
53
+ def __init__(self):
54
+ if xFuserLongContextAttention is not None:
55
+ try:
56
+ self.hybrid_seq_parallel_attn = xFuserLongContextAttention()
57
+ except Exception:
58
+ self.hybrid_seq_parallel_attn = None
59
+ else:
60
+ self.hybrid_seq_parallel_attn = None
61
+ if not hasattr(F, "scaled_dot_product_attention"):
62
+ raise ImportError("CogVideoXAttnProcessor requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
63
+
64
+ def __call__(
65
+ self,
66
+ attn: Attention,
67
+ hidden_states: torch.Tensor,
68
+ encoder_hidden_states: torch.Tensor,
69
+ attention_mask: Optional[torch.Tensor] = None,
70
+ image_rotary_emb: Optional[torch.Tensor] = None,
71
+ ) -> torch.Tensor:
72
+ text_seq_length = encoder_hidden_states.size(1)
73
+
74
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
75
+
76
+ batch_size, sequence_length, _ = (
77
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
78
+ )
79
+
80
+ if attention_mask is not None:
81
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
82
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
83
+
84
+ query = attn.to_q(hidden_states)
85
+ key = attn.to_k(hidden_states)
86
+ value = attn.to_v(hidden_states)
87
+
88
+ inner_dim = key.shape[-1]
89
+ head_dim = inner_dim // attn.heads
90
+
91
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
92
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
93
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
94
+
95
+ if attn.norm_q is not None:
96
+ query = attn.norm_q(query)
97
+ if attn.norm_k is not None:
98
+ key = attn.norm_k(key)
99
+
100
+ # Apply RoPE if needed
101
+ if image_rotary_emb is not None:
102
+ query[:, :, text_seq_length:] = apply_rotary_emb(query[:, :, text_seq_length:], image_rotary_emb)
103
+ if not attn.is_cross_attention:
104
+ key[:, :, text_seq_length:] = apply_rotary_emb(key[:, :, text_seq_length:], image_rotary_emb)
105
+
106
+ if self.hybrid_seq_parallel_attn is None:
107
+ hidden_states = F.scaled_dot_product_attention(
108
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
109
+ )
110
+ hidden_states = hidden_states
111
+ else:
112
+ img_q = query[:, :, text_seq_length:].transpose(1, 2)
113
+ txt_q = query[:, :, :text_seq_length].transpose(1, 2)
114
+ img_k = key[:, :, text_seq_length:].transpose(1, 2)
115
+ txt_k = key[:, :, :text_seq_length].transpose(1, 2)
116
+ img_v = value[:, :, text_seq_length:].transpose(1, 2)
117
+ txt_v = value[:, :, :text_seq_length].transpose(1, 2)
118
+
119
+ hidden_states = self.hybrid_seq_parallel_attn(
120
+ None,
121
+ img_q, img_k, img_v, dropout_p=0.0, causal=False,
122
+ joint_tensor_query=txt_q,
123
+ joint_tensor_key=txt_k,
124
+ joint_tensor_value=txt_v,
125
+ joint_strategy='front',
126
+ ).transpose(1, 2)
127
+
128
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
129
+
130
+ # linear proj
131
+ hidden_states = attn.to_out[0](hidden_states)
132
+ # dropout
133
+ hidden_states = attn.to_out[1](hidden_states)
134
+
135
+ encoder_hidden_states, hidden_states = hidden_states.split(
136
+ [text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
137
+ )
138
+ return hidden_states, encoder_hidden_states