diff --git a/torch_npu/csrc/aten/npu_native_functions.yaml b/torch_npu/csrc/aten/npu_native_functions.yaml index 0b45e3046506ffcdb728513d89dac56a124ce00c..320c18f74c93ba4e6b9f1da3af896c5cf6fdfa59 100644 --- a/torch_npu/csrc/aten/npu_native_functions.yaml +++ b/torch_npu/csrc/aten/npu_native_functions.yaml @@ -76,6 +76,7 @@ custom: - func: get_storage_size(Tensor self) -> int - func: npu_format_cast(Tensor self, int acl_format) -> Tensor - func: _npu_format_cast(Tensor self, int acl_format) -> Tensor + - func: npu_rms_norm(Tensor self, Tensor gamma, float epsilon) -> (Tensor, Tensor) symint: - as_strided_ diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 660f38934aafc66963e3cb26bec03c3be9f0208e..5ccc3dbec2533d823a526a0a61a81c6b59f975b2 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -167,6 +167,15 @@ class NPUGiouOP(torch.autograd.Function): return g.op("npu::NPUGiou", self, gtboxes, trans_i=trans, is_cross_i=is_cross, mode_i=mode) +class NPURmsNormOP(torch.autograd.Function): + @staticmethod + def forward(ctx, *args, **kwargs): + return torch_npu._C._VariableFunctionsClass.npu_rms_norm(*args, **kwargs) + @staticmethod + def symbolic(g, self: torch.Tensor, gamma: torch.Tensor, epsilon: float = 1e-6): + return g.op("npu::NPURmsNorm", self, gamma, epsilon_i=epsilon, outputs=2) + + class NPUDeformableConv2dOP(torch.autograd.Function): @staticmethod @@ -736,6 +745,10 @@ def wrapper_npu_ptiou(bboxes, gtboxes, mode=0): return NPUPtiouOP.apply(bboxes, gtboxes, mode) +def wrapper_npu_rms_norm(self, gamma, epsilon=1e-6): + return NPURmsNormOP.apply(self, gamma, epsilon) + + def wrapper_npu_normalize_batch(self, seq_len, normalize_type=0): return NPUNormalizeBatchOP.apply(self, seq_len, normalize_type)