From bfc91fffecdcf9d98d7ba5410908eb8a58181ed2 Mon Sep 17 00:00:00 2001 From: l00855186 Date: Sat, 16 Mar 2024 09:37:58 +0800 Subject: [PATCH] support mm_all_reduce_add_rms_norm --- torch_npu/meta/meta_registrations.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torch_npu/meta/meta_registrations.py b/torch_npu/meta/meta_registrations.py index 5df8ebd955..cdbf78affd 100644 --- a/torch_npu/meta/meta_registrations.py +++ b/torch_npu/meta/meta_registrations.py @@ -326,6 +326,18 @@ def npu_mm_all_reduce_base_forward(x1, x2, hcom, reduce_op='sum', bias=None, ant return x1.new_empty(tuple(dim_list)) +@impl(m, "npu_mm_all_reduce_add_rms_norm") +def npu_mm_all_reduce_add_rms_norm_forward(x1, x2, residual, gamma, hcom, reduce_op='sum', epsilon=1e-6, bias=None, + antiquant_scale=None, antiquant_offset=None, dequant_scale=None, + antiquant_group_size=0, comm_turn=0): + return (torch.empty_like(residual, dtype=residual.dtype), torch.empty_like(residual, dtype=residual.dtype)) + + +@impl(m, "npu_mm_all_reduce_add_rms_norm_") +def npu_inplace_mm_all_reduce_add_rms_norm_forward(x1, x2, residual, gamma, hcom, reduce_op='sum', epsilon=1e-6, bias=None, + antiquant_scale=None, antiquant_offset=None, dequant_scale=None, + antiquant_group_size=0, comm_turn=0): + return (torch.empty_like(residual, dtype=residual.dtype), torch.empty_like(residual, dtype=residual.dtype)) @impl(m, "npu_mm_reduce_scatter_base") def npu_mm_reduce_scatter_base_meta(self, x2, hcom, world_size, reduce_op='sum', -- Gitee