diff --git a/test/test_network_ops/test_binary_cross_entropy_with_logits_backward.py b/test/test_network_ops/test_binary_cross_entropy_with_logits_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..bce792bdb8159c082ef736a6b388897c09b1576f --- /dev/null +++ b/test/test_network_ops/test_binary_cross_entropy_with_logits_backward.py @@ -0,0 +1,81 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import torch +import torch_npu +import torch.nn as nn +import numpy as np + +from torch_npu.testing.common_utils import TestCase, run_tests +from torch_npu.testing.common_device_type import instantiate_device_type_tests + +def generate_data(min1, max1, shape, dtype): + input1 = np.random.uniform(min1, max1, shape).astype(dtype) + # modify from numpy.ndarray to torch.tensor + output = torch.from_numpy(input1) + # generate target: target.size == input1.size + label = torch.randint(shape[1], size=(shape[0],), dtype=torch.long) + target = torch.zeros(shape[0], shape[1]) + target[range(target.shape[0]), label] = 1 + target = target.to(output.dtype) + return output, target + +class TestBinaryCrossEntropyWithLogitsBackward(TestCase): + def cpu_op_exec(self, input1, target): + input1.requires_grad_(True) + output = torch.nn.functional.binary_cross_entropy_with_logits(input1, target) + input_cpu = output.detach().numpy() + output.backward() + res = input1.grad + res = res.numpy() + return input_cpu, res + + def npu_op_exec(self, input1, target): + target = target.to("npu") + input1 = input1.to("npu") + input1.requires_grad_(True) + output = torch.nn.functional.binary_cross_entropy_with_logits(input1, target) + input_npu = output.cpu() + input_npu = input_npu.detach().numpy() + output.backward() + res = input1.grad.cpu() + res = res.numpy() + return input_npu, res + + def test_binary_cross_entropy_with_logits_backward_fp32(self, device): + npu_input1, npu_target = generate_data(0, 100, (5, 3), np.float32) + cpu_input1 = copy.deepcopy(npu_input1) + cpu_target = copy.deepcopy(npu_target) + cpu_output, cpu_grad_output = self.cpu_op_exec(cpu_input1, cpu_target) + npu_output, npu_grad_output = self.npu_op_exec(npu_input1, npu_target) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad_output, npu_grad_output) + + def test_binary_cross_entropy_with_logits_backward_fp16(self, device): + npu_input1, npu_target = generate_data(0, 100, (5, 3), np.float16) + cpu_input1 = copy.deepcopy(npu_input1) + cpu_target = copy.deepcopy(npu_target) + cpu_input1 = cpu_input1.to(torch.float32) + cpu_target = cpu_target.to(torch.float32) + cpu_output, cpu_grad_output = self.cpu_op_exec(cpu_input1, cpu_target) + npu_output, npu_grad_output = self.npu_op_exec(npu_input1, npu_target) + cpu_output = cpu_output.astype(npu_output.dtype) + cpu_grad_output = cpu_grad_output.astype(npu_grad_output.dtype) + self.assertRtolEqual(cpu_output, npu_output) + self.assertRtolEqual(cpu_grad_output, npu_grad_output) + +instantiate_device_type_tests(TestBinaryCrossEntropyWithLogitsBackward, globals(), except_for="cpu") +if __name__ == "__main__": + run_tests() + diff --git a/torch_npu/csrc/aten/ops/BinaryCrossEntropyWithLogitsBackwardKernelNpu.cpp b/torch_npu/csrc/aten/ops/BinaryCrossEntropyWithLogitsBackwardKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4fa7a28279c1d24b479695bdbe35da7d5d5b0d8 --- /dev/null +++ b/torch_npu/csrc/aten/ops/BinaryCrossEntropyWithLogitsBackwardKernelNpu.cpp @@ -0,0 +1,58 @@ +// Copyright (c) 2020 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/utils/OpAdapter.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" + +namespace at_npu { +namespace native { + +at::Tensor NPUNativeFunctions::binary_cross_entropy_with_logits_backward( + const at::Tensor& grad_output, + const at::Tensor& self, + const at::Tensor& target, + const c10::optional& weight_opt, + const c10::optional& pos_weight_opt, + int64_t reduction) { + + const at::Tensor& weight = c10::value_or_else(weight_opt, [] {return at::Tensor();}); + const at::Tensor& pos_weight = c10::value_or_else(pos_weight_opt, [] {return at::Tensor();}); + + at::Tensor gradInput = OpPreparation::ApplyTensor(self); + at::Tensor weightTensor; + weightTensor = at::ones(self.sizes(), self.options()); + + at::Tensor posWeightTensor; + posWeightTensor = at::ones(self.sizes(), self.options()); + + at::Tensor doutTensor = NPUNativeFunctions::npu_broadcast(grad_output, self.sizes()); + std::string reductionStr = CalcuOpUtil::get_reduction_str(reduction); + OpCommand cmd; + cmd.Name("SigmoidCrossEntropyWithLogitsGradV2") + .Input(self) + .Input(target) + .Input(doutTensor) + .Input(weightTensor) + .Input(posWeightTensor) + .Output(gradInput) + .Attr("reduction", reductionStr) + .Run(); + + return gradInput; +} +} // namespace native +} // namespace at_npu \ No newline at end of file