| |
| |
| |
| |
| |
|
|
| |
|
|
| import copy |
| from typing import Callable, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from torch.distributed import DeviceMesh |
| from torch.distributed.pipelining import PipelineStage |
| from torch.distributed.pipelining.schedules import ScheduleZBVZeroBubble, _PipelineSchedule, get_schedule_class |
| from transformers import PretrainedConfig |
|
|
| from flame.models.parallelize_fla import get_blocks, get_components_name, get_model |
| from torchtitan.config_manager import JobConfig |
| from torchtitan.distributed.parallel_dims import ParallelDims |
| from torchtitan.distributed.pipeline import build_pipeline_schedule, generate_split_points, stage_ids_this_rank |
| from torchtitan.tools.logging import logger |
|
|
| DeviceType = Union[int, str, torch.device] |
|
|
|
|
| def pipeline_fla( |
| model: nn.Module, |
| pp_mesh: DeviceMesh, |
| parallel_dims: ParallelDims, |
| job_config: JobConfig, |
| device: DeviceType, |
| model_config: PretrainedConfig, |
| loss_fn: Callable[..., torch.Tensor], |
| ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: |
| stages, models = pipeline_fla_manual_split( |
| model, pp_mesh, parallel_dims, job_config, device, model_config |
| ) |
|
|
| pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) |
|
|
| |
| has_first_stage = False |
| has_last_stage = False |
| for stage in stages: |
| if stage.is_first: |
| has_first_stage = True |
| if stage.is_last: |
| has_last_stage = True |
|
|
| return pp_schedule, models, has_first_stage, has_last_stage |
|
|
|
|
| def pipeline_fla_manual_split( |
| whole_model: nn.Module, |
| pp_mesh: DeviceMesh, |
| parallel_dims: ParallelDims, |
| job_config: JobConfig, |
| device: DeviceType, |
| model_config: PretrainedConfig, |
| ) -> tuple[list[PipelineStage], list[nn.Module]]: |
| """ |
| This API extracts one torch.nn.Module objects for the part of the model configured to run inside this stage. |
| |
| It wraps the model chunk in a ManualPipelineStage object and returns both the stage and model objects. |
| |
| The stage object is used to create a pipeline schedule, and the model object can be used for applying SPMD |
| parallelism. |
| """ |
| pp_rank = pp_mesh.get_local_rank() |
| pp_size = pp_mesh.size() |
|
|
| splits = ( |
| job_config.experimental.pipeline_parallel_split_points |
| or generate_split_points( |
| job_config, parallel_dims.pp, model_config.num_hidden_layers |
| ) |
| ) |
|
|
| def _build_stage( |
| stage_idx: int, |
| start_layer: Optional[str], |
| stop_layer: Optional[str], |
| is_first: bool = False, |
| is_last: bool = False, |
| ) -> tuple[PipelineStage, nn.Module]: |
| model = copy.deepcopy(whole_model) |
| if not is_first: |
| |
| real_model = get_model(model) |
| tok_embeddings_name = get_components_name(real_model, "tok_embeddings") |
| setattr(real_model, tok_embeddings_name, None) |
|
|
| drop_layers = start_layer is not None |
| |
| |
| module_dict = get_blocks(model)._modules |
| layer_names = list(module_dict.keys()) |
|
|
| |
| for name in layer_names: |
| |
| prefix = start_layer.split(".")[0] if start_layer else "layers" |
| layer_name = f"{prefix}.{name}" |
|
|
| |
| if layer_name == start_layer: |
| drop_layers = False |
| if layer_name == stop_layer: |
| drop_layers = True |
|
|
| |
| if drop_layers: |
| del module_dict[name] |
|
|
| if not is_last: |
| |
| real_model = get_model(model) |
| norm_name = get_components_name(real_model, "norm") |
| setattr(real_model, norm_name, None) |
|
|
| head_name = get_components_name(model, "lm_head") |
| setattr(model, head_name, None) |
|
|
| stage = PipelineStage( |
| model, |
| stage_idx, |
| num_stages, |
| device, |
| group=pp_mesh.get_group("pp"), |
| ) |
| return stage, model |
|
|
| num_stages = len(splits) + 1 |
| stage_idx = pp_rank |
|
|
| stages = [] |
| models = [] |
|
|
| schedule_class = get_schedule_class( |
| job_config.experimental.pipeline_parallel_schedule |
| ) |
| style = "v" if schedule_class == ScheduleZBVZeroBubble else "loop" |
|
|
| for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style=style): |
| start_layer = splits[stage_idx - 1] if stage_idx > 0 else None |
| stop_layer = splits[stage_idx] if stage_idx < num_stages - 1 else None |
| stage, model_chunk = _build_stage( |
| stage_idx, |
| start_layer, |
| stop_layer, |
| is_first=stage_idx == 0, |
| is_last=stage_idx == num_stages - 1, |
| ) |
| logger.info( |
| f"PP rank {pp_rank} is building stage_idx {stage_idx}" |
| f" with start_layer {start_layer}, stop_layer {stop_layer}" |
| ) |
| stages.append(stage) |
| models.append(model_chunk) |
| return stages, models |
|
|