From 1b50bfbb023ef38d240c42778378e152c5360a60 Mon Sep 17 00:00:00 2001 From: 123456Qq Date: Thu, 9 Nov 2023 22:13:27 +0800 Subject: [PATCH 1/2] test --- torch_npu/csrc/aten/npu_native_functions.yaml | 3 ++- torch_npu/onnx/wrapper_onnx_ops.py | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 035b47a11a..02a6eaffc2 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -85,7 +85,8 @@ custom: - func: npu_hcom_allgather(Tensor self, int rank_size, str group, float alpha, float beta, int? hccl_comm) -> Tensor - func: npu_hcom_allgather.out(Tensor self, int rank_size, str group, float alpha, float beta, int? hccl_comm, *, Tensor(a!) out) -> Tensor(a!) - func: npu_format_cast(Tensor self, int acl_format) -> Tensor - - func: _npu_format_cast(Tensor self, int acl_format) -> Tensor + - func: _npu_format_cast(Tensor self, int acl_format) -> Tensor + - func: npu_rms_norm(Tensor self, Tensor gamma, float epsilon) -> (Tensor, Tensor) custom_autograd: - func: npu_multi_head_attention_score(Tensor query, Tensor key, Tensor value, int head_num, str input_layout, Tensor? pse=None, Tensor? padding_mask=None, Tensor? atten_mask=None, float scale=1., float keep_prob=1., int pre_tockens=2147483647, int next_tockens=2147483647, bool gen_mask_parallel=True, bool sync=False) -> Tensor[] diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 854b37ff14..05386e29e4 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -169,6 +169,15 @@ class NPUMultiHeadAttentionOP(torch.autograd.Function): outputs=8) +class NPURmsNormOP(torch.autograd.Function): + @staticmethod + def forward(ctx, *args, **kwargs): + return torch_npu._C._VariableFunctionsClass.npu_rms_norm(*args, **kwargs) + @staticmethod + def symbolic(g, self: torch.Tensor, gamma: torch.Tensor, epsilon: float = 1e-6): + return g.op("npu::NPURmsNorm", self, gamma, epsilon_i=epsilon, outputs=2) + + class NPUDiouOP(torch.autograd.Function): @staticmethod @@ -739,6 +748,10 @@ def wrapper_npu_grid_assign_positive(self, overlaps, box_responsible_flags, max_ num_gts, pos_iou_thr, min_pos_iou, gt_max_assign_all) +def wrapper_npu_rms_norm(self, gamma, epsilon=1e-6): + return NPURmsNormOP.apply(self, gamma, epsilon) + + def wrapper_npu_ifmr(data, data_min, data_max, cumsum, min_percentile, max_percentile, search_start, search_end, search_step, with_offset): return NPUIfmrOP.apply(data, data_min, data_max, cumsum, min_percentile, max_percentile, @@ -915,4 +928,5 @@ def add_onnx_ops(): torch_npu.npu_scaled_masked_softmax = wrapper_npu_scaled_masked_softmax torch_npu.npu_mish = wrapper_npu_mish torch_npu.npu_rotary_mul = wrapper_npu_rotary_mul + torch_npu.npu_rmsnorm = wrapper_npu_rms_norm torch_npu.npu_flash_attention = wrapper_npu_flash_attention -- Gitee From d8b59c1c3a33038512e096a3e2c1f2137c84f7f1 Mon Sep 17 00:00:00 2001 From: "liyefeng803@huawei.com" Date: Tue, 21 Nov 2023 19:26:51 +0800 Subject: [PATCH 2/2] test --- test/test_onnx/test_wrapper_onnx_ops.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/test_onnx/test_wrapper_onnx_ops.py b/test/test_onnx/test_wrapper_onnx_ops.py index 542069e4d7..d496541402 100644 --- a/test/test_onnx/test_wrapper_onnx_ops.py +++ b/test/test_onnx/test_wrapper_onnx_ops.py @@ -1184,6 +1184,27 @@ class TestOnnxOps(TestCase): export_onnx(onnx_model_name) assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + + def test_wrapper_npu_rms_norm(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, gamma): + epsilon = 1e-6 + x = torch_npu.npu_rms_norm(x, gamma, epsilon) + return x + + def export_onnx(onnx_model_name): + x = torch.rand(10, 1024).uniform_(-3, 3).npu().half() + gamma = torch.rand(10).uniform_(-3, 3).npu().half() + model = Model().to("npu") + model(x, gamma) + self.onnx_export(model, (x, gamma), onnx_model_name) + onnx_model_name = "model_npu_rms_norm.onnx" + export_onnx(onnx_model_name) + assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, + onnx_model_name))) def test_wrapper_npu_rotary_mul(self): class Model(torch.nn.Module): -- Gitee