From 08692604d8ff56a7a6373e9ec9fecac76a5a3ea9 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Wed, 26 Jan 2022 10:13:19 +0800 Subject: [PATCH 1/2] =?UTF-8?q?batch=5Fnorm=5Fstats=E7=AE=97=E5=AD=901.8.1?= =?UTF-8?q?=E7=A7=BB=E6=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../test_network_ops/test_batch_norm_stats.py | 65 ++++++++++++++ .../normalization/BatchNormStatsKernelNpu.cpp | 88 +++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 test/test_network_ops/test_batch_norm_stats.py create mode 100644 torch_npu/csrc/aten/ops/normalization/BatchNormStatsKernelNpu.cpp diff --git a/test/test_network_ops/test_batch_norm_stats.py b/test/test_network_ops/test_batch_norm_stats.py new file mode 100644 index 0000000000..8451db4c0d --- /dev/null +++ b/test/test_network_ops/test_batch_norm_stats.py @@ -0,0 +1,65 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type_tests +from torch_npu.testing.util_test import create_common_tensor, create_dtype_tensor, UT_FAST_MODE + +class TestBatchNormStats(TestCase): + # def cpu_op_exec(self, input1, mean, invstd, running_mean, running_var, momentum, eps, counts, normalize_type): + def cuda_op_exec(self, *args): + cpu_mean, cpu_invstd = torch.batch_norm_stats(*args) + return cpu_mean.numpy(), cpu_invstd.numpy() + + def cuda_expect_result(self): + cpu_output0 = np.array([5.401827, 5.444219, 5.7656665], dtype=np.float32) + cpu_output1 = np.array([0.37123242, 0.38706362, 0.37435925], dtype=np.float32) + return cpu_output0, cpu_output1 + + def npu_op_exec(self, *args): + npu_mean, npu_invstd = torch.batch_norm_stats(*args) + out_mean = npu_mean.cpu().numpy() + out_invstd = npu_invstd.cpu().numpy() + return out_mean, out_invstd + + def test_batch_norm_stats(self, device): + shape_format = [ + [[np.float16, -1, [2, 3, 12, 12]], 1e-5], + ] + for item in shape_format: + # NB: mixup precision ut, benchmarking with fp32 standard + cpu_input1, npu_inputfp16 = create_common_tensor(item[0], 1, 10) + # fp32 standard + npu_input1fp32 = npu_inputfp16.float() + if torch.cuda.is_available(): + cpu_output = self.cuda_op_exec(cpu_input1.cuda(), item[-1]) + else: + cpu_output = self.cuda_expect_result() + npu_outputfp16 = self.npu_op_exec(npu_inputfp16, item[-1]) + npu_outputfp32 = self.npu_op_exec(npu_inputfp16, item[-1]) + + self.assertRtolEqual(cpu_output[0], npu_outputfp16[0]) + self.assertRtolEqual(cpu_output[1], npu_outputfp16[1], 1e-2) + + self.assertRtolEqual(cpu_output[0], npu_outputfp32[0]) + self.assertRtolEqual(cpu_output[1], npu_outputfp32[1], 1e-2) + + +instantiate_device_type_tests(TestBatchNormStats, globals(), except_for='cpu') +if __name__ == "__main__": + run_tests() \ No newline at end of file diff --git a/torch_npu/csrc/aten/ops/normalization/BatchNormStatsKernelNpu.cpp b/torch_npu/csrc/aten/ops/normalization/BatchNormStatsKernelNpu.cpp new file mode 100644 index 0000000000..6898ea6a7d --- /dev/null +++ b/torch_npu/csrc/aten/ops/normalization/BatchNormStatsKernelNpu.cpp @@ -0,0 +1,88 @@ +// Copyright (c) 2022 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" + +namespace at_npu { +namespace native { + +std::tuple batch_norm_stats_out_npu_nocheck( + at::Tensor& mean, + at::Tensor& invstd, + const at::Tensor& self, + double eps) { + c10::SmallVector dim; + int dimN = self.ndimension(); + for(int i = 0; i < dimN; i++){ + if (i == 1) { + continue; + } + dim.emplace_back(i); + } + at::Tensor selfCp = self; + if (self.scalar_type() != at::kFloat){ + selfCp = NPUNativeFunctions::npu_dtype_cast(selfCp, at::kFloat); + } + OpCommand cmd1; + cmd1.Name("ReduceMean") + .Input(selfCp) + .Input(dim, at::kInt) + .Output(mean) + .Attr("keep_dims", (bool) false) + .Run(); + + at::Tensor meanCp = mean; + if (mean.dim() != 0) { + auto dimVector = array_to_small_vector(dim); + for (int64_t i = 0; i < dimVector.size(); i++) { + meanCp = meanCp.unsqueeze(dimVector[i]); + } + } + meanCp = meanCp.expand(self.sizes()); + OpCommand cmd2; + cmd2.Name("ReduceStdWithMean") + .Input(selfCp) + .Input(meanCp) + .Output(invstd) + .Attr("dim", dim) + .Attr("unbiased", false) + .Attr("keepdim", false) + .Attr("invert", true) + .Attr("epsilon", static_cast(eps)) + .Run(); + + return std::tie(mean, invstd); +} + +std::tuple NPUNativeFunctions::batch_norm_stats( + const at::Tensor& self, + double eps) { + TORCH_CHECK( + self.ndimension() >= 2, + "Expected 2D+ Tensor, but got tensor with ", + self.ndimension(), + " Dimension"); + at::Tensor mean = OpPreparation::ApplyTensor({self.size(1)}, self.options().dtype(at::kFloat), self); + at::Tensor invstd = OpPreparation::ApplyTensor({self.size(1)}, self.options().dtype(at::kFloat), self); + batch_norm_stats_out_npu_nocheck(mean, invstd, self, eps); + + return std::tie(mean, invstd); +} + +} // namespace native +} // namespace at_npu \ No newline at end of file -- Gitee From 00c91a3d77ff486f3d72e1c63ac54f39667b2309 Mon Sep 17 00:00:00 2001 From: hxf12345677 Date: Thu, 27 Jan 2022 08:39:01 +0000 Subject: [PATCH 2/2] =?UTF-8?q?batch=5Fnorm=5Fstats=20=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/test_network_ops/test_batch_norm_stats.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_network_ops/test_batch_norm_stats.py b/test/test_network_ops/test_batch_norm_stats.py index 8451db4c0d..88f1598df8 100644 --- a/test/test_network_ops/test_batch_norm_stats.py +++ b/test/test_network_ops/test_batch_norm_stats.py @@ -21,7 +21,6 @@ from torch_npu.testing.common_device_type import Dtypes, instantiate_device_type from torch_npu.testing.util_test import create_common_tensor, create_dtype_tensor, UT_FAST_MODE class TestBatchNormStats(TestCase): - # def cpu_op_exec(self, input1, mean, invstd, running_mean, running_var, momentum, eps, counts, normalize_type): def cuda_op_exec(self, *args): cpu_mean, cpu_invstd = torch.batch_norm_stats(*args) return cpu_mean.numpy(), cpu_invstd.numpy() @@ -42,9 +41,7 @@ class TestBatchNormStats(TestCase): [[np.float16, -1, [2, 3, 12, 12]], 1e-5], ] for item in shape_format: - # NB: mixup precision ut, benchmarking with fp32 standard cpu_input1, npu_inputfp16 = create_common_tensor(item[0], 1, 10) - # fp32 standard npu_input1fp32 = npu_inputfp16.float() if torch.cuda.is_available(): cpu_output = self.cuda_op_exec(cpu_input1.cuda(), item[-1]) -- Gitee