diff --git a/test/test_custom_ops/test_moe_gating_top_k_softmax.py b/test/test_custom_ops/test_moe_gating_top_k_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..2d12e1400be64667b5ee4251be80942bf80891e8 --- /dev/null +++ b/test/test_custom_ops/test_moe_gating_top_k_softmax.py @@ -0,0 +1,109 @@ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor + + +class testMoeGatingTopKSoftmax(TestCase): + def supported_op_exec(self, gating, finished_opt=None, k=1): + def softmax_func(x, axis=None): + is_fp16 = x.dtype == np.float16 + x = x.astype(np.float32) + x_max = x.max(axis=-1, keepdims=True) + x_sub = x - x_max + y = np.exp(x_sub) + x_sum = y.sum(axis=-1, keepdims=True) + ans = y / x_sum + if is_fp16: + ans = ans.astype(np.float16) + x_max = x_max.astype(np.float16) + x_sum = x_sum.astype(np.float16) + return ans, x_max, x_sum + + gating = gating.numpy() + num_expert = gating.shape[-1] + if finished_opt is None: + pass + else: + finished_opt = finished_opt.numpy() + softmax, _, _ = softmax_func(gating, -1) + indices = np.argsort(-softmax, axis=-1, kind='stable') + indices = indices[:, :k].astype(np.int32) + + out = np.take_along_axis(softmax, indices, axis=-1) + + if finished_opt is not None: + finished = finished_opt.reshape(finished_opt.shape[0], 1) + finished = np.tile(finished, (1, k)) + indices = np.where(finished, num_expert, indices) + source_row = np.arange(indices.shape[0] * indices.shape[1]).reshape(indices.shape[1], indices.shape[0]).transpose(1, 0).astype(np.int32) + + # output = torch.from_numpy(out) + # indices_out = torch.from_numpy(indices) + # source_row_out = torch.from_numpy(source_row) + + return [out, indices, source_row] + + def custom_op_exec(self, fating, finished=None, k=1): + output, indices_out, source_row_out = torch_npu.npu_moe_gating_top_k_softmax(fating, finished, k) + return [output.cpu().numpy(), indices_out.cpu().numpy(), source_row_out.cpu().numpy()] + + def test_fast_gelu_float32(self, device="npu"): + item = [np.float32, 0, [48, 32]] + cpu_input, npu_input = create_common_tensor(item, 0, 100) + + supported_output = self.supported_op_exec(cpu_input, k=16) + custom_output = self.custom_op_exec(npu_input, k=16) + self.assertRtolEqual(supported_output[0], custom_output[0]) + self.assertRtolEqual(supported_output[1], custom_output[1]) + self.assertRtolEqual(supported_output[2], custom_output[2]) + + def test_fast_gelu_float16(self, device="npu"): + item = [np.float16, 0, [48, 32]] + cpu_input, npu_input = create_common_tensor(item, 0, 100) + + supported_output = self.supported_op_exec(cpu_input, k=16) + custom_output = self.custom_op_exec(npu_input, k=16) + self.assertRtolEqual(supported_output[0], custom_output[0]) + self.assertRtolEqual(supported_output[1], custom_output[1]) + self.assertRtolEqual(supported_output[2], custom_output[2]) + + # def test_fast_gelu_bfloat16(self, device="npu"): + # item = [np.bfloat16, 0, [48, 32]] + # _, npu_input = create_common_tensor(item, 0, 100) + + # supported_output = self.supported_op_exec(npu_input) + # custom_output = self.custom_op_exec(npu_input, k) + # self.assertRtolEqual(supported_output, custom_output) + + def test_fast_gelu_with_finished_float32(self, device="npu"): + item = [np.float32, 0, [48, 32]] + cpu_input, npu_input = create_common_tensor(item, 0, 100) + item_finished = [np.bool, 0, [48]] + cpu_finished, npu_finished = create_common_tensor(item_finished, 0, 2) + + supported_output = self.supported_op_exec(cpu_input, cpu_finished, k=16) + custom_output = self.custom_op_exec(npu_input, npu_finished, k=16) + self.assertRtolEqual(supported_output[0], custom_output[0]) + self.assertRtolEqual(supported_output[1], custom_output[1]) + self.assertRtolEqual(supported_output[2], custom_output[2]) + +if __name__ == "__main__": + run_tests() diff --git a/test/test_onnx/test_wrapper_onnx_ops.py b/test/test_onnx/test_wrapper_onnx_ops.py index 99d7c950fa19f9b3505da72014d0575d2ed28c1f..a0921e8456d06bbc8a0b09663013c87e83583589 100644 --- a/test/test_onnx/test_wrapper_onnx_ops.py +++ b/test/test_onnx/test_wrapper_onnx_ops.py @@ -194,6 +194,27 @@ class TestOnnxOps(TestCase): assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, onnx_model_name))) + def test_wrapper_npu_moe_gating_top_k_softmax(self): + class Model(torch.nn.Module): + def __init__(self): + super(Model, self).__init__() + + def forward(self, x, finsihed=None): + k = 2 + return torch_npu.npu_moe_gating_top_k_softmax(x, finsihed, k) + + def export_onnx(onnx_model_name): + x = torch.rand(2, 16).npu() + finished = torch.randint(0, 2, size=(2,), dtype=torch.bool).npu() + model = Model().to("npu") + model(x, finished) + self.onnx_export(model, (x, finished), onnx_model_name, ["x", "finished"], ["y", "expert_idx", "row_idx"]) + + onnx_model_name = "model_npu_moe_gating_top_k_softmax.onnx" + export_onnx(onnx_model_name) + assert (os.path.isfile(os.path.join(TestOnnxOps.test_onnx_path, + onnx_model_name))) + def test_wrapper_npu_geglu(self): class Model(torch.nn.Module): def __init__(self): diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 91d660d3d5a1d838a2480089fa6e2b3923c39dd8..43a8371b075e47a0f4b9ba794352795128e557d4 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -593,6 +593,21 @@ class NPUScaledMaskedSoftmaxOP(torch.autograd.Function): return g.op("npu::NPUScaledMaskedSoftmax", x, mask, scale_f=scale, fixed_triu_mask_i=fixed_triu_mask) +class NPUMoeGatingTopKSoftmaxOP(torch.autograd.Function): + + @staticmethod + def forward(ctx, *args, **kwargs): + return torch_npu._C._VariableFunctionsClass.npu_moe_gating_top_k_softmax(*args, **kwargs) + + @staticmethod + def symbolic(g, gating: torch.Tensor, finished: torch.Tensor, k: int = -1): + dtype = torch.bool + if finished is None: + finished = g.op("Constant", value_t=torch.tensor([]).to(dtype)) + return g.op("npu::NPUMoeGatingTopKSoftmax", gating, finished, k_i=k, outputs=3) + + + class NPUMishOP(torch.autograd.Function): @staticmethod @@ -861,6 +876,10 @@ def wrapper_npu_scaled_masked_softmax(x, mask, scale=1, fixed_triu_mask=False): return NPUScaledMaskedSoftmaxOP.apply(x, mask, scale, fixed_triu_mask) +def wrapper_npu_moe_gating_top_k_softmax(gating, finished=None, k=-1): + return NPUMoeGatingTopKSoftmaxOP.apply(gating, finished, k) + + def wrapper_npu_mish(self): return NPUMishOP.apply(self) @@ -876,6 +895,7 @@ def add_onnx_ops(): torch_npu.npu_iou = wrapper_npu_iou torch_npu.npu_batch_nms = wrapper_npu_batch_nms torch_npu.fast_gelu = wrapper_npu_fast_gelu + torch_npu.npu_moe_gating_top_k_softmax = wrapper_npu_moe_gating_top_k_softmax torch_npu.npu_fast_gelu = wrapper_npu_fast_gelu torch_npu.npu_geglu = wrapper_npu_geglu torch_npu.npu_fused_attention_score = wrapper_npu_fused_attention_score