diff --git a/.gitmodules b/.gitmodules index fd07c1e062060acebfa8d123adf037543d88f3b3..98d633f8a686d8f900dd71e33eff34ca5843f743 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,15 +1,15 @@ [submodule "third_party/op-plugin"] path = third_party/op-plugin - url = https://gitee.com/ascend/op-plugin.git - ignore = dirty - branch = 6.0.0 + url = https://gitee.com/caoqiku/op-plugin.git + branch = feature_dispatch_combine3 + # ignore = dirty [submodule "third_party/googletest"] path = third_party/googletest url = https://gitee.com/mirrors/googletest.git [submodule "third_party/torchair/torchair"] path = third_party/torchair/torchair - url = https://gitee.com/ascend/torchair.git - branch = 6.0.0 + url = https://gitee.com/caoqiku/torchair.git + branch = feature_dispatch_combine2 [submodule "third_party/Tensorpipe"] path = third_party/Tensorpipe url = https://gitee.com/ascend/Tensorpipe.git diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index 6268278a21e3ae4b3a1abf1556cd2780cdf0bff4..034e1af2e4384308080ce1b0c5f51da7a6a153f0 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -267,6 +267,62 @@ def npu_rms_norm_backward_meta(dy, self, gamma, rstd): return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(gamma, dtype=gamma.dtype)) +@impl(m, "npu_moe_distribute_dispatch") +def npu_moe_distribute_dispatch_meta(x, expert_ids, group_ep, ep_world_size, ep_rank_id, moe_expert_num, scales=None, group_tp="", tp_world_size=0, + tp_rank_id=0, expert_shard_type=0, shared_expert_rank_num=0, quant_mode=0, global_bs=0): + n = x.size(0) + h = x.size(1) + k = expert_ids.size(1) + + shared_front = 0 + outDtype = x.dtype + if expert_shard_type == 0: + shared_front = 1 + + local_moe_expert_num = 0 + global_bs_real = 0 + if global_bs == 0: + global_bs_real = n * ep_world_size + else: + global_bs_real = global_bs + a = 0 + if shared_front == 1: + if ep_rank_id < shared_expert_rank_num: + local_moe_expert_num = 1 + a = global_bs_real // shared_expert_rank_num + else: + local_moe_expert_num = moe_expert_num // (ep_world_size - shared_expert_rank_num) + a = global_bs_real * local_moe_expert_num + else: + if ep_rank_id >= ep_world_size - shared_expert_rank_num: + local_moe_expert_num = 1 + a = global_bs_real // shared_expert_rank_num + else: + local_moe_expert_num = moe_expert_num // (ep_world_size - shared_expert_rank_num) + a = global_bs_real * local_moe_expert_num + + if scales is not None or quant_mode != 0: + outDtype = torch.int8 + local_moe_expert_num = int(local_moe_expert_num) + expand_x = x.new_empty(tuple([a * tp_world_size, h]), dtype=outDtype) + dynamic_scales = x.new_empty(tuple([a * tp_world_size]), dtype=torch.float32) + expand_idx = x.new_empty(tuple([n * k]), dtype=torch.int32) + expert_token_nums = x.new_empty(tuple([local_moe_expert_num]), dtype=torch.int64) + ep_recv_counts = x.new_empty(tuple([moe_expert_num + shared_expert_rank_num]), dtype=torch.int32) + tp_recv_counts = x.new_empty(tuple([tp_world_size]), dtype=torch.int32) + return (expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts) + + +@impl(m, "npu_moe_distribute_combine") +def npu_moe_distribute_combine_meta(expand_x, expert_ids, expand_idx, ep_send_counts, expert_scales, group_ep, ep_world_size, ep_rank_id, moe_expert_num, + tp_send_counts=None, group_tp="", tp_world_size=0, tp_rank_id=0, expert_shard_type=0, shared_expert_rank_num=0, global_bs=0): + dim_list = [] + dim_list.append(expert_ids.size(0)) + dim_list.append(expand_x.size(1)) + + return expand_x.new_empty(tuple(dim_list), dtype=expand_x.dtype) + + @impl(m, "_npu_dropout") def _npu_dropout_meta(self, p): mask = math.floor(math.floor((self.numel() + BIT_NUMBER - 1) / BIT_NUMBER) * BIT_NUMBER / UINT8_BIT_NUMBER)