diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index daf84d799dafec823f27e14676e810e30c9d4d3e..d574722a4c453ad3421849bc78b4ddc74c934d56 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -246,7 +246,19 @@ def npu_mm_all_reduce_base_forward(x1, x2, hcom, reduce_op='sum', bias=None, ant else: return x1.new_empty(tuple(dim_list)) - +@impl(m, "npu_tome_merge") +def npu_tome_merge(token_a, token_b, topk_indice, arg_max, top_rate=0.5): + batch = token_a.size(0) + seq_len_a = token_a.size(1) + hidden_size = token_a.size(2) + seq_len_b = token_b.size(1) + topR = (seq_len_a + seq_len_b) * top_rate + heads = 8 + unmerge_token_a_dim_list = [batch, seq_len_a - topR, hidden_size] + unmerge_token_b_dim_list = [batch, heads, seq_len_b, hidden_size] + unreduce_count_dim_list = [batch, heads, seq_len_b] + unreduce_count = torch.empty(unreduce_count_dim_list, dtype=torch.float32, device='meta') + return (token_a.new_empty(tuple(unmerge_token_a_dim_list)), token_a.new_empty(tuple(unmerge_token_b_dim_list)), torch.empty_like(unreduce_count)) @impl(m, "npu_mm_reduce_scatter_base") def npu_mm_reduce_scatter_base_meta(self, x2, hcom, world_size, reduce_op='sum', diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 839d06eef33e814a95ed6b8315891318e8973d5c..a756e654b5e21efa71811ac6639116ce6ef048e0 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -677,6 +677,18 @@ class NPUIncreFlashAttentionOP(torch.autograd.Function): block_size, inner_precise) +class NPUTomeMergeOp(torch.autogard.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_tome_merge(*args, **kwargs) + + @staticmethod + def symbolic(g, token_a: Tensor, token_b: Tensor, token_indice: Tensor, arg_max: Tensor, top_rate: float = 1.0): + return g.op("npu::NPUTomeMerge", token_a, token_b, token_indice, arg_max, top_rate) + + + class NPUMaskedSoftmaxWithRelPosBiasOP(torch.autograd.Function): @staticmethod @@ -904,6 +916,10 @@ def wrapper_npu_rms_norm(self, gamma, epsilon=1e-6): return NPURmsNormOP.apply(self, gamma, epsilon) +def wrapper_npu_tome_merge(self, token_a, token_b, token_indice, arg_max, top_rate = 1.0): + return NPUTomeMergeOp.apply(self, token_a, token_b, token_indice, arg_max, top_rate) + + def wrapper_npu_add_rms_norm(x1, x2, gamma, epsilon=1e-6): return NPUAddRmsNormOP.apply(x1, x2, gamma, epsilon) @@ -1097,3 +1113,4 @@ def add_onnx_ops(): torch_npu.npu_mm_all_reduce_base = wrapper_npu_mm_all_reduce_base torch_npu.npu_weight_quant_batchmatmul = wrapper_npu_weight_quant_batchmatmul torch_npu.npu_anti_quant = wrapper_npu_anti_quant + torch_npu.npu_tome_merge = wrapper_npu_tome_merge