error when batch size >1
#1
by loulou2 - opened
Hi, I have run into a small error when running this moe with batch>1. :(
I am getting dispatched_experts = input_expanded * dispatch_mask ~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~ RuntimeError: The size of tensor a (2432) must match the size of tensor b (128) at non-singleton dimension 1 from the modeling_gemma4 file.
I guess because the x.repeat in the forward function does not match the dispatch mask coming from the router when batchsize>1