| from state import Model as Model, Parallelism, Training |
| from dtypes import DType |
| from math import ceil |
|
|
|
|
| class MemoryCalculation: |
| def __init__( |
| self, |
| modelconfig: Model, |
| parallelismconfig: Parallelism, |
| trainingconfig: Training, |
| ): |
| self.model = modelconfig |
| self.parallelism = parallelismconfig |
| self.training = trainingconfig |
|
|
| def calculate_num_parameters_per_layer(self) -> float: |
| |
| |
|
|
| |
| |
|
|
| |
| b, s = self.training.batch_size, self.training.sequence_length |
| h, i, l, v, e = ( |
| self.model.hidden_dim, |
| self.model.intermediate_size, |
| self.model.num_layers, |
| self.model.vocab_size, |
| self.model.total_experts, |
| ) |
| tp, pp, ep = ( |
| self.parallelism.tensor_parallelism, |
| self.parallelism.pipeline_parallelism, |
| self.parallelism.expert_parallelism, |
| ) |
|
|
| |
| layer_norm_attn_in = h |
| qkv = 3 * h * h / tp |
| attn_output_proj = (h * h + h) / tp |
| attn = layer_norm_attn_in + qkv + attn_output_proj |
|
|
| |
| layer_norm_mlp_in = h |
| mlp_up_proj = (h * i + i) / tp |
| mlp_gate_proj = (h * i + i) / tp |
| mlp_down_proj = (i * h + h) / tp |
| mlp = layer_norm_mlp_in + mlp_up_proj + mlp_gate_proj + mlp_down_proj |
| if self.model.is_moe: |
| router = h * e + e |
| expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj |
| experts = expert * e / ep |
| mlp = layer_norm_mlp_in + router + experts |
|
|
| layer = attn + mlp |
| return layer |
|
|
| def calculate_unshardeable_parameters(self) -> float: |
| b, s = self.training.batch_size, self.training.sequence_length |
| h, i, l, v, e = ( |
| self.model.hidden_dim, |
| self.model.intermediate_size, |
| self.model.num_layers, |
| self.model.vocab_size, |
| self.model.total_experts, |
| ) |
| tp, pp, ep = ( |
| self.parallelism.tensor_parallelism, |
| self.parallelism.pipeline_parallelism, |
| self.parallelism.expert_parallelism, |
| ) |
| |
| input_embedding = v * h / tp |
| unembedding = 0 |
| if not self.model.weight_tied_embeddings: |
| unembedding = h * v / tp |
| final_layer_norm = h |
| |
| total_params = 0 |
| if pp == 1: |
| total_params = input_embedding + unembedding + final_layer_norm |
| elif pp > 1: |
| total_params = max(input_embedding, unembedding) + final_layer_norm |
| return total_params |
|
|
| def calculate_fsdp_sharded_parameters(self) -> float: |
| if not self.parallelism.fsdp_enabled: |
| return self.calculate_num_parameters() |
| else: |
| return ( |
| self.calculate_num_parameters_per_layer() |
| * ceil( |
| (self.model.num_layers - 1) / self.parallelism.pipeline_parallelism |
| ) |
| / self.parallelism.fsdp_parallelism |
| + self.calculate_unshardeable_parameters() |
| + self.calculate_num_parameters_per_layer() |
| ) |
|
|
| def calculate_num_parameters(self) -> float: |
| return ( |
| self.calculate_num_parameters_per_layer() |
| * ceil(self.model.num_layers / self.parallelism.pipeline_parallelism) |
| + self.calculate_unshardeable_parameters() |
| ) |
|
|
| def calculate_activation_parameters(self) -> float: |
| |
| |
| |
| b, s = self.training.batch_size, self.training.sequence_length |
| h, i, l, v, e, ae = ( |
| self.model.hidden_dim, |
| self.model.intermediate_size, |
| self.model.num_layers, |
| self.model.vocab_size, |
| self.model.total_experts, |
| self.model.active_experts, |
| ) |
| tp, cp, pp, ep = ( |
| self.parallelism.tensor_parallelism, |
| self.parallelism.context_parallelism, |
| self.parallelism.pipeline_parallelism, |
| self.parallelism.expert_parallelism, |
| ) |
| sp = tp |
| if self.training.gradient_checkpointing: |
| |
| embed = 0 |
| layer = s * b * h / cp / tp |
| layers = layer * l |
| embed = 0 |
| final_layer_out = s * b * h / cp / sp |
| final_norm = s * b * h / cp / sp |
| unembed = s * b * v / cp / tp |
| num_params = embed + layers + final_layer_out + final_norm + unembed |
| return num_params |
| else: |
| |
| |
| |
| |
| layer_in = s * b * h / cp / tp |
| attn_norm = s * b * h / cp / sp |
| flash = s * b * h / cp / tp |
| |
| projection = s * b * h / cp / tp |
| attn = layer_in + attn_norm + flash + projection |
| |
| mlp_norm = s * b * h / cp / sp |
|
|
| mlp_up = s * b * i / cp / tp |
| mlp_gate = s * b * i / cp / tp |
| hadamard_swiglu = s * b * i / cp / tp |
| mlp_down = s * b * h / cp / tp |
| if self.model.is_moe: |
| router = ( |
| s * b * e / cp / sp |
| ) |
| expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down |
| experts = expert * ae / ep |
| mlp = mlp_norm + router + experts |
| else: |
| mlp = mlp_norm + mlp_up + mlp_gate + hadamard_swiglu + mlp_down |
| layer = attn + mlp |
| layers = ( |
| layer * l |
| ) |
| |
| embed = 0 |
| final_layer_out = ( |
| s * b * h / cp / sp |
| ) |
| final_norm = s * b * h / cp / sp |
| unembed = s * b * v / cp / tp |
| num_params = embed + layers + final_layer_out + final_norm + unembed |
| return num_params |
|
|
| def calculate_parameter_memory(self) -> float: |
| if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy == "Zero-3": |
| params = self.calculate_fsdp_sharded_parameters() |
| else: |
| params = self.calculate_num_parameters() |
| if self.training.mixed_precision: |
| master_copy = params * self.training.precision |
| working_copy = params * self.training.param_dtype |
| return master_copy + working_copy |
| else: |
| return params * self.training.precision |
|
|
| def calculate_gradient_memory(self) -> float: |
| |
| if self.parallelism.fsdp_enabled and self.parallelism.fsdp_strategy in ("Zero-3", "Zero-2"): |
| params = self.calculate_fsdp_sharded_parameters() |
| else: |
| params = self.calculate_num_parameters() |
| grad_accumulation = 0 |
| if self.training.grad_accumulation: |
| if self.training.mixed_precision: |
| grad_accumulation = ( |
| params * self.training.reduce_dtype |
| ) |
| else: |
| grad_accumulation = ( |
| params * self.training.precision |
| ) |
| if self.training.mixed_precision: |
| gradients = params * self.training.param_dtype |
| else: |
| gradients = params * self.training.precision |
| return grad_accumulation + gradients |
|
|
| def calculate_optimizer_memory(self) -> float: |
| |
| |
| if self.parallelism.fsdp_enabled: |
| return ( |
| 2 * self.calculate_num_parameters() * DType.FP32 |
| ) / self.parallelism.fsdp_parallelism |
| else: |
| return ( |
| 2 * self.calculate_num_parameters() * DType.FP32 |
| ) |
| |
| def calculate_activation_memory(self) -> float: |
| if self.training.mixed_precision: |
| return self.calculate_activation_parameters() * self.training.param_dtype |
| else: |
| return ( |
| self.calculate_activation_parameters() * self.training.precision |
| ) |
|
|