| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Test for using ray collective group. |
| Suppose we Actor and Rollout. Actor contains 4 workers and Rollout contains 2 workers. We established a Worker to |
| Rollout relationship by using collective groups |
| Actor: rank 0, 1 - Rollout rank 0 |
| Rollout rank 2, 3 - Rollout rank 1 |
| Then, we initiate 4 p2p comms from actor to rollout |
| """ |
|
|
| import ray |
| import ray.util.collective as collective |
| import torch |
|
|
| from verl.single_controller.base import Worker |
| from verl.single_controller.base.decorator import Dispatch, register |
| from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup |
|
|
|
|
| @ray.remote |
| class Actor(Worker): |
| @register(Dispatch.ONE_TO_ALL) |
| def init(self): |
| remote_rank = self.rank // 2 |
| self.group_name = f"A{self.rank}_R{remote_rank}" |
| collective.init_collective_group(world_size=2, rank=0, backend="nccl", group_name=self.group_name) |
|
|
| @register(Dispatch.ONE_TO_ALL, blocking=False) |
| def send_tensors(self): |
| tensor = torch.ones(size=(4,), dtype=torch.float32, device="cuda") * self.rank |
| collective.send(tensor=tensor, dst_rank=1, group_name=self.group_name) |
|
|
|
|
| @ray.remote |
| class Rollout(Worker): |
| @register(Dispatch.ONE_TO_ALL) |
| def init(self): |
| self.remote_first_rank = self.rank * 2 |
| self.remote_second_rank = self.remote_first_rank + 1 |
| self.first_group_name = f"A{self.remote_first_rank}_R{self.rank}" |
| self.second_group_name = f"A{self.remote_second_rank}_R{self.rank}" |
|
|
| collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.first_group_name) |
| collective.init_collective_group(world_size=2, rank=1, backend="nccl", group_name=self.second_group_name) |
|
|
| @register(Dispatch.ONE_TO_ALL, blocking=False) |
| def receive_tensors(self): |
| self.tensor1 = torch.randn(size=(4,), dtype=torch.float32, device="cuda") |
| self.tensor2 = torch.randn(size=(4,), dtype=torch.float32, device="cuda") |
|
|
| collective.recv(self.tensor1, src_rank=0, group_name=self.first_group_name) |
| collective.recv(self.tensor2, src_rank=0, group_name=self.second_group_name) |
|
|
| @register(Dispatch.ONE_TO_ALL) |
| def get_tensors(self): |
| return {f"src_{self.remote_first_rank}": self.tensor1, f"src_{self.remote_second_rank}": self.tensor2} |
|
|
|
|
| def test_ray_collective_group(): |
| ray.init() |
|
|
| actor_resource_pool = RayResourcePool([4]) |
| rollout_resource_pool = RayResourcePool([2]) |
|
|
| actor_cls = RayClassWithInitArgs(cls=Actor) |
| rollout_cls = RayClassWithInitArgs(cls=Rollout) |
|
|
| actor_wg = RayWorkerGroup( |
| resource_pool=actor_resource_pool, ray_cls_with_init=actor_cls, name_prefix="collective_group_actor" |
| ) |
| rollout_wg = RayWorkerGroup( |
| resource_pool=rollout_resource_pool, ray_cls_with_init=rollout_cls, name_prefix="collective_group_rollout" |
| ) |
|
|
| actor_wg.init() |
| rollout_wg.init() |
|
|
| out1 = actor_wg.send_tensors() |
| out2 = rollout_wg.receive_tensors() |
|
|
| |
| ray.get(out1) |
| ray.get(out2) |
|
|
| output = rollout_wg.get_tensors() |
|
|
| rollout_0_output = output[0] |
| rollout_1_output = output[1] |
|
|
| output = rollout_0_output | rollout_1_output |
|
|
| print(output) |
|
|
| for i in range(4): |
| assert torch.sum(output[f"src_{i}"]).item() == 4 * i |
|
|
| ray.shutdown() |
|
|
|
|
| if __name__ == "__main__": |
| test_ray_collective_group() |
|
|