| |
| import os |
| import sys |
| import types |
| from typing import List, Optional, Tuple |
|
|
| import safetensors |
| import torch |
| import torch.nn.functional as F |
| import transformers |
| from packaging import version |
| from peft import PeftModel |
| from torch.utils.data import DataLoader |
| from transformers import PreTrainedModel, trainer |
| from transformers.modeling_utils import unwrap_model |
|
|
| from swift.utils import get_logger, torchacc_trim_graph, use_torchacc |
|
|
| logger = get_logger() |
|
|
|
|
| |
| def get_bucket_sizes(max_length: int) -> List[int]: |
| """Get the bucket sizes for TorchAcc. |
| You can set the environment variable TORCHACC_DATA_BUCKETS to specify |
| the bucket sizes. If not set, we use a normal distribution bucketing with |
| 8 buckets. |
| """ |
| padding_p_base = 2 |
| if os.getenv('TORCHACC_DATA_BUCKETS') is not None: |
| bucket_sizes = [int(x) for x in os.getenv('TORCHACC_DATA_BUCKETS').split(',')] |
| bucket_sizes.append(max_length) |
| else: |
| if os.getenv('TORCHACC_CACHE_PATH') is not None: |
| padding_p_base = 1.4 |
| padding_p_base = os.getenv('TORCHACC_PADDING_P_BASE', padding_p_base) |
| try: |
| padding_p_base = float(padding_p_base) |
| except ValueError as e: |
| logger.error(f'Expect TORCHACC_PADDINF_P_BASE to be a float number, but encountered {padding_p_base}') |
| raise e |
| bucket_sizes = [16, 32, 48, 64, 96, 128] |
| base_size = 256 |
| while base_size < max_length: |
| bucket_sizes.append((int(base_size) + 127) // 128 * 128) |
| base_size *= padding_p_base |
| bucket_sizes.append(max_length) |
|
|
| return bucket_sizes |
|
|
|
|
| def _get_closet_bucket(bucket_sizes, data_length): |
| """Select the one from bucket_sizes that is closest in distance to |
| data_length. This is required for TorchAcc. |
| """ |
| closest_length = sys.maxsize |
| for b in bucket_sizes: |
| if b == data_length or ((b < closest_length) and (b > data_length)): |
| closest_length = b |
|
|
| if closest_length == sys.maxsize: |
| bucket_sizes.append(data_length) |
| closest_length = data_length |
|
|
| return closest_length |
|
|
|
|
| def pad_and_split_batch(padding_to, input_ids, attention_mask, labels, loss_scale, max_length, tokenizer, rank, |
| world_size, padding_right): |
| if padding_to is None: |
| longest_len = input_ids.shape[-1] |
| bucket_sizes = get_bucket_sizes(max_length) |
| bucket_data_length = _get_closet_bucket(bucket_sizes, longest_len) |
| padding_length = bucket_data_length - input_ids.shape[1] |
| pad_tuple = (0, padding_length) if padding_right else (padding_length, 0) |
| input_ids = F.pad(input_ids, pad_tuple, 'constant', tokenizer.pad_token_id) |
| attention_mask = F.pad(attention_mask, pad_tuple, 'constant', 0) |
| if loss_scale: |
| loss_scale = F.pad(loss_scale, pad_tuple, 'constant', 0.) |
| labels = F.pad(labels, pad_tuple, 'constant', -100) |
|
|
| |
| batch_size = input_ids.shape[0] // world_size |
| if batch_size > 0: |
| start = rank * batch_size |
| end = (rank + 1) * batch_size |
| input_ids = input_ids[start:end, :] |
| attention_mask = attention_mask[start:end, :] |
| labels = labels[start:end, :] |
| if loss_scale: |
| loss_scale = loss_scale[start:end, :] |
| return input_ids, attention_mask, labels, loss_scale |
|
|
|
|
| def ta_train_dataloader(train_dataset, data_collator, sampler, args, batch_size): |
| |
| def acc_skip_first_batches(dataloader, num_batches=0): |
| from accelerate.data_loader import SkipBatchSampler |
| batch_sampler = SkipBatchSampler(dataloader._loader.batch_sampler, skip_batches=num_batches) |
| try: |
| dataset = dataloader.dataset |
| except AttributeError: |
| dataset = dataloader._loader.dataset |
| dataloader_params = { |
| 'collate_fn': data_collator, |
| 'num_workers': args.dataloader_num_workers, |
| 'pin_memory': args.dataloader_pin_memory, |
| 'persistent_workers': args.dataloader_persistent_workers, |
| } |
|
|
| if not isinstance(train_dataset, torch.utils.data.IterableDataset): |
| dataloader_params['batch_sampler'] = batch_sampler |
| dataloader_params['worker_init_fn'] = trainer.seed_worker |
|
|
| return ta.AsyncLoader(DataLoader(dataset, **dataloader_params), args.device) |
|
|
| trainer.skip_first_batches = acc_skip_first_batches |
|
|
| |
| import torchacc as ta |
|
|
| dataloader_params = { |
| 'batch_size': batch_size, |
| 'collate_fn': data_collator, |
| 'num_workers': args.dataloader_num_workers, |
| 'pin_memory': args.dataloader_pin_memory, |
| 'persistent_workers': args.dataloader_persistent_workers, |
| } |
|
|
| if not isinstance(train_dataset, torch.utils.data.IterableDataset): |
| dataloader_params['sampler'] = sampler |
| dataloader_params['drop_last'] = args.dataloader_drop_last |
| dataloader_params['worker_init_fn'] = trainer.seed_worker |
|
|
| return ta.AsyncLoader(DataLoader(train_dataset, **dataloader_params), args.device) |
|
|
|
|
| def ta_eval_dataloader(eval_dataset, data_collator, sampler, args): |
| import torchacc as ta |
|
|
| dataloader_params = { |
| 'batch_size': args.eval_batch_size, |
| 'collate_fn': data_collator, |
| 'num_workers': args.dataloader_num_workers, |
| 'pin_memory': args.dataloader_pin_memory, |
| 'persistent_workers': args.dataloader_persistent_workers, |
| } |
|
|
| if not isinstance(eval_dataset, torch.utils.data.IterableDataset): |
| dataloader_params['sampler'] = sampler |
| dataloader_params['drop_last'] = args.dataloader_drop_last |
|
|
| return ta.AsyncLoader(DataLoader(eval_dataset, **dataloader_params), args.device) |
|
|
|
|
| def ta_test_dataloader(test_dataset, data_collator, sampler, args): |
| import torchacc as ta |
|
|
| dataloader_params = { |
| 'batch_size': args.eval_batch_size, |
| 'collate_fn': data_collator, |
| 'num_workers': args.dataloader_num_workers, |
| 'pin_memory': args.dataloader_pin_memory, |
| 'persistent_workers': args.dataloader_persistent_workers, |
| } |
|
|
| if not isinstance(test_dataset, torch.utils.data.IterableDataset): |
| dataloader_params['sampler'] = sampler |
| dataloader_params['drop_last'] = args.dataloader_drop_last |
|
|
| |
| return ta.AsyncLoader(DataLoader(test_dataset, **dataloader_params), args.device) |
|
|
|
|
| |
| def ta_save_optimizer_and_scheduler(optimizer, lr_scheduler, output_dir): |
| import torch_xla.core.xla_model as xm |
| xm.rendezvous('saving_optimizer_states') |
| xm.save(optimizer.state_dict(), os.path.join(output_dir, f'optimizer_{xm.get_ordinal()}.pt'), master_only=False) |
| xm.save(lr_scheduler.state_dict(), os.path.join(output_dir, f'scheduler_{xm.get_ordinal()}.pt'), master_only=False) |
| xm.rendezvous('saving_optimizer_states_done') |
|
|
|
|
| def ta_load_optimizer_and_scheduler(optimizer, lr_scheduler, checkpoint, device): |
| import torch_xla.core.xla_model as xm |
| optimizer_state = torch.load(os.path.join(checkpoint, f'optimizer_{xm.get_ordinal()}.pt'), map_location='cpu') |
| lr_scheduler_state = torch.load(os.path.join(checkpoint, f'scheduler_{xm.get_ordinal()}.pt'), map_location='cpu') |
| xm.send_cpu_data_to_device(optimizer_state, device) |
| xm.send_cpu_data_to_device(lr_scheduler_state, device) |
|
|
| optimizer.load_state_dict(optimizer_state) |
| lr_scheduler.load_state_dict(lr_scheduler_state) |
| return optimizer, lr_scheduler |
|
|
|
|
| def save_ta_ddp_checkpoint(self_model, tokenizer, args, output_dir: Optional[str] = None): |
| output_dir = output_dir if output_dir is not None else args.output_dir |
| import torch_xla.core.xla_model as xm |
|
|
| model = self_model |
|
|
| if xm.is_master_ordinal(local=False): |
| os.makedirs(output_dir, exist_ok=True) |
| torch.save(args, os.path.join(output_dir, 'training_args.bin')) |
|
|
| xm.mark_step() |
| |
| |
| supported_classes = (PreTrainedModel, PeftModel) |
| if not isinstance(model, supported_classes): |
| if isinstance(unwrap_model(model), supported_classes): |
| unwrap_model(model).save_pretrained( |
| output_dir, |
| is_main_process=args.should_save, |
| state_dict=xm._maybe_convert_to_cpu(model.state_dict()), |
| save_function=xm.save, |
| safe_serialization=args.save_safetensors, |
| ) |
| else: |
| logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') |
| state_dict = xm._maybe_convert_to_cpu(model.state_dict()) |
| if args.save_safetensors: |
| safetensors.torch.save_file(state_dict, os.path.join(output_dir, 'model.safetensors')) |
| else: |
| torch.save(state_dict, os.path.join(output_dir, 'pytorch_model.bin')) |
| else: |
| model.save_pretrained( |
| output_dir, |
| is_main_process=args.should_save, |
| save_function=xm.save, |
| safe_serialization=args.save_safetensors, |
| state_dict=xm._maybe_convert_to_cpu(model.state_dict())) |
| if tokenizer is not None and args.should_save: |
| tokenizer.save_pretrained(output_dir) |
|
|
|
|
| def save_ta_fsdp_checkpoint(self_model, tokenizer, args, output_dir): |
| import torch_xla.core.xla_model as xm |
| from torch_xla.distributed.fsdp import consolidate_sharded_model_checkpoints |
|
|
| xm.mark_step() |
|
|
| if xm.is_master_ordinal(local=False): |
| os.makedirs(output_dir, exist_ok=True) |
| torch.save(args, os.path.join(output_dir, 'training_args.bin')) |
|
|
| supported_classes = (PreTrainedModel, PeftModel) |
| model = self_model._get_underlay_model().module.module |
| unwrapped_model = unwrap_model(model) |
|
|
| xm.rendezvous('saving_checkpoint') |
| ckpt = { |
| 'model': self_model._get_underlay_model().state_dict(), |
| 'shard_metadata': self_model._get_underlay_model().get_shard_metadata(), |
| } |
| if isinstance(model, PeftModel): |
| ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-adapter_model.bin') |
| else: |
| ckpt_path = os.path.join(output_dir, f'rank{args.process_index}-of-{args.global_world_size}-pytorch_model.bin') |
| xm.save(ckpt, ckpt_path, master_only=False) |
| |
| xm.rendezvous('save_full_checkpoints') |
|
|
| if tokenizer is not None and args.should_save: |
| tokenizer.save_pretrained(output_dir, is_main_process=xm.is_master_ordinal(local=False), save_function=xm.save) |
|
|
| |
| if xm.is_master_ordinal(local=False): |
| if isinstance(model, PeftModel): |
| ckpt_suffix = 'rank*-of-*-adapter_model.bin' |
| else: |
| ckpt_suffix = 'rank*-of-*-pytorch_model.bin' |
| full_state_dict, _ = consolidate_sharded_model_checkpoints( |
| ckpt_prefix=os.path.join(output_dir, ''), ckpt_suffix=ckpt_suffix, save_model=False) |
|
|
| if isinstance(unwrapped_model, supported_classes): |
| unwrapped_model.save_pretrained( |
| output_dir, |
| state_dict=full_state_dict, |
| save_function=xm.save, |
| safe_serialization=args.save_safetensors, |
| ) |
| else: |
| logger.info('Trainer.model is not a `PreTrainedModel`, only saving its state dict.') |
| if args.save_safetensors: |
| safetensors.torch.save_file(full_state_dict, os.path.join(output_dir, 'model.safetensors')) |
| else: |
| torch.save(full_state_dict, os.path.join(output_dir, 'pytorch_model.bin')) |
|
|
| xm.rendezvous('ckpt_consolidation') |
| |
| os.remove(ckpt_path) |
|
|
|
|
| def ta_trim_graph(): |
| if use_torchacc() and torchacc_trim_graph(): |
| import torchacc as ta |
| ta.mark_step() |
|
|
|
|
| |
| def rotate_half(x): |
| """Rotates half the hidden dims of the input.""" |
| x1 = x[..., :x.shape[-1] // 2] |
| x2 = x[..., x.shape[-1] // 2:] |
| return torch.cat((-x2, x1), dim=-1) |
|
|
|
|
| def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): |
| """Applies Rotary Position Embedding to the query and key tensors. |
| |
| Args: |
| q (`torch.Tensor`): The query tensor. |
| k (`torch.Tensor`): The key tensor. |
| cos (`torch.Tensor`): The cosine part of the rotary embedding. |
| sin (`torch.Tensor`): The sine part of the rotary embedding. |
| position_ids (`torch.Tensor`): |
| The position indices of the tokens corresponding to the query and key tensors. For example, this can be |
| used to pass offsetted position ids when working with a KV-cache. |
| unsqueeze_dim (`int`, *optional*, defaults to 1): |
| The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and |
| sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note |
| that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and |
| k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes |
| cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have |
| the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
| Returns: |
| `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. |
| """ |
| if position_ids is not None: |
| cos = cos[position_ids].unsqueeze(unsqueeze_dim) |
| sin = sin[position_ids].unsqueeze(unsqueeze_dim) |
| else: |
| cos = cos.unsqueeze(unsqueeze_dim) |
| sin = sin.unsqueeze(unsqueeze_dim) |
| q_embed = (q * cos) + (rotate_half(q) * sin) |
| k_embed = (k * cos) + (rotate_half(k) * sin) |
| return q_embed, k_embed |
|
|
|
|
| def patch_acc_model(args, model): |
| if not args.use_flash_attn: |
| logger.warn('Currently use flash attn for torchacc.') |
| if args.model_type.startswith('qwen1half') or args.model_type.startswith('qwen2'): |
| model = patch_qwen2_model(model) |
| elif args.model_type.startswith('qwen'): |
| import torchacc as ta |
| model = ta.patch_qwen_model(model) |
| elif args.model_type.startswith('baichuan'): |
| model = patch_baichuan_model(model) |
| elif args.model_type.startswith('llama') or args.model_type.startswith('yi'): |
| model = patch_llama_model(model) |
| elif args.model_type.startswith('chatglm'): |
| model = patah_chatglm_model(model) |
| return model |
|
|
|
|
| def patch_llama_model(model): |
|
|
| def update_causal_mask(self, *args, **kwargs): |
| |
| return None |
|
|
| def llama_attn_forward(self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
| from torchacc.ops import flash_attn_varlen_xla |
| import einops |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) |
| key_states = ( |
| self.k_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) |
| value_states = ( |
| self.v_proj(hidden_states).view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)) |
|
|
| kv_seq_len = key_states.shape[-2] |
| assert past_key_value is None, 'past_key_value is not supported' |
|
|
| if version.parse(transformers.__version__) >= version.parse('4.36'): |
| cos, sin = self.rotary_emb(value_states, position_ids) |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| else: |
| cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
|
| assert not output_attentions, 'output_attentions is not supported' |
|
|
| if past_key_value is not None: |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
| past_key_value = (key_states, value_states) if use_cache else None |
|
|
| |
| |
| |
| q = einops.rearrange(query_states, 'b h s ... -> (b s) h ...') |
| k = einops.rearrange(key_states, 'b h s ... -> (b s) h ...') |
| v = einops.rearrange(value_states, 'b h s ... -> (b s) h ...') |
| max_s = q_len |
| cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) |
| output = flash_attn_varlen_xla( |
| q, k, v, cu_q_lens, cu_q_lens, max_s, max_s, 0.0, softmax_scale=None, causal=True) |
| output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) |
|
|
| return self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')), None, past_key_value |
|
|
| for layer in model.model.layers: |
| layer.self_attn.forward = types.MethodType(llama_attn_forward, layer.self_attn) |
|
|
| if version.parse(transformers.__version__) >= version.parse('4.38'): |
| model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model) |
|
|
| return model |
|
|
|
|
| def patah_chatglm_model(model): |
|
|
| def chatglm_apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor: |
| |
| sq, _, np, _ = x.size(0), x.size(1), x.size(2), x.size(3) |
| rot_dim = rope_cache.shape[-2] * 2 |
| x, x_pass = x[..., :rot_dim], x[..., rot_dim:] |
| |
| rope_cache = rope_cache[:sq] |
| xshaped = x.reshape(sq, -1, np, rot_dim // 2, 2) |
| rope_cache = rope_cache.view(sq, -1, 1, xshaped.size(3), 2) |
| x_out2 = torch.stack( |
| [ |
| xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], |
| xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], |
| ], |
| -1, |
| ) |
| x_out2 = x_out2.flatten(3) |
| return torch.cat((x_out2, x_pass), dim=-1) |
|
|
| def chatglm_attn_forward(self, |
| hidden_states, |
| attention_mask, |
| rotary_pos_emb, |
| kv_cache=None, |
| use_cache=True, |
| **kwargs): |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| mixed_x_layer = self.query_key_value(hidden_states) |
|
|
| if self.multi_query_attention: |
| (query_layer, key_layer, value_layer) = mixed_x_layer.split( |
| [ |
| self.num_attention_heads_per_partition * self.hidden_size_per_attention_head, |
| self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, |
| self.num_multi_query_groups_per_partition * self.hidden_size_per_attention_head, |
| ], |
| dim=-1, |
| ) |
| query_layer = query_layer.view(query_layer.size()[:-1] + (self.num_attention_heads_per_partition, |
| self.hidden_size_per_attention_head)) |
| key_layer = key_layer.view(key_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, |
| self.hidden_size_per_attention_head)) |
| value_layer = value_layer.view(value_layer.size()[:-1] + (self.num_multi_query_groups_per_partition, |
| self.hidden_size_per_attention_head)) |
| else: |
| new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads_per_partition, |
| 3 * self.hidden_size_per_attention_head) |
| mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) |
|
|
| |
| (query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3) |
|
|
| |
| if rotary_pos_emb is not None: |
| query_layer = chatglm_apply_rotary_pos_emb(query_layer, rotary_pos_emb) |
| key_layer = chatglm_apply_rotary_pos_emb(key_layer, rotary_pos_emb) |
|
|
| |
| if kv_cache is not None: |
| cache_k, cache_v = kv_cache |
| key_layer = torch.cat((cache_k, key_layer), dim=0) |
| value_layer = torch.cat((cache_v, value_layer), dim=0) |
| if use_cache: |
| kv_cache = (key_layer, value_layer) |
| else: |
| kv_cache = None |
|
|
| if self.multi_query_attention: |
| key_layer = key_layer.unsqueeze(-2) |
| key_layer = key_layer.expand( |
| -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) |
| key_layer = key_layer.contiguous().view(key_layer.size()[:2] + (self.num_attention_heads_per_partition, |
| self.hidden_size_per_attention_head)) |
| value_layer = value_layer.unsqueeze(-2) |
| value_layer = value_layer.expand( |
| -1, -1, -1, self.num_attention_heads_per_partition // self.num_multi_query_groups_per_partition, -1) |
| value_layer = value_layer.contiguous().view(value_layer.size()[:2] |
| + (self.num_attention_heads_per_partition, |
| self.hidden_size_per_attention_head)) |
|
|
| |
| |
| |
|
|
| from torchacc.ops import flash_attn_varlen_qkvpacked_xla |
| import einops |
|
|
| query_layer, key_layer, value_layer = [k.permute(1, 2, 0, 3) for k in [query_layer, key_layer, value_layer]] |
| bsz, _, q_len, _ = query_layer.size() |
| qkv = torch.stack([query_layer, key_layer, value_layer], dim=2) |
| qkv = qkv.transpose(1, 3) |
| qkv = einops.rearrange(qkv, 'b s ... -> (b s) ...') |
| cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) |
| context_layer = flash_attn_varlen_qkvpacked_xla( |
| qkv, cu_q_lens, q_len, dropout_p=0.0, softmax_scale=None, causal=True) |
| context_layer = einops.rearrange(context_layer, '(b s) ... -> b s ...', b=bsz) |
| context_layer = context_layer.permute(1, 0, 2, 3) |
| new_context_layer_shape = context_layer.size()[:-2] + (self.core_attention.hidden_size_per_partition, ) |
| context_layer = context_layer.reshape(*new_context_layer_shape) |
|
|
| |
| |
| |
|
|
| output = self.dense(context_layer) |
|
|
| return output, kv_cache |
|
|
| def torchacc_swiglu(x): |
| x = torch.chunk(x, 2, dim=-1) |
| return F.silu(x[0]).to(x[0].dtype) * x[1] |
|
|
| |
| for layer in model.transformer.encoder.layers: |
| layer.self_attention.forward = types.MethodType(chatglm_attn_forward, layer.self_attention) |
| layer.mlp.activation_func = torchacc_swiglu |
|
|
| return model |
|
|
|
|
| def patch_baichuan_model(model): |
|
|
| def baichuan_attn_forward(self, |
| hidden_states: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, |
| output_attentions: bool = False, |
| use_cache: bool = False, |
| **kwargs) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
| import einops |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| proj = self.W_pack(hidden_states) |
| proj = (proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)) |
| query_states = (proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) |
| key_states = (proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) |
| value_states = (proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)) |
|
|
| kv_seq_len = key_states.shape[-2] |
| if past_key_value is not None: |
| kv_seq_len += past_key_value[0].shape[-2] |
|
|
| if past_key_value is not None: |
| |
| key_states = torch.cat([past_key_value[0], key_states], dim=2) |
| value_states = torch.cat([past_key_value[1], value_states], dim=2) |
|
|
| past_key_value = (key_states, value_states) if use_cache else None |
|
|
| from torchacc.ops import flash_attn_varlen_xla |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
| q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] |
| cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) |
| output = flash_attn_varlen_xla( |
| q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, 0.0, softmax_scale=None, causal=True) |
| output = einops.rearrange(output, '(b s) ... -> b s ...', b=bsz) |
| output = self.o_proj(einops.rearrange(output, 'b s h d -> b s (h d)')) |
| return output, None, past_key_value |
|
|
| for layer in model.base_model.layers: |
| layer.self_attn.forward = types.MethodType(baichuan_attn_forward, layer.self_attn) |
|
|
| return model |
|
|
|
|
| def patch_qwen2_model(model): |
|
|
| def update_causal_mask(self, *args, **kwargs): |
| |
| return None |
|
|
| def qwen2_attn_forward( |
| self, |
| hidden_states, |
| attention_mask=None, |
| position_ids=None, |
| past_key_value=None, |
| output_attentions=False, |
| use_cache=False, |
| cache_position=None, |
| position_embeddings=None, |
| **kwargs, |
| ): |
|
|
| bsz, q_len, _ = hidden_states.size() |
|
|
| query_states = self.q_proj(hidden_states) |
| key_states = self.k_proj(hidden_states) |
| value_states = self.v_proj(hidden_states) |
|
|
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
|
| kv_seq_len = key_states.shape[-2] |
| if past_key_value is not None: |
| if self.layer_idx is None: |
| raise ValueError( |
| f'The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} ' |
| 'for auto-regressive decoding with k/v caching, please make sure to initialize the attention class ' |
| 'with a layer index.') |
| kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) |
|
|
| |
| |
| rotary_seq_len = kv_seq_len + 1 |
|
|
| if version.parse(transformers.__version__) >= version.parse('4.45'): |
| if position_embeddings is None: |
| cos, sin = self.rotary_emb(value_states, position_ids) |
| else: |
| cos, sin = position_embeddings |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) |
| else: |
| cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) |
|
|
| dropout_rate = 0.0 if not self.training else self.attention_dropout |
|
|
| |
| |
| |
| input_dtype = query_states.dtype |
| if input_dtype == torch.float32: |
| if torch.is_autocast_enabled(): |
| target_dtype = torch.get_autocast_gpu_dtype() |
| |
| elif hasattr(self.config, '_pre_quantization_dtype'): |
| target_dtype = self.config._pre_quantization_dtype |
| else: |
| target_dtype = self.q_proj.weight.dtype |
|
|
| query_states = query_states.to(target_dtype) |
| key_states = key_states.to(target_dtype) |
| value_states = value_states.to(target_dtype) |
|
|
| |
| query_states = query_states.transpose(1, 2) |
| key_states = key_states.transpose(1, 2) |
| value_states = value_states.transpose(1, 2) |
|
|
| from torchacc.ops import flash_attn_varlen_xla |
| import einops |
|
|
| q, k, v = [einops.rearrange(x, 'b s ... -> (b s) ...') for x in [query_states, key_states, value_states]] |
| cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=q.device) |
|
|
| attn_output = flash_attn_varlen_xla( |
| q, k, v, cu_q_lens, cu_q_lens, q_len, q_len, dropout_rate, softmax_scale=None, causal=True) |
|
|
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() |
| attn_output = self.o_proj(attn_output) |
|
|
| if not output_attentions: |
| attn_weights = None |
|
|
| return attn_output, attn_weights, past_key_value |
|
|
| def qwen2_forward(self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[List[torch.FloatTensor]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs): |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| if input_ids is not None and inputs_embeds is not None: |
| raise ValueError('You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time') |
| elif input_ids is not None: |
| batch_size, seq_length = input_ids.shape |
| elif inputs_embeds is not None: |
| batch_size, seq_length, _ = inputs_embeds.shape |
| else: |
| raise ValueError('You have to specify either decoder_input_ids or decoder_inputs_embeds') |
|
|
| if self.gradient_checkpointing and self.training: |
| if use_cache: |
| use_cache = False |
|
|
| past_key_values_length = 0 |
|
|
| if use_cache: |
| use_legacy_cache = not isinstance(past_key_values, Cache) |
| if use_legacy_cache: |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) |
| past_key_values_length = past_key_values.get_usable_length(seq_length) |
|
|
| if position_ids is None: |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
| position_ids = torch.arange( |
| past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device) |
| position_ids = position_ids.unsqueeze(0).view(-1, seq_length) |
| else: |
| position_ids = position_ids.view(-1, seq_length).long() |
|
|
| if inputs_embeds is None: |
| inputs_embeds = self.embed_tokens(input_ids) |
|
|
| hidden_states = inputs_embeds |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| all_self_attns = () if output_attentions else None |
| next_decoder_cache = None |
|
|
| for decoder_layer in self.layers: |
| if output_hidden_states: |
| all_hidden_states += (hidden_states, ) |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| position_ids, |
| past_key_values, |
| output_attentions, |
| use_cache, |
| ) |
| else: |
| layer_outputs = decoder_layer( |
| hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_value=past_key_values, |
| output_attentions=output_attentions, |
| use_cache=use_cache, |
| ) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if use_cache: |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
| if output_attentions: |
| all_self_attns += (layer_outputs[1], ) |
|
|
| hidden_states = self.norm(hidden_states) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states, ) |
|
|
| next_cache = None |
| if use_cache: |
| next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache |
|
|
| if not return_dict: |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
| from transformers.modeling_outputs import BaseModelOutputWithPast |
| return BaseModelOutputWithPast( |
| last_hidden_state=hidden_states, |
| past_key_values=next_cache, |
| hidden_states=all_hidden_states, |
| attentions=all_self_attns, |
| ) |
|
|
| for layer in model.model.layers: |
| layer.self_attn.forward = types.MethodType(qwen2_attn_forward, layer.self_attn) |
|
|
| if version.parse(transformers.__version__) >= version.parse('4.43'): |
| model.model._update_causal_mask = types.MethodType(update_causal_mask, model.model) |
| else: |
| model.model.forward = types.MethodType(qwen2_forward, model.model) |
| return model |
|
|
|
|
| def patch_clip_grad_norm(accelerator): |
| from accelerate.utils import DistributedType |
| from accelerate.optimizer import AcceleratedOptimizer |
| import torch_xla.core.xla_model as xm |
|
|
| def clip_grad_norm_(self, parameters, max_norm, norm_type=2): |
| """ |
| Should be used in place of `torch.nn.utils.clip_grad_norm_`. |
| |
| Returns: |
| `torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector). |
| |
| Example: |
| |
| ```python |
| >>> from accelerate import Accelerator |
| |
| >>> accelerator = Accelerator(gradient_accumulation_steps=2) |
| >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler) |
| |
| >>> for input, target in dataloader: |
| ... optimizer.zero_grad() |
| ... output = model(input) |
| ... loss = loss_func(output, target) |
| ... accelerator.backward(loss) |
| ... if accelerator.sync_gradients: |
| ... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm) |
| ... optimizer.step() |
| ``` |
| """ |
| if self.distributed_type == DistributedType.FSDP: |
| self.unscale_gradients() |
| parameters = [p for p in parameters] |
| for model in self._models: |
| if parameters == [p for p in model.parameters()]: |
| return model.clip_grad_norm_(max_norm, norm_type) |
| elif self.distributed_type == DistributedType.DEEPSPEED: |
| |
| |
| return None |
| elif self.distributed_type == DistributedType.XLA: |
| |
| for acc_opt in self._optimizers: |
| if not acc_opt.gradient_state.is_xla_gradients_synced: |
| opt = acc_opt |
| while isinstance(opt, AcceleratedOptimizer): |
| opt = opt.optimizer |
| gradients = xm._fetch_gradients(opt) |
| |
| |
| xm.all_reduce('sum', gradients, scale=1.0 / self.num_processes) |
| |
| acc_opt.gradient_state.is_xla_gradients_synced = True |
| if os.environ.get('ACCELERATE_USE_FSDP', 'false') == 'true': |
| self.unscale_gradients() |
| parameters = [p for p in parameters] |
| for model in self._models: |
| if parameters == [p for p in model.parameters()]: |
| return model._get_underlay_model().clip_grad_norm_(max_norm, norm_type) |
| self.unscale_gradients() |
| return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type) |
|
|
| |
| accelerator.clip_grad_norm_ = types.MethodType(clip_grad_norm_, accelerator) |
| return accelerator |
|
|
|
|
| def ta_accelerate(model, |
| fsdp_num, |
| layer_cls_name, |
| bf16=True, |
| fp16=False, |
| gradient_checkpointing=True, |
| fsdp_flatten_parameters=False): |
| """ accelerate LLM training using TorchAcc(only available internally). |
| """ |
| import torchacc as ta |
| assert layer_cls_name is not None |
|
|
| def get_ta_config(): |
| config = ta.Config() |
| config.compute.fp16 = fp16 |
| config.compute.bf16 = bf16 |
|
|
| config.memory.gc = gradient_checkpointing |
| if config.memory.gc: |
| config.memory.gc_cls = {layer_cls_name} |
|
|
| config.dist.fsdp.size = fsdp_num |
| config.dist.fsdp.wrap_layer_cls = {layer_cls_name} |
| config.dist.fsdp.flatten_parameters = fsdp_flatten_parameters |
| config.dist.dp.size = 1 |
|
|
| if fsdp_num > 1: |
| os.environ['ACCELERATE_USE_FSDP'] = 'true' |
|
|
| return config |
|
|
| ta_config = get_ta_config() |
| model = ta.accelerate(model, config=ta_config) |
| return model |
|
|