From 94ff932abb99a966f8fca8244c8eeda7db6f3d01 Mon Sep 17 00:00:00 2001 From: TD_lihan Date: Fri, 23 Feb 2024 17:44:05 +0800 Subject: [PATCH] add tomeunmerge --- torch_npu/meta/meta_registrations.py | 9 +++++++++ torch_npu/onnx/wrapper_onnx_ops.py | 20 ++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index daf84d799d..5a25134abc 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -224,6 +224,15 @@ def npu_ffn_meta(x, weight1, weight2, activation, *, expert_tokens=None, bias1=N return x.new_empty(tuple(dim_list)) +@impl(m, "npu_tome_unmerge") +def npu_tome_unmerge_meta(atten_out, ori_indice_a, ori_indice_b, topk_indice, arg_max, top_r_rate=0.5): + dim_list = [] + dim_list.append(atten_out.size(0)) + dim_list.append(ori_indice_a.size(1) + ori_indice_b.size(1)) + dim_list.append(atten_out.size(2)) + return atten_out.new_empty(tuple(dim_list)) + + @impl(m, "npu_group_norm_silu") def group_norm_silu_meta(self, gemma, beta, group, eps=0.00001): N = self.size(1) diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 839d06eef3..f3f4c12369 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -657,6 +657,19 @@ class NPUPromptFlashAttentionOP(torch.autograd.Function): input_layout, num_key_value_heads) +class NPUTomeUnmergeOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch.ops.npu.npu_tome_unmerge(*args, **kwargs) + + @staticmethod + def symbolic(g, atten_out: Tensor, ori_indice_a: Tensor, ori_indice_b: Tensor, topk_indice: Tensor, + arg_max: Tensor, top_r_rate: float = 0.5): + return g.op("npu::NPUTomeUnmerge", atten_out, ori_indice_a, ori_indice_b, topk_indice, arg_max, + top_r_rate_f=top_r_rate) + + class NPUIncreFlashAttentionOP(torch.autograd.Function): @staticmethod @@ -793,6 +806,12 @@ def wrapper_npu_iou(bboxes, gtboxes, mode=0): return NPUIouOP.apply(bboxes, gtboxes, mode) +def wrapper_npu_tome_unmerge(self, atten_out, ori_indice_a, ori_indice_b, topk_indice, + arg_max, top_r_rate=0.5): + return NPUTomeUnmergeOP.apply(self, atten_out, ori_indice_a, ori_indice_b, topk_indice, + arg_max, top_r_rate) + + def wrapper_npu_batch_nms(self, scores, score_threshold, iou_threshold, max_size_per_class, max_total_size, change_coordinate_frame=False, transpose_box=False): @@ -1048,6 +1067,7 @@ def add_onnx_ops(): torch_npu.npu_roi_align = wrapper_npu_roi_align torch_npu.npu_group_norm_silu = wrapper_npu_group_norm_silu torch_npu.npu_iou = wrapper_npu_iou + torch_npu.npu_tome_unmerge = wrapper_npu_tome_unmerge torch_npu.npu_batch_nms = wrapper_npu_batch_nms torch_npu.fast_gelu = wrapper_npu_fast_gelu torch_npu.npu_fast_gelu = wrapper_npu_fast_gelu -- Gitee