From 4928168221c6a64cbf5825cdc890eb345a91591e Mon Sep 17 00:00:00 2001 From: FelixTang7 Date: Tue, 5 Dec 2023 09:45:19 +0800 Subject: [PATCH] Title ads Type: Feature Team: PyTorch_Ops_Dev InventoryUpdate: False Issue: issue_no Description: --- ads/common/__init__.py | 16 +- ads/common/ops/csrc/NpuSilu.cpp | 83 ++++++++ ads/common/ops/csrc/RotaryMulKernelNpu.cpp | 91 ++++++++ .../ops/csrc/RotatedBoxDecodeKernelNpu.cpp | 37 ++++ .../ops/csrc/RotatedBoxEncodeKernelNpu.cpp | 44 ++++ ads/common/ops/csrc/RotatedIouKernelNpu.cpp | 86 ++++++++ .../ops/csrc/RotatedOverlapsKernelNpu.cpp | 67 ++++++ ads/common/ops/csrc/ScatterV1KernelNpu.cpp | 38 ++++ ads/common/ops/csrc/SignBitsPackKernelNpu.cpp | 41 ++++ .../ops/csrc/SignBitsUnpackKernelNpu.cpp | 62 ++++++ ...SoftmaxCrossEntropyWithLogitsKernelNpu.cpp | 77 +++++++ ads/common/ops/csrc/StrideAddKernelNpu.cpp | 75 +++++++ ads/common/ops/csrc/TransposeKernelNpu.cpp | 76 +++++++ .../ops/csrc/YoloBoxesEncodeKernelNpu.cpp | 81 ++++++++ ads/common/ops/csrc/common.cpp | 195 ++++++++++++++++++ ads/common/ops/csrc/common.h | 29 +++ ads/common/ops/csrc/functions.h | 37 ++++ ads/common/ops/csrc/pybind.cpp | 36 ++++ ads/common/ops/rotary_mul.py | 22 ++ ads/common/ops/rotated_box_decode.py | 5 + ads/common/ops/rotated_box_encode.py | 5 + ads/common/ops/rotated_iou.py | 5 + ads/common/ops/rotated_overlaps.py | 5 + ads/common/ops/scatter.py | 5 + ads/common/ops/sign_bits_pack.py | 5 + ads/common/ops/sign_bits_unpack.py | 5 + ads/common/ops/silu.py | 25 +++ .../ops/softmax_cross_entropy_with_logits.py | 23 +++ ads/common/ops/stride_add.py | 5 + ads/common/ops/transpose.py | 5 + ads/common/ops/yolo_boxes_encode.py | 5 + tests/test_npu_rotary_mul.py | 54 +++++ tests/test_npu_scatter.py | 64 ++++++ tests/test_npu_silu.py | 59 ++++++ ...t_npu_softmax_cross_entropy_with_logits.py | 47 +++++ tests/test_npu_stride_add.py | 43 ++++ tests/test_npu_transpose.py | 37 ++++ tests/test_rotated_box.py | 52 +++++ tests/test_rotated_iou.py | 72 +++++++ tests/test_rotated_overlaps.py | 84 ++++++++ tests/test_sign_bits_pack.py | 42 ++++ tests/test_sign_bits_unpack.py | 45 ++++ tests/test_yolo_boxes_encode.py | 35 ++++ 43 files changed, 1924 insertions(+), 1 deletion(-) create mode 100644 ads/common/ops/csrc/NpuSilu.cpp create mode 100644 ads/common/ops/csrc/RotaryMulKernelNpu.cpp create mode 100644 ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/RotatedIouKernelNpu.cpp create mode 100644 ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp create mode 100644 ads/common/ops/csrc/ScatterV1KernelNpu.cpp create mode 100644 ads/common/ops/csrc/SignBitsPackKernelNpu.cpp create mode 100644 ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp create mode 100644 ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp create mode 100644 ads/common/ops/csrc/StrideAddKernelNpu.cpp create mode 100644 ads/common/ops/csrc/TransposeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/common.cpp create mode 100644 ads/common/ops/csrc/common.h create mode 100644 ads/common/ops/rotary_mul.py create mode 100644 ads/common/ops/rotated_box_decode.py create mode 100644 ads/common/ops/rotated_box_encode.py create mode 100644 ads/common/ops/rotated_iou.py create mode 100644 ads/common/ops/rotated_overlaps.py create mode 100644 ads/common/ops/scatter.py create mode 100644 ads/common/ops/sign_bits_pack.py create mode 100644 ads/common/ops/sign_bits_unpack.py create mode 100644 ads/common/ops/silu.py create mode 100644 ads/common/ops/softmax_cross_entropy_with_logits.py create mode 100644 ads/common/ops/stride_add.py create mode 100644 ads/common/ops/transpose.py create mode 100644 ads/common/ops/yolo_boxes_encode.py create mode 100644 tests/test_npu_rotary_mul.py create mode 100644 tests/test_npu_scatter.py create mode 100644 tests/test_npu_silu.py create mode 100644 tests/test_npu_softmax_cross_entropy_with_logits.py create mode 100644 tests/test_npu_stride_add.py create mode 100644 tests/test_npu_transpose.py create mode 100644 tests/test_rotated_box.py create mode 100644 tests/test_rotated_iou.py create mode 100644 tests/test_rotated_overlaps.py create mode 100644 tests/test_sign_bits_pack.py create mode 100644 tests/test_sign_bits_unpack.py create mode 100644 tests/test_yolo_boxes_encode.py diff --git a/ads/common/__init__.py b/ads/common/__init__.py index 0249033a..ec5a1f3a 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -1 +1,15 @@ -from .ops.scatter_max import scatter_max \ No newline at end of file +from .ops.scatter_max import scatter_max +from .ops.rotated_box_decode import npu_rotated_box_decode +from .ops.rotated_box_encode import npu_rotated_box_encode +from .ops.rotated_iou import npu_rotated_iou +from .ops.rotated_overlaps import npu_rotated_overlaps +from .ops.sign_bits_pack import npu_sign_bits_pack +from .ops.sign_bits_unpack import npu_sign_bits_unpack +from .ops.softmax_cross_entropy_with_logits import npu_softmax_cross_entropy_with_logits +from .ops.stride_add import npu_stride_add +from .ops.transpose import npu_transpose +from .ops.yolo_boxes_encode import npu_yolo_boxes_encode +from .ops.scatter import npu_scatter +from .ops.silu import npu_silu +from .ops.silu import npu_silu_ +from .ops.rotary_mul import npu_rotary_mul \ No newline at end of file diff --git a/ads/common/ops/csrc/NpuSilu.cpp b/ads/common/ops/csrc/NpuSilu.cpp new file mode 100644 index 00000000..0b81c0f4 --- /dev/null +++ b/ads/common/ops/csrc/NpuSilu.cpp @@ -0,0 +1,83 @@ +#include + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" +#include "common.h" + +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using npu_preparation = at_npu::native::OpPreparation; +using npu_utils = at_npu::native::NpuUtils; +using tensor_list = std::vector; + +at::Tensor &silu_out_npu_nocheck(at::Tensor &result, const at::Tensor &self) +{ + at_npu::native::OpCommand cmd; + cmd.Name("Swish") + .Input(self) + .Output(result) + .Attr("scale", (float)1.0) + .Run(); + return result; +} + +at::Tensor &silu_out_npu(const at::Tensor &self, at::Tensor &result) +{ + if (!check_match(result)) { + at::Tensor contiguous_result = result.contiguous(); + silu_out_npu_nocheck(contiguous_result, self); + format_fresh_view(result, contiguous_result); + } else { + silu_out_npu_nocheck(result, self); + } + + return result; +} + +at::Tensor silu_kernel_npu(const at::Tensor &self) +{ + at::Tensor result = at::empty(self.sizes(), self.options()); + + silu_out_npu_nocheck(result, self); + + return result; +} + +at::Tensor &silu_backward_out_npu_nocheck( + at::Tensor &result, + const at::Tensor &grad_output, + const at::Tensor &x0, + const at::Tensor &x1) +{ + at_npu::native::OpCommand cmd; + cmd.Name("SwishGrad") + .Input(grad_output) + .Input(x0) + .Input(x1) + .Output(result) + .Run(); + + return result; +} + +at::Tensor npu_silu_backward(const at::Tensor &grad_output, const at::Tensor &x0, const at::Tensor &x1) +{ + at::Tensor grad_input = at::empty(grad_output.sizes(), grad_output.options()); + silu_backward_out_npu_nocheck(grad_input, grad_output, x0, x1); + + return grad_input; +} + +at::Tensor npu_silu(const at::Tensor &self) +{ + return silu_kernel_npu(self); +} + +at::Tensor &npu_silu_(at::Tensor &self) +{ + silu_out_npu(self, self); + return self; +} \ No newline at end of file diff --git a/ads/common/ops/csrc/RotaryMulKernelNpu.cpp b/ads/common/ops/csrc/RotaryMulKernelNpu.cpp new file mode 100644 index 00000000..3e814e4e --- /dev/null +++ b/ads/common/ops/csrc/RotaryMulKernelNpu.cpp @@ -0,0 +1,91 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/aten/CustomFunctions.h" +#include "functions.h" +#include "common.h" + +using npu_preparation = at_npu::native::OpPreparation; +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_tuple = std::tuple; + +namespace { +at::Tensor &rotary_mul_nocheck(at::Tensor &y, const at::Tensor &x, const at::Tensor &r1, const at::Tensor &r2) +{ + if (x.sizes()[3] % 64 != 0) { + std::vector chunkResult = x.chunk(2, -1); + at::Tensor x_new = at::cat({chunkResult[1] * (-1), chunkResult[0]}, 3); + y = at::mul(r1, x) + at::mul(r2, x_new); + } else { + at_npu::native::OpCommand cmd; + cmd.Name("RotaryMul").Input(x).Input(r1).Input(r2).Output(y).Run(); + } + return y; +} + +tensor_tuple rotary_mul_backward_nocheck(at::Tensor &dx, at::Tensor &dr1, at::Tensor &dr2, const at::Tensor &x, + const at::Tensor &r1, const at::Tensor &r2, const at::Tensor &dy) +{ + TORCH_CHECK(x.dim() == 4, "The dim of input tensor [x] shoule equal to four."); + TORCH_CHECK(r1.dim() == 4, "The dim of input tensor [r1] shoule equal to four."); + TORCH_CHECK(r2.dim() == 4, "The dim of input tensor [r2] shoule equal to four."); + if (x.sizes()[3] % 64 != 0) { + at::Tensor x_grad_mul = at::mul(x, dy); + at::Tensor x1_grad_mul = at::mul(r1, dy); + at::Tensor x2_grad_mul = at::mul(r2, dy); + std::vector x2_chunk = x2_grad_mul.chunk(2, -1); + at::Tensor x2_chunk_cat = at::cat({x2_chunk[1], x2_chunk[0] * (-1)}, 3); + dx = at::add(x2_chunk_cat, x1_grad_mul); + c10::SmallVector dims; + for (int i = 0; i < 4; i++) { + if (x.sizes()[i] != r1.sizes()[i]) { + dims.emplace_back(i); + } + } + std::vector xq_chunk = x_grad_mul.chunk(2, -1); + at::Tensor xq_chunk_cat = at::cat({xq_chunk[1] * (-1), xq_chunk[0]}, 3); + dr2 = at::sum(xq_chunk_cat, dims, true); + dr1 = at::sum(x_grad_mul, dims, true); + } else { + at_npu::native::OpCommand cmd; + cmd.Name("RotaryMulGrad").Input(x).Input(r1).Input(r2).Input(dy).Output(dx).Output(dr1).Output(dr2).Run(); + } + return std::tie(dx, dr1, dr2); +} +} // namespace + +at::Tensor npu_rotary_mul(const at::Tensor &self, const at::Tensor &r1, const at::Tensor &r2) +{ + at::Tensor result = at::empty(self.sizes(), self.options()); + rotary_mul_nocheck(result, self, r1, r2); + return result; +} + +std::tuple npu_rotary_mul_backward(const at::Tensor &grad, const at::Tensor &self, + const at::Tensor &r1, const at::Tensor &r2) +{ + at::Tensor dx = at::empty(self.sizes(), self.options()); + at::Tensor dr1 = at::empty(r1.sizes(), r1.options()); + at::Tensor dr2 = at::empty(r2.sizes(), r2.options()); + rotary_mul_backward_nocheck(dx, dr1, dr2, self, r1, r2, grad); + return std::tie(dx, dr1, dr2); +} diff --git a/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp b/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp new file mode 100644 index 00000000..db949fc9 --- /dev/null +++ b/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp @@ -0,0 +1,37 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" + +using npu_preparation = at_npu::native::OpPreparation; +using npu_utils = at_npu::native::NpuUtils; + +at::Tensor npu_rotated_box_decode(const at::Tensor &self, const at::Tensor &deltas, const at::Tensor &weight) +{ + at::Tensor result = at::empty(self.sizes(), self.options()); + at::Tensor weight_cpu = weight.to(at::Device(at::kCPU), at::kFloat); + auto weight_ptr = weight_cpu.data_ptr(); + TORCH_CHECK(weight_ptr != nullptr, "weight_ptr is nullptr."); + at::ArrayRef weight_list(weight_ptr, weight_cpu.numel()); + + at_npu::native::OpCommand cmd; + cmd.Name("RotatedBoxDecode").Input(self).Input(deltas).Output(result).Attr("weight", weight_list).Run(); + return result; +} diff --git a/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp b/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp new file mode 100644 index 00000000..cfe515ba --- /dev/null +++ b/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp @@ -0,0 +1,44 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" + +using npu_preparation = at_npu::native::OpPreparation; + +at::Tensor npu_rotated_box_encode( + const at::Tensor &self, + const at::Tensor >Box, + const at::Tensor &weight) +{ + at::Tensor result = at::empty(self.sizes(), self.options()); + at::Tensor weight_cpu = weight.to(at::Device(at::kCPU), at::kFloat); + auto weight_ptr = weight_cpu.data_ptr(); + TORCH_CHECK(weight_ptr != nullptr, "weight_cpu is null") + at::ArrayRef weight_list(weight_ptr, weight_cpu.numel()); + + at_npu::native::OpCommand cmd; + cmd.Name("RotatedBoxEncode") + .Input(self) + .Input(gtBox) + .Output(result) + .Attr("weight", weight_list) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/RotatedIouKernelNpu.cpp b/ads/common/ops/csrc/RotatedIouKernelNpu.cpp new file mode 100644 index 00000000..7c8c334a --- /dev/null +++ b/ads/common/ops/csrc/RotatedIouKernelNpu.cpp @@ -0,0 +1,86 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/aten/CustomFunctions.h" +#include "functions.h" + +using npu_preparation = at_npu::native::OpPreparation; + +namespace { +at::Tensor &rotated_iou_npu_nocheck( + at::Tensor &iou, + const at::Tensor &boxes, + const at::Tensor &query_boxes, + bool trans, + int64_t mode, + bool is_cross, + double v_threshold, + double e_threshold) +{ + string mode_str = (mode == 0) ? "iou" : "iof"; + + at_npu::native::OpCommand cmd; + cmd.Name("RotatedIou") + .Input(boxes) + .Input(query_boxes) + .Output(iou) + .Attr("trans", trans) + .Attr("mode", mode_str) + .Attr("is_cross", is_cross) + .Attr("value", static_cast(v_threshold)) + .Attr("value", static_cast(e_threshold)) + .Run(); + return iou; +} +} // namespace + +at::Tensor npu_rotated_iou( + const at::Tensor &boxes, + const at::Tensor &query_boxes, + bool trans, + int64_t mode, + bool is_cross, + double v_threshold, + double e_threshold) +{ + TORCH_CHECK(boxes.ndimension() == 3 && query_boxes.ndimension() == 3); + + auto origin_dtype = boxes.scalar_type(); + + at::Tensor boxes_cp = boxes.permute({0, 2, 1}); + if (origin_dtype == at::kHalf) { + boxes_cp = boxes_cp.to(at::kFloat); + } + at::Tensor query_boxes_cp = query_boxes.permute({0, 2, 1}); + if (query_boxes_cp.scalar_type() == at::kHalf) { + query_boxes_cp = query_boxes_cp.to(at::kFloat); + } + + int64_t B = boxes_cp.size(0); + int64_t N = boxes_cp.size(-1); + int64_t K = query_boxes_cp.size(-1); + + c10::SmallVector output_size({B, N, K}); + at::Tensor iou = at::empty(output_size, boxes_cp.options()); + + rotated_iou_npu_nocheck(iou, boxes_cp, query_boxes_cp, trans, mode, is_cross, v_threshold, e_threshold); + iou = iou.to(origin_dtype); + return iou; +} diff --git a/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp b/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp new file mode 100644 index 00000000..ac476c62 --- /dev/null +++ b/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp @@ -0,0 +1,67 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/aten/CustomFunctions.h" +#include "functions.h" + +using npu_preparation = at_npu::native::OpPreparation; + +namespace { +at::Tensor &rotated_overlaps_npu_nocheck( + at::Tensor &overlaps, + const at::Tensor &self, + const at::Tensor &query_boxes, + bool trans) +{ + at_npu::native::OpCommand cmd; + cmd.Name("RotatedOverlaps") + .Input(self) + .Input(query_boxes) + .Output(overlaps) + .Attr("trans", trans) + .Run(); + return overlaps; +} +} // namespace + +at::Tensor npu_rotated_overlaps( + const at::Tensor &self, + const at::Tensor &query_boxes, + bool trans) +{ + TORCH_CHECK(self.ndimension() == 3 && query_boxes.ndimension() == 3, + "boxes' dim should be equal to query_boxes' ndimension() ", + "and equal to 3!"); + auto origin_dtype = self.scalar_type(); + // the Op only support fp32 currently! + at::Tensor self_cp = self.to(at::kFloat).permute({0, 2, 1}); + at::Tensor query_boxes_cp = query_boxes.to(at::kFloat).permute({0, 2, 1}); + + int64_t B = self_cp.size(0); + int64_t N = self_cp.size(-1); + int64_t K = query_boxes_cp.size(-1); + + c10::SmallVector output_size({B, N, K}); + at::Tensor overlaps = at::empty(output_size, self_cp.options()); + + rotated_overlaps_npu_nocheck(overlaps, self_cp, query_boxes_cp, trans); + overlaps = overlaps.to(origin_dtype); + return overlaps; +} diff --git a/ads/common/ops/csrc/ScatterV1KernelNpu.cpp b/ads/common/ops/csrc/ScatterV1KernelNpu.cpp new file mode 100644 index 00000000..155ea383 --- /dev/null +++ b/ads/common/ops/csrc/ScatterV1KernelNpu.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" + +using npu_preparation = at_npu::native::OpPreparation; + +at::Tensor npu_scatter(const at::Tensor &self, const at::Tensor &indices, const at::Tensor &updates, int64_t dim) +{ + at::Tensor outputs = at::empty(self.sizes(), self.options()); + at_npu::native::OpCommand cmd; + cmd.Name("ArgMaxGrad") + .Input(self) + .Input(indices) + .Input(updates) + .Output(outputs) + .Attr("dimension", dim) + .Run(); + + return outputs; +} diff --git a/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp b/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp new file mode 100644 index 00000000..5fb1139c --- /dev/null +++ b/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp @@ -0,0 +1,41 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" + +using npu_preparation = at_npu::native::OpPreparation; + +at::Tensor npu_sign_bits_pack(const at::Tensor &self, int64_t size) +{ + TORCH_CHECK(self.dim() == 1, "input must be one-dimensional"); + TORCH_CHECK(self.scalar_type() == at::ScalarType::Half || self.scalar_type() == at::ScalarType::Float, + "all only supports torch.float16 and torch.float32 dtypes"); + auto ysize = (self.numel() + 7) / 8; + TORCH_CHECK(size != 0 && ysize % size == 0, "all must be divisible by size"); + at::Tensor result = at::empty({size, ysize / size}, self.options().dtype(at::kByte)); + + at_npu::native::OpCommand cmd; + cmd.Name("SignBitsPack") + .Input(self) + .Output(result) + .Attr("size", size) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp b/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp new file mode 100644 index 00000000..e7a35680 --- /dev/null +++ b/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp @@ -0,0 +1,62 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" +#include "common.h" + +using npu_preparation = at_npu::native::OpPreparation; + +at::Tensor npu_sign_bits_unpack_compute( + const at::Tensor &input, + int64_t size, + c10::ScalarType dtype) +{ + int64_t dim = input.dim(); + TORCH_CHECK(dim == 1, "input value should be a 1-dimensional tensor"); + TORCH_CHECK(input.scalar_type() == at::ScalarType::Byte, "sign_bits_unpack input only supports torch.uint8 "); + TORCH_CHECK(size > 0, "The argument 'size' is not valid because it is less than or equal to zero"); + + int64_t input_size = input.numel(); + TORCH_CHECK((input_size * 8) % size == 0, "input value length*8 must be multiple of size"); + TORCH_CHECK(dtype == at::ScalarType::Float || dtype == at::ScalarType::Half, "The argument 'dtype' must be torch.float32 or torch.float16"); + int64_t m = input_size * 8 / size; + at::Tensor result = at::empty({size, m}, input.options().dtype(dtype)); + + int64_t type_enum = dtype == at::ScalarType::Half ? 1 : 0; + at_npu::native::OpCommand cmd; + cmd.Name("SignBitsUnpack") + .Input(input) + .Output(result) + .Attr("dtype", type_enum) + .Attr("size", size) + .Run(); + return result; +} + +at::Tensor npu_sign_bits_unpack(py::args args) +{ + TORCH_CHECK(args.size() == 3, "input args size shoule be 3"); + at::Tensor input = py::cast(args[0]); + int64_t size = py::cast(args[1]); + auto typeStr = py::cast(py::str(args[2])); + auto typePair = trans_torch_type_to_scalar(typeStr); + TORCH_CHECK(typePair.first, "input dtype is wrong"); + return npu_sign_bits_unpack_compute(input, size, typePair.second); +} diff --git a/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp b/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp new file mode 100644 index 00000000..e936c819 --- /dev/null +++ b/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp @@ -0,0 +1,77 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" +#include "common.h" + +using npu_preparation = at_npu::native::OpPreparation; +using torch::autograd::AutogradContext; +using torch::autograd::Function; +using tensor_list = std::vector; + +namespace { +std::tuple softmax_cross_entropy_with_logits_out_nocheck( + at::Tensor &result, + at::Tensor &backprop, + const at::Tensor &self, + const at::Tensor &labels) +{ + at_npu::native::OpCommand cmd; + cmd.Name("SoftmaxCrossEntropyWithLogits") + .Input(self) + .Input(labels) + .Output(result) + .Output(backprop) + .Run(); + + return std::tuple(result, backprop); +} + +std::tuple softmax_cross_entropy_with_logits_impl_out_nocheck( + const at::Tensor &self, + const at::Tensor &labels) +{ + auto output_sizes = softmax_cross_entropy_with_logits_impl_npu_output_size(self); + at::Tensor result = at::empty(std::get<0>(output_sizes), self.options()); + at::Tensor backprop = at::empty(std::get<1>(output_sizes), self.options()); + + softmax_cross_entropy_with_logits_out_nocheck(result, backprop, self, labels); + + return std::make_tuple(result, backprop); +} +} // namespace + +at::Tensor npu_softmax_cross_entropy_with_logits_backward( + const at::Tensor &grad, + const at::Tensor &self, + const at::Tensor &labels) +{ + at::Tensor result1 = std::get<1>(softmax_cross_entropy_with_logits_impl_out_nocheck(self, labels)); + return result1 * grad.unsqueeze(-1); +} + +at::Tensor npu_softmax_cross_entropy_with_logits( + const at::Tensor &self, + const at::Tensor &labels) +{ + TORCH_CHECK(torch_npu::utils::is_npu(self)); + return std::get<0>(softmax_cross_entropy_with_logits_impl_out_nocheck(self, labels)); +} diff --git a/ads/common/ops/csrc/StrideAddKernelNpu.cpp b/ads/common/ops/csrc/StrideAddKernelNpu.cpp new file mode 100644 index 00000000..ebcfbfda --- /dev/null +++ b/ads/common/ops/csrc/StrideAddKernelNpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" +#include "common.h" + +using npu_preparation = at_npu::native::OpPreparation; + +namespace { +at::Tensor &stride_add_out_npu_nocheck( + at::Tensor &result, + const at::Tensor &self, + const at::Tensor &other, + c10::Scalar offset1, + c10::Scalar offset2, + c10::Scalar c1_len) +{ + at_npu::native::OpCommand cmd; + cmd.Name("StrideAdd") + .Input(self, "x1") + .Input(other, "x2") + .Output(result, "y") + .Attr("x1_c1_offset", (int64_t)offset1.toInt()) + .Attr("x2_c1_offset", (int64_t)offset2.toInt()) + .Attr("c1_len", (int64_t)c1_len.toInt()) + .Run(); + return result; +} +} // namespace + +at::Tensor npu_stride_add_compute( + const at::Tensor &self, + const at::Tensor &other, + const c10::Scalar &offset1, + const c10::Scalar &offset2, + const c10::Scalar &c1_len) +{ + auto output_size = infersize_stride_add(self.sizes(), other.sizes()); + output_size[1] = c1_len.toInt() * 16; + at::Tensor result = at::empty(output_size, self.options()); + stride_add_out_npu_nocheck(result, self, other, offset1, offset2, c1_len); + return result; +} + +at::Tensor npu_stride_add(py::args args) +{ + TORCH_CHECK(args.size() == 5U, "input arg size: ", args.size(), " is wrong, size should be 5"); + at::Tensor self = py::cast(args[0]); + at::Tensor other = py::cast(args[1]); + int offsetTmp1 = py::cast(args[2]); + int offsetTmp2 = py::cast(args[3]); + int lenTmp = py::cast(args[4]); + c10::Scalar offset1 = c10::Scalar(offsetTmp1); + c10::Scalar offset2 = c10::Scalar(offsetTmp2); + c10::Scalar c10Len = c10::Scalar(lenTmp); + return npu_stride_add_compute(self, other, offset1, offset2, c10Len); +} diff --git a/ads/common/ops/csrc/TransposeKernelNpu.cpp b/ads/common/ops/csrc/TransposeKernelNpu.cpp new file mode 100644 index 00000000..ad9d2e97 --- /dev/null +++ b/ads/common/ops/csrc/TransposeKernelNpu.cpp @@ -0,0 +1,76 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "functions.h" +#include "common.h" + +using npu_utils = at_npu::native::NpuUtils; + +namespace { +at::Tensor &npu_transpose_out_nocheck( + at::Tensor &result, + const at::Tensor &self, + at::IntArrayRef perm, + bool require_contiguous) +{ + at_npu::native::OpCommand cmd; + if (require_contiguous) { + // Any tensor-view(discontiguous) Input Tensor from users should be transformed to be contiguous here. + cmd.Name("Transpose") + .Input(self) + .Input(perm) + .Output(result) + .Run(); + } else { + // For permute-opt in trans-contiguous, it accepts transposed(discontiguous) Input Tensor. + cmd.Name("Transpose") + .InputWithoutContiguous(self) + .Input(perm) + .Output(result) + .Run(); + } + return result; +} +} // namespace + +at::Tensor npu_transpose(const at::Tensor &self, at::IntArrayRef perm, bool require_contiguous) +{ + auto output_size = transpose_npu_output_size(self, perm); + at::Tensor result = at::empty(output_size, self.options()); + npu_transpose_out_nocheck(result, self, perm, require_contiguous); + + return result; +} + +at::Tensor &npu_transpose_out( + const at::Tensor &self, + at::IntArrayRef perm, + bool require_contiguous, + at::Tensor &result) +{ + if (!check_match(result)) { + at::Tensor contiguous_result = result.contiguous(); + npu_transpose_out_nocheck(contiguous_result, self, perm, require_contiguous); + format_fresh_view(result, contiguous_result); + } else { + npu_transpose_out_nocheck(result, self, perm, require_contiguous); + } + return result; +} diff --git a/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp b/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp new file mode 100644 index 00000000..f3cd4201 --- /dev/null +++ b/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp @@ -0,0 +1,81 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/aten/CustomFunctions.h" +#include "functions.h" +#include "common.h" + +using npu_preparation = at_npu::native::OpPreparation; + +namespace { +inline void yolo_boxes_encode_check( + const at::Tensor &anchor_boxes, + const at::Tensor >_bboxes, + const at::Tensor &stride) +{ + TORCH_CHECK( + anchor_boxes.dim() == 2 && anchor_boxes.size(1) == 4, + "Non-empty 2D anchor_boxes tensor expected but got a tensor with sizes ", + anchor_boxes.sizes()); + TORCH_CHECK( + anchor_boxes.size(0) <= 20480, + "anchor_boxes only support max [20480] num, but got num ", + anchor_boxes.size(0)); + TORCH_CHECK( + gt_bboxes.dim() == 2 && gt_bboxes.size(1) == 4, + "Non-empty 2D gt_bboxes tensor expected but got a tensor with sizes ", + gt_bboxes.sizes()); + TORCH_CHECK( + stride.dim() == 1, + "Non-empty 1D stride tensor expected but got a tensor with sizes ", + stride.sizes()); + TORCH_CHECK( + stride.size(0) == gt_bboxes.size(0), + "stride's length should be equal gt_bboxes' num, but got stride length ", + stride.size(0), + "gt_bboxes num ", + gt_bboxes.size(0)); + TORCH_CHECK( + at::isIntegralType(stride.scalar_type(), true) && stride.scalar_type() != at::ScalarType::Long, + "int32 strdie tensor expected but got a tensor with dtype: ", + stride.scalar_type()); +} +} // namespace + +at::Tensor npu_yolo_boxes_encode( + const at::Tensor &anchor_boxes, + const at::Tensor >_bboxes, + const at::Tensor &stride, + bool performance_mode) +{ + yolo_boxes_encode_check(anchor_boxes, gt_bboxes, stride); + at::Tensor result = at::empty(gt_bboxes.sizes(), gt_bboxes.options()); + string impl_mode_str = performance_mode ? "high_performance" : "high_precision"; + at::Tensor stride_cp = stride.to(at::ScalarType::Int); + at_npu::native::OpCommand cmd; + cmd.Name("YoloBoxesEncode") + .Input(anchor_boxes) + .Input(gt_bboxes) + .Input(stride_cp) + .Output(result) + .Attr("performance_mode", impl_mode_str) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/common.cpp b/ads/common/ops/csrc/common.cpp new file mode 100644 index 00000000..8e4b3037 --- /dev/null +++ b/ads/common/ops/csrc/common.cpp @@ -0,0 +1,195 @@ +#include +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/mirror/NPUMemoryOverlap.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "third_party/acl/inc/acl/acl_base.h" +#include "common.h" + +using npu_utils = at_npu::native::NpuUtils; +using CalcuOpUtil = at_npu::native::CalcuOpUtil; + +#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ + _(at::ScalarType::Byte, ACL_UINT8) \ + _(at::ScalarType::Char, ACL_INT8) \ + _(at::ScalarType::Short, ACL_INT16) \ + _(at::ScalarType::Int, ACL_INT32) \ + _(at::ScalarType::Long, ACL_INT64) \ + _(at::ScalarType::Half, ACL_FLOAT16) \ + _(at::ScalarType::Float, ACL_FLOAT) \ + _(at::ScalarType::Double, ACL_DOUBLE) \ + _(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \ + _(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \ + _(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \ + _(at::ScalarType::Bool, ACL_BOOL) \ + _(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ + _(at::ScalarType::BFloat16, ACL_BF16) \ + _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ + _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) + +static std::unordered_map dTypeTransMap{ + {"torch.float16", at::ScalarType::Half}, {"torch.half", at::ScalarType::Half}, {"torch.float32", at::ScalarType::Float}, {"torch.float", at::ScalarType::Float}, {"torch.float64", at::ScalarType::Double}, {"torch.float", at::ScalarType::Double}, {"torch.int8", at::ScalarType::Char}, {"torch.char", at::ScalarType::Char}, {"torch.int16", at::ScalarType::Short}, {"torch.short", at::ScalarType::Short}, {"torch.int32", at::ScalarType::Int}, {"torch.int32", at::ScalarType::Int}, {"torch.int64", at::ScalarType::Long}, {"torch.long", at::ScalarType::Long}}; + +static bool check_inplace_tensor(const std::initializer_list &src_list, const at::Tensor &dst) +{ + bool is_inplace_tensor = false; + // check whether dst is contained in src_list + for (const auto &src : src_list) { + if (dst.is_same(src)) { + is_inplace_tensor = true; + break; + } + } + return is_inplace_tensor; +} + +static void check_tensor_size(const std::initializer_list &src_list, at::Tensor &dst, + c10::IntArrayRef expect_size) +{ + bool is_inplace = check_inplace_tensor(src_list, dst); + // Preserve legacy resizing behavior of out=... arguments + if (!dst.sizes().equals(expect_size)) { + TORCH_CHECK(!is_inplace, "output with shape ", dst.sizes(), " doesn't match the broadcast shape ", + expect_size); + dst.resize_(expect_size); + } + return; +} + +constexpr aclDataType kATenScalarTypeToAclDataTypeTable + [static_cast(at::ScalarType::NumOptions) + 1] = { +#define DEFINE_ENUM(_1, n) n, + AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) +#undef DEFINE_ENUM +}; + +aclDataType ConvertToAclDataType(const at::ScalarType &data_type) +{ + auto acl_dtype = + kATenScalarTypeToAclDataTypeTable[static_cast(data_type)]; + TORCH_CHECK(acl_dtype != ACL_DT_UNDEFINED, + std::string(c10::toString(data_type)) + " has not been supported") + return acl_dtype; +} + +c10::SmallVector array_to_small_vector(c10::IntArrayRef shape) +{ + c10::SmallVector shape_small_vec; + for (uint64_t i = 0; i < shape.size(); i++) { + shape_small_vec.emplace_back(shape[i]); + } + return shape_small_vec; +} + +c10::SmallVector conv_transpose2d_npu_output_size(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, c10::IntArrayRef padding, + c10::IntArrayRef output_padding, + c10::IntArrayRef stride, c10::IntArrayRef dilation, + int64_t groups) +{ + int64_t N = input.size(0); + int64_t H = input.size(2); + int64_t W = input.size(3); + int64_t Co = weight.size(1) * groups; + auto kernel_size = weight.sizes().slice(2); + + int64_t Ho = (H - 1) * stride[0] - 2 * padding[0] + dilation[0] * (kernel_size[0] - 1) + output_padding[0] + 1; + int64_t Wo = (W - 1) * stride[1] - 2 * padding[1] + dilation[1] * (kernel_size[1] - 1) + output_padding[1] + 1; + + c10::SmallVector outputSize = {N, Co, Ho, Wo}; + + return outputSize; +} + +// tyf + +std::pair trans_torch_type_to_scalar(const std::string &type) +{ + if (dTypeTransMap.find(type) != dTypeTransMap.end()) { + return {true, dTypeTransMap[type]}; + } + return {false, at::ScalarType::Byte}; +} + +tuple_vector softmax_cross_entropy_with_logits_impl_npu_output_size(const at::Tensor &self) +{ + c10::SmallVector resultSize = array_to_small_vector(self.size(0)); + c10::SmallVector backpropSize = array_to_small_vector(self.sizes()); + + return std::tuple, c10::SmallVector>(resultSize, backpropSize); +} + +c10::SmallVector convert_array_to_vector(c10::IntArrayRef intArray) +{ + c10::SmallVector intVec; + for (uint64_t i = 0; i < intArray.size(); i++) { + intVec.emplace_back(intArray[i]); + } + return intVec; +} + +int64_t make_warp_dim(int64_t dim, int64_t dim_post_expr) +{ + if (dim_post_expr <= 0) { + dim_post_expr = 1; // this will make range [-1, 0] + } + if (dim < 0) { + dim += dim_post_expr; + } + return dim; +} + +// This logic is specially made for stride_add, and will be removed in future version. +c10::SmallVector infersize_stride_add(c10::IntArrayRef shape1_, c10::IntArrayRef shape2_) +{ + auto shape1 = array_to_small_vector(shape1_); + auto shape2 = array_to_small_vector(shape2_); + + c10::SmallVector output_shape; + if (shape1.size() < shape2.size()) { + c10::SmallVector shapeTemp = shape1; + shape1 = shape2; + shape2 = shapeTemp; + } + + uint64_t shape1_size = shape1.size(); + uint64_t shape2_size = shape2.size(); + for (uint64_t i = 0; i < shape1_size - shape2_size; i++) { + shape2.insert(shape2.begin(), 1); + } + + for (uint64_t i = 0; i < shape1_size; i++) { + if (shape1[i] == 0 || shape2[i] == 0) { + output_shape.emplace_back((int64_t)0); + } else { + output_shape.emplace_back((shape1[i] > shape2[i]) ? shape1[i] : shape2[i]); + } + } + return output_shape; +} + +c10::SmallVector transpose_npu_output_size(const at::Tensor &self, c10::IntArrayRef perm) +{ + auto sizes = self.sizes(); + c10::SmallVector shape; + for (uint64_t i = 0; i < perm.size(); i++) { + shape.emplace_back(sizes[perm[i]]); + } + + return shape; +} + +bool check_match(const at::Tensor &self) +{ + static auto op = c10::Dispatcher::singleton().findSchemaOrThrow("aten::check_match", "").typed(); + return op.call(self); +} + +void format_fresh_view(at::Tensor &x, const at::Tensor &y) +{ + x.copy_(y); +} \ No newline at end of file diff --git a/ads/common/ops/csrc/common.h b/ads/common/ops/csrc/common.h new file mode 100644 index 00000000..49a75653 --- /dev/null +++ b/ads/common/ops/csrc/common.h @@ -0,0 +1,29 @@ +#include +#include +#include +#include +#include +#include "torch_npu/csrc/core/npu/NPUMacros.h" +#include "torch_npu/csrc/framework/utils/NPUDefinition.h" +#include "third_party/acl/inc/acl/acl_base.h" + +const int N = 32; +const int SIZE = 8; + +using tuple_vector = std::tuple, c10::SmallVector>; +aclDataType ConvertToAclDataType(const at::ScalarType &data_type); +c10::SmallVector array_to_small_vector(c10::IntArrayRef shape); +c10::SmallVector conv_transpose2d_npu_output_size(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, c10::IntArrayRef padding, + c10::IntArrayRef output_padding, + c10::IntArrayRef stride, c10::IntArrayRef dilation, + int64_t groups); + +std::pair trans_torch_type_to_scalar(const std::string &type); +tuple_vector softmax_cross_entropy_with_logits_impl_npu_output_size(const at::Tensor& self); +int64_t make_warp_dim(int64_t dim, int64_t dim_post_expr); +c10::SmallVector convert_array_to_vector(c10::IntArrayRef intArray); +c10::SmallVector infersize_stride_add(c10::IntArrayRef shape1_, c10::IntArrayRef shape2_); +c10::SmallVector transpose_npu_output_size(const at::Tensor &self, c10::IntArrayRef perm); +bool check_match(const at::Tensor &self); +void format_fresh_view(at::Tensor &x, const at::Tensor &y); \ No newline at end of file diff --git a/ads/common/ops/csrc/functions.h b/ads/common/ops/csrc/functions.h index d5b69716..daa82f4d 100644 --- a/ads/common/ops/csrc/functions.h +++ b/ads/common/ops/csrc/functions.h @@ -1,7 +1,44 @@ #include #include +#include +#include +#include +#include void init_common(pybind11::module &m); std::tuple npu_scatter_max(const at::Tensor& updates, const at::Tensor& indices, c10::optional out); at::Tensor npu_scatter_max_backward(const at::Tensor& x, const at::Tensor& segment_ids, const at::Tensor& num_segments); + +at::Tensor npu_rotated_box_decode(const at::Tensor &self, const at::Tensor &deltas, const at::Tensor &weight); +at::Tensor npu_rotated_box_encode( + const at::Tensor& self, + const at::Tensor& gtBox, + const at::Tensor& weight); +at::Tensor npu_rotated_iou( + const at::Tensor& boxes, + const at::Tensor& query_boxes, + bool trans, + int64_t mode, + bool is_cross, + double v_threshold, + double e_threshold); +at::Tensor npu_rotated_overlaps( + const at::Tensor& self, + const at::Tensor& query_boxes, + bool trans); +at::Tensor npu_scatter(const at::Tensor& self, const at::Tensor& indices, const at::Tensor& updates, int64_t dim); +at::Tensor npu_sign_bits_pack(const at::Tensor& self, int64_t size); +at::Tensor npu_sign_bits_unpack(py::args args); +at::Tensor npu_softmax_cross_entropy_with_logits(const at::Tensor &self, const at::Tensor &lables); +at::Tensor npu_stride_add(py::args args); +at::Tensor npu_transpose(const at::Tensor &self, at::IntArrayRef perm, bool require_contiguous); +at::Tensor npu_yolo_boxes_encode( + const at::Tensor& anchor_boxes, + const at::Tensor& gt_bboxes, + const at::Tensor& stride, + bool performance_mode); +at::Tensor npu_scatter(const at::Tensor& self, const at::Tensor& indices, const at::Tensor& updates, int64_t dim); +at::Tensor npu_rotary_mul(const at::Tensor &self, const at::Tensor &r1, const at::Tensor &r2); +at::Tensor npu_silu(const at::Tensor& self); +at::Tensor& npu_silu_(at::Tensor& self); \ No newline at end of file diff --git a/ads/common/ops/csrc/pybind.cpp b/ads/common/ops/csrc/pybind.cpp index 7feba1da..c4383ac8 100644 --- a/ads/common/ops/csrc/pybind.cpp +++ b/ads/common/ops/csrc/pybind.cpp @@ -5,4 +5,40 @@ void init_common(pybind11::module &m) { m.def("npu_scatter_max", &npu_scatter_max); m.def("npu_scatter_max_backward", &npu_scatter_max_backward); + + // rotatedBox kernel + m.def("npu_rotated_box_decode", &npu_rotated_box_decode, "npu_rot_box_decode NPU version"); + m.def("npu_rotated_box_encode", &npu_rotated_box_encode, "npu_rot_box_encode NPU version"); + + // rotated iou + m.def("npu_rotated_iou", &npu_rotated_iou, "npu_rotated_iou NPU version"); + + // roated overlap + m.def("npu_rotated_overlaps", &npu_rotated_overlaps, "npu_rotated_overlap NPU version"); + + // sign bits + m.def("npu_sign_bits_pack", &npu_sign_bits_pack, "npu_sign_bits_pack NPU version"); + m.def("npu_sign_bits_unpack", &npu_sign_bits_unpack, "npu_sign_bits_unpack NPU version"); + + // softmax + m.def("npu_softmax_cross_entropy_with_logits", &npu_softmax_cross_entropy_with_logits, "npu_softmax_cross_entropy_with_logits NPU version"); + + // stride add + m.def("npu_stride_add", &npu_stride_add, "npu_stride_add NPU version"); + + // transpose + m.def("npu_transpose", &npu_transpose, "npu_transpose NPU version"); + + // yolo encode + m.def("npu_yolo_boxes_encode", &npu_yolo_boxes_encode, "npu_yolo_boxes_encode NPU version"); + + // scatter + m.def("npu_scatter", &npu_scatter, "npu_scatter NPU version"); + + // silu + m.def("npu_silu_", &npu_silu_); + m.def("npu_silu", &npu_silu); + + // rotary mul + m.def("npu_rotary_mul", &npu_rotary_mul); } diff --git a/ads/common/ops/rotary_mul.py b/ads/common/ops/rotary_mul.py new file mode 100644 index 00000000..bf9c2a9b --- /dev/null +++ b/ads/common/ops/rotary_mul.py @@ -0,0 +1,22 @@ +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads_c + + +class RotaryMulFunction(Function): + @staticmethod + def forward(ctx, input, r1, r2): + result = ads_c.npu_rotary_mul(input, r1, r2) + ctx.save_for_backward(input, r1, r2) + return result + + @staticmethod + def backward(ctx, grad_output): + input, r1, r2 = ctx.saved_tensors + result = ads_c.npu_rotary_mul_backward(grad_output, input, r1, r2) + return result + +npu_rotary_mul = RotaryMulFunction.apply \ No newline at end of file diff --git a/ads/common/ops/rotated_box_decode.py b/ads/common/ops/rotated_box_decode.py new file mode 100644 index 00000000..621218f1 --- /dev/null +++ b/ads/common/ops/rotated_box_decode.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_rotated_box_decode = ads_c.npu_rotated_box_decode diff --git a/ads/common/ops/rotated_box_encode.py b/ads/common/ops/rotated_box_encode.py new file mode 100644 index 00000000..17c21f92 --- /dev/null +++ b/ads/common/ops/rotated_box_encode.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_rotated_box_encode = ads_c.npu_rotated_box_encode diff --git a/ads/common/ops/rotated_iou.py b/ads/common/ops/rotated_iou.py new file mode 100644 index 00000000..d88d3e9b --- /dev/null +++ b/ads/common/ops/rotated_iou.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_rotated_iou = ads_c.npu_rotated_iou \ No newline at end of file diff --git a/ads/common/ops/rotated_overlaps.py b/ads/common/ops/rotated_overlaps.py new file mode 100644 index 00000000..40753235 --- /dev/null +++ b/ads/common/ops/rotated_overlaps.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_rotated_overlaps = ads_c.npu_rotated_overlaps \ No newline at end of file diff --git a/ads/common/ops/scatter.py b/ads/common/ops/scatter.py new file mode 100644 index 00000000..7d89109c --- /dev/null +++ b/ads/common/ops/scatter.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_scatter = ads_c.npu_scatter \ No newline at end of file diff --git a/ads/common/ops/sign_bits_pack.py b/ads/common/ops/sign_bits_pack.py new file mode 100644 index 00000000..c09d486a --- /dev/null +++ b/ads/common/ops/sign_bits_pack.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_sign_bits_pack = ads_c.npu_sign_bits_pack \ No newline at end of file diff --git a/ads/common/ops/sign_bits_unpack.py b/ads/common/ops/sign_bits_unpack.py new file mode 100644 index 00000000..efa1a2dd --- /dev/null +++ b/ads/common/ops/sign_bits_unpack.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_sign_bits_unpack = ads_c.npu_sign_bits_unpack \ No newline at end of file diff --git a/ads/common/ops/silu.py b/ads/common/ops/silu.py new file mode 100644 index 00000000..8ca866db --- /dev/null +++ b/ads/common/ops/silu.py @@ -0,0 +1,25 @@ +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads_c + + +class SiluFunction(Function): + @staticmethod + def forward(ctx, input): + func = ads_c.npu_silu + result = func(input) + ctx.save_for_backward(input, result) + return result + + @staticmethod + def backward(ctx, grad_outputs): + x0, x1 = ctx.saved_tensors + result = ads_c.npu_silu_backward(grad_outputs, x0, x1) + return result + +npu_silu = SiluFunction.apply + +npu_silu_ = ads_c.npu_silu_ \ No newline at end of file diff --git a/ads/common/ops/softmax_cross_entropy_with_logits.py b/ads/common/ops/softmax_cross_entropy_with_logits.py new file mode 100644 index 00000000..f09d2a3e --- /dev/null +++ b/ads/common/ops/softmax_cross_entropy_with_logits.py @@ -0,0 +1,23 @@ +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads_c + + +class SoftMaxFunction(Function): + @staticmethod + def forward(ctx, feature, labels): + func = ads_c.npu_softmax_cross_entropy_with_logits + result = func(feature, labels) + ctx.save_for_backward(feature, labels) + return result + + @staticmethod + def backward(ctx, grad_output): + feature, labels = ctx.saved_tensors + result = ads_c.npu_softmax_cross_entropy_with_logits_backward(grad_output, feature, labels) + return result + +npu_softmax_cross_entropy_with_logits = SoftMaxFunction.apply \ No newline at end of file diff --git a/ads/common/ops/stride_add.py b/ads/common/ops/stride_add.py new file mode 100644 index 00000000..24a3946b --- /dev/null +++ b/ads/common/ops/stride_add.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_stride_add = ads_c.npu_stride_add \ No newline at end of file diff --git a/ads/common/ops/transpose.py b/ads/common/ops/transpose.py new file mode 100644 index 00000000..14972299 --- /dev/null +++ b/ads/common/ops/transpose.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_transpose = ads_c.npu_transpose \ No newline at end of file diff --git a/ads/common/ops/yolo_boxes_encode.py b/ads/common/ops/yolo_boxes_encode.py new file mode 100644 index 00000000..585adb58 --- /dev/null +++ b/ads/common/ops/yolo_boxes_encode.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_yolo_boxes_encode = ads_c.npu_yolo_boxes_encode \ No newline at end of file diff --git a/tests/test_npu_rotary_mul.py b/tests/test_npu_rotary_mul.py new file mode 100644 index 00000000..c84cf1e7 --- /dev/null +++ b/tests/test_npu_rotary_mul.py @@ -0,0 +1,54 @@ +import unittest +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] + + +class TestRotaryMul(TestCase): + def rotary_mul(self, x, r1, r2): + x1, x2 = torch.chunk(x, 2, -1) + x_new = torch.cat((-x2, x1), dim=-1) + output = r1 * x + r2 * x_new + return output + + def gen_data(self, shape, dtype): + cpu_input = torch.rand(shape, dtype=dtype) + npu_input = cpu_input.npu() + return cpu_input, npu_input + + def cpu_to_exec(self, x, r1, r2): + out = self.rotary_mul(x, r1, r2) + return out.cpu().numpy() + + def npu_to_exec(self, x, r1, r2): + out = ads.common.npu_rotary_mul(x, r1, r2) + return out.cpu().numpy() + + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `RotaryMul` is only supported on 910B, skip this ut!") + def test_rotary_mul(self): + dtype_list = [torch.float16, torch.float32] + shape_list = [ + [[2, 8192, 5, 128], [1, 8192, 1, 128], [1, 8192, 1, 128]], + [[8192, 2, 5, 128], [8192, 1, 1, 128], [8192, 1, 1, 128]], + [[2048, 4, 32, 64], [2048, 4, 1, 64], [2048, 4, 1, 64]], + ] + items = [ + [shape, dtype] + for shape in shape_list + for dtype in dtype_list + ] + for shape, dtype in items: + cpu_x, npu_x = self.gen_data(shape[0], dtype) + cpu_r1, npu_r1 = self.gen_data(shape[1], dtype) + cpu_r2, npu_r2 = self.gen_data(shape[2], dtype) + cpu_out = self.cpu_to_exec(cpu_x, cpu_r1, cpu_r2) + npu_out = self.npu_to_exec(npu_x, npu_r1, npu_r2) + self.assertRtolEqual(cpu_out, npu_out) + + +if __name__ == '__main__': + run_tests() diff --git a/tests/test_npu_scatter.py b/tests/test_npu_scatter.py new file mode 100644 index 00000000..afdf691e --- /dev/null +++ b/tests/test_npu_scatter.py @@ -0,0 +1,64 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestNpuScatter(TestCase): + def supported_op_exec(self, input1, indices, updates, dim): + tmp = input1.reshape(-1) + shape = input1.shape + dim_len = shape[dim] + + for i in range(indices.numel()): + tmp[i * dim_len + indices[i]] = updates[i] + + output = tmp.reshape(shape).to('cpu') + output = output.numpy() + return output + + def npu_op_exec(self, input1, indices, updates, dim): + output = ads.common.npu_scatter(input1, indices, updates, dim) + output = output.to("cpu") + output = output.numpy() + return output + + def test_npu_scatter(self, device="npu"): + input1_list = [[[1.6279, 0.1226], [0.9041, 1.0980]]] + indices_list = [[0, 1]] + updates_list = [[-1.1993, -1.5247]] + dim_list = [0] + exoutput_list = [[[-1.1993, 0.1226], [0.9041, -1.5247]]] + + shape_format = [[i, j, k, h, f] for i in input1_list + for j in indices_list for k in updates_list for h in dim_list for f in exoutput_list] + + for item in shape_format: + input1_tensor = torch.tensor(item[0]).npu() + indices_tensor = torch.tensor(item[1]).npu().to(torch.int32) + updates_tensor = torch.tensor(item[2]).npu() + dim = item[3] + exoutput_tensor = torch.tensor(item[4]) + output1 = self.npu_op_exec(input1_tensor, indices_tensor, updates_tensor, dim) + output2 = self.supported_op_exec(input1_tensor, indices_tensor, updates_tensor, dim) + self.assertRtolEqual(exoutput_tensor.numpy(), output1) + self.assertRtolEqual(output1, output2) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_silu.py b/tests/test_npu_silu.py new file mode 100644 index 00000000..ef561ce0 --- /dev/null +++ b/tests/test_npu_silu.py @@ -0,0 +1,59 @@ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestNpuSilu(TestCase): + + def cpu_op_exec_silu(self, input1): + output = input1 * torch.nn.functional.sigmoid(input1) + output = output.cpu().numpy() + return output + + def cpu_op_exec_silu_(self, input1): + result = input1 * torch.nn.functional.sigmoid(input1) + input1 = result.cpu().numpy() + return input1 + + def npu_op_exec_silu(self, input1): + output = ads.common.npu_silu(input1) + output = output.cpu().numpy() + return output + + def npu_op_exec_silu_(self, input1): + ads.common.npu_silu_(input1) + return input1.cpu().numpy() + + def test_silu(self): + input1 = torch.randn(5, 5).npu() + cput_out = self.cpu_op_exec_silu(input1) + npu_out = self.npu_op_exec_silu(input1) + self.assertRtolEqual(cput_out, npu_out) + + def test_silu_(self): + input1 = torch.randn(5, 5).npu() + input2 = torch.clone(input1).npu() + input1 = self.cpu_op_exec_silu_(input1) + input2 = self.npu_op_exec_silu_(input2) + self.assertRtolEqual(input1, input2) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_softmax_cross_entropy_with_logits.py b/tests/test_npu_softmax_cross_entropy_with_logits.py new file mode 100644 index 00000000..bf2a24d8 --- /dev/null +++ b/tests/test_npu_softmax_cross_entropy_with_logits.py @@ -0,0 +1,47 @@ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestSoftmaxCrossEntropyWithLogits(TestCase): + + def supported_op_exec(self, input1, label): + softmax = torch.nn.functional.softmax(input1) + log_softmax = torch.log(softmax) + loss = torch.sum(- label * log_softmax, dim=1) + return loss.cpu().detach() + + def custom_op_exec(self, input1, label): + output = ads.common.npu_softmax_cross_entropy_with_logits(input1, label) + return output.cpu().detach() + + def test_npu_softmax_cross_entropy_with_logits(self, device="npu"): + item = [np.float32, 0, (64, 10)] + _, npu_input = create_common_tensor(item, -1, 1) + _, label = create_common_tensor(item, 0, 1) + supported_output = self.supported_op_exec(npu_input, label) + custom_output = self.custom_op_exec(npu_input, label) + self.assertRtolEqual(supported_output, custom_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_stride_add.py b/tests/test_npu_stride_add.py new file mode 100644 index 00000000..873335ad --- /dev/null +++ b/tests/test_npu_stride_add.py @@ -0,0 +1,43 @@ +import torch +import numpy as np +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.decorator import Dtypes, instantiate_tests +import ads.common + + +@instantiate_tests +class TestNpuStrideAdd(TestCase): + def _split_npu_stride_add(self, x1, x2, offset1, offset2, c1_len): + x1 = x1[:, offset1:, :, :, :] + x2 = x2[:, offset2:, :, :, :] + x1_size = list(x1.size()) + x2_size = list(x2.size()) + x1_pad_size = [x1_size[0]] + [16 * c1_len - x1_size[1]] + x1_size[2:] + x2_pad_size = [x2_size[0]] + [16 * c1_len - x2_size[1]] + x2_size[2:] + x1_pad = torch.cat((x1, torch.zeros(*x1_pad_size, device='npu')), 1) + x2_pad = torch.cat((x2, torch.zeros(*x2_pad_size, device='npu')), 1) + return torch.add(x1_pad, x2_pad) + + def npu_op_exec(self, input1, input2, offset1, offset2, c1_len): + output = ads.common.npu_stride_add(input1, input2, offset1, offset2, c1_len) + output = output.to("cpu") + output = output.numpy() + return output + + def split_npu_op_exec(self, input1, input2, offset1, offset2, c1_len): + output = self._split_npu_stride_add(input1, input2, offset1, offset2, c1_len) + output = output.to("cpu") + output = output.numpy() + return output + + def test_StrideAdd(self): + input1 = torch.tensor([[[[[1.]]]]]).npu() + input2 = input1 + exoutput = self.npu_op_exec(input1, input2, 0, 0, 1) + output = self.split_npu_op_exec(input1, input2, 0, 0, 1) + self.assertRtolEqual(exoutput, output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_transpose.py b/tests/test_npu_transpose.py new file mode 100644 index 00000000..787a30e0 --- /dev/null +++ b/tests/test_npu_transpose.py @@ -0,0 +1,37 @@ +import torch +import numpy as np + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestNpuTransepose(TestCase): + def custom_op_exec(self, input1, perm): + output = input1.permute(perm) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec(self, input1, perm): + output = ads.common.npu_transpose(input1, perm, True) + output = output.to("cpu") + output = output.numpy() + return output + + def test_npu_transpose(self): + shape_format = [ + [[np.float32, 0, (5, 3, 6, 4)], [1, 0, 2, 3]], + [[np.float16, 0, (5, 3, 6, 4)], [0, 3, 2, 1]], + ] + + for item in shape_format: + _, npu_input1 = create_common_tensor(item[0], 0, 100) + custom_output = self.custom_op_exec(npu_input1, item[1]) + npu_output = self.npu_op_exec(npu_input1, item[1]) + self.assertRtolEqual(custom_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_rotated_box.py b/tests/test_rotated_box.py new file mode 100644 index 00000000..e66dd520 --- /dev/null +++ b/tests/test_rotated_box.py @@ -0,0 +1,52 @@ +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestRotatedBox(TestCase): + def npu_op_encode_exec(self, anchor_boxes, gt_bboxes, weight): + out = ads.common.npu_rotated_box_encode(anchor_boxes, gt_bboxes, weight) + out = out.to("cpu") + return out.detach().numpy() + + def npu_op_decode_exec(self, anchor_boxes, deltas, weight): + out = ads.common.npu_rotated_box_decode(anchor_boxes, deltas, weight) + out = out.to("cpu") + return out.detach().numpy() + + def test_rotated_boxes_encode_fp32(self, device="npu"): + anchor_boxes = torch.tensor([[[44.2877], [9.1412], [88.7575], [25.8879], [64.8047]]]).to("npu") + gt_bboxes = torch.tensor([[[39.1763], [0.9838], [78.1028], [29.5997], [51.5907]]]).to("npu") + weight = torch.tensor([1., 1., 1., 1., 1.]).npu() + expect_cpu = torch.tensor([[[-0.1773], [-0.1327], [-0.1331], [0.5358], [-0.8643]]]) + npu_output = self.npu_op_encode_exec(anchor_boxes, gt_bboxes, weight) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + + def test_rotated_boxes_decode_fp32(self, device="npu"): + anchor_boxes = torch.tensor([[[32.1855], [41.9922], [64.1435], [62.5325], [34.607]]]).to("npu") + deltas = torch.tensor([[[1.8725], [-1.8915], [0.2395], [-0.4622], [-34.6539]]]).to("npu") + weight = torch.tensor([1., 1., 1., 1., 1.]).npu() + expect_cpu = torch.tensor([[[87.70366], [6.9412346], [128.31055], [19.879467], [-88.313515]]]) + npu_output = self.npu_op_decode_exec(anchor_boxes, deltas, weight) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + + def test_rotated_boxes_encode_fp16(self, device="npu"): + anchor_boxes = torch.tensor([[[30.69], [32.6], [45.94], [59.88], [-44.53]]], dtype=torch.float16).to("npu") + gt_bboxes = torch.tensor([[[30.44], [18.72], [33.22], [45.56], [8.5]]], dtype=torch.float16).to("npu") + weight = torch.tensor([1., 1., 1., 1., 1.], dtype=torch.float16).npu() + expect_cpu = torch.tensor([[[-0.4253], [-0.5166], [-1.702], [-0.0162], [1.133]]], dtype=torch.float16) + npu_output = self.npu_op_encode_exec(anchor_boxes, gt_bboxes, weight) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + + def test_rotated_boxes_decode_fp16(self, device="npu"): + anchor_boxes = torch.tensor([[[4.137], [33.72], [29.4], [54.06], [41.28]]], dtype=torch.float16).to("npu") + deltas = torch.tensor([[[0.0244], [-1.992], [0.2109], [0.315], [-37.25]]], dtype=torch.float16).to("npu") + weight = torch.tensor([1., 1., 1., 1., 1.], dtype=torch.float16).npu() + expect_cpu = torch.tensor([[[1.786], [-10.58], [33.], [17.3], [-88.44]]], dtype=torch.float16) + npu_output = self.npu_op_decode_exec(anchor_boxes, deltas, weight) + self.assertRtolEqual(expect_cpu.numpy(), npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_rotated_iou.py b/tests/test_rotated_iou.py new file mode 100644 index 00000000..94b50b6a --- /dev/null +++ b/tests/test_rotated_iou.py @@ -0,0 +1,72 @@ +import torch +import numpy as np +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestRotatedIou(TestCase): + def generate_rto_data(self, item): + np.random.seed(1234) + minValue, maxValue = 20, 60 + scope = 20 + dtype = item[0][0] + shape_one = item[0][-1] + shape_two = item[1][-1] + trans = item[-1] + + boxes_array1 = np.random.uniform(minValue, maxValue, shape_one[:2] + [2]).astype(dtype) + boxes_wh = np.random.randint(1, scope, size=shape_one[:2] + [2]) + boxes_angle = np.random.randint(-180, 180, size=shape_one[:2] + [1]) + boxes = np.concatenate([boxes_array1, boxes_wh, boxes_angle], dtype=dtype, axis=-1) + # query_boxes + query_boxes_array1 = np.random.uniform(minValue, maxValue, shape_two[:2] + [2]).astype(dtype) + query_boxes_wh = np.random.randint(1, scope, size=shape_two[:2] + [2]) + query_boxes_angle = np.random.randint(-180, 180, size=shape_two[:2] + [1]) + query_boxes = np.concatenate([query_boxes_array1, query_boxes_wh, query_boxes_angle], dtype=dtype, axis=-1) + + cpu_input1 = torch.from_numpy(boxes) + cpu_input2 = torch.from_numpy(query_boxes) + npu_input1 = cpu_input1.npu() + npu_input2 = cpu_input2.npu() + list1 = [boxes, query_boxes, npu_input1, npu_input2] + return list1 + + def cpu_expect_result(self, dtype): + if dtype == np.float32: + output = np.array([[[0., 0.00045966, 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0.00600622, 0.10504241, 0.]], + [[0., 0., 0.], [0., 0., 0.]]], dtype=np.float32) + else: + output = np.array([[[0., 0.00045966, 0.], [0., 0., 0.]], + [[0., 0., 0.], [0., 0., 0.]], + [[0., 0., 0.], [0.00600622, 0.10504241, 0.]], + [[0., 0., 0.], [0., 0., 0.]]], dtype=np.float16) + return output + + def npu_op_exec(self, box1, box2, trans=False): + output = ads.common.npu_rotated_iou(box1, box2, trans, 0, True, 0.0, 0.0) + output = output.detach().cpu().numpy() + return output + + def test_rotated_iou_shape_format_fp32(self): + dtype = np.float32 + shape_format = [[dtype, -1, [4, 2, 5]], [dtype, -1, [4, 3, 5]], False] + list2 = self.generate_rto_data(shape_format) + cpu_output = self.cpu_expect_result(dtype) + npu_output = self.npu_op_exec(list2[2], list2[3], shape_format[-1]) + self.assertRtolEqual(cpu_output, npu_output) + + def test_rotated_iou_shape_format_fp16(self): + dtype = np.float16 + shape_format = [[dtype, -1, [4, 2, 5]], [dtype, -1, [4, 3, 5]], False] + list1 = self.generate_rto_data(shape_format) + cpu_output = self.cpu_expect_result(dtype) + npu_output = self.npu_op_exec(list1[2], list1[3], shape_format[-1]) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_rotated_overlaps.py b/tests/test_rotated_overlaps.py new file mode 100644 index 00000000..63d561d8 --- /dev/null +++ b/tests/test_rotated_overlaps.py @@ -0,0 +1,84 @@ +import torch +import numpy as np +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestRotatedOverlaps(TestCase): + def generate_rto_data(self, item): + np.random.seed(1234) + min_value, max_value = 30, 60 + scope = 20 + dtype = item[0][0] + shape_one = item[0][-1] + shape_two = item[1][-1] + + boxes_center = np.random.uniform(min_value, max_value, shape_one[:2] + [2]).astype(dtype) + boxes_wh = np.random.randint(1, scope, size=shape_one[:2] + [2]) + boxes_angle = np.random.randint(-180, 180, size=shape_one[:2] + [1]) + boxes = np.concatenate([boxes_center, boxes_wh, boxes_angle], axis=-1, dtype=dtype) + # query_boxes + query_boxes_center = np.random.uniform(min_value, max_value, shape_two[:2] + [2]).astype(dtype) + query_boxes_wh = np.random.randint(1, scope, size=shape_two[:2] + [2]) + query_boxes_angle = np.random.randint(-180, 180, size=shape_two[:2] + [1]) + query_boxes = np.concatenate([query_boxes_center, query_boxes_wh, query_boxes_angle], axis=-1, dtype=dtype) + + cpu_input1 = torch.from_numpy(boxes) + cpu_input2 = torch.from_numpy(query_boxes) + npu_input1 = cpu_input1.npu() + npu_input2 = cpu_input2.npu() + return npu_input1, npu_input2 + + def cpu_expect_result(self, dtype): + if dtype == np.float16: + output = np.array([[[0., 13.27, 1.022, 0.], + [0., 0., 54.12, 0.], + [0., 0., 0., 19.17]]], dtype=np.float16) + else: + output = np.array([[[0., 10.289731], + [0., 0.], + [0., 0.]]], dtype=np.float32) + return output + + def npu_op_exec(self, box1, box2, trans=False): + output = ads.common.npu_rotated_overlaps(box1, box2, trans) + output = output.detach().cpu().numpy() + return output + + def test_rotated_overlaps_shape_format_fp32(self, device="npu"): + dtype = np.float32 + shape_list = [ + [[1, 3, 5], [1, 2, 5]], + ] + is_trans_list = [False] + shape_format = [[[dtype, -1, m[0]], [dtype, -1, m[1]], k] + for m in shape_list + for k in is_trans_list] + + for item in shape_format: + npu_input1, npu_input2 = self.generate_rto_data(item[:-1]) + cpu_output = self.cpu_expect_result(dtype) + npu_output = self.npu_op_exec(npu_input1, npu_input2, item[-1]) + # fp32 has't enough precission, but match model need currently. + self.assertRtolEqual(cpu_output, npu_output, prec=0.00005) + + def test_rotated_overlaps_shape_format_fp16(self, device="npu"): + dtype = np.float16 + shape_list = [ + [[1, 3, 5], [1, 4, 5]], + ] + # true is xyxyt, false is xywh format + is_trans_list = [False] + shape_format = [[[dtype, -1, m[0]], [dtype, -1, m[1]], k] + for m in shape_list + for k in is_trans_list] + for item in shape_format: + npu_input1, npu_input2 = self.generate_rto_data(item) + cpu_output = self.cpu_expect_result(dtype) + npu_output = self.npu_op_exec(npu_input1, npu_input2, item[-1]) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_sign_bits_pack.py b/tests/test_sign_bits_pack.py new file mode 100644 index 00000000..4a33b905 --- /dev/null +++ b/tests/test_sign_bits_pack.py @@ -0,0 +1,42 @@ +import torch +import numpy as np +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestLess(TestCase): + def cpu_op_exec(self, input1, size): + sign_data = np.sign(input1) + sign_data = sign_data + 1 + bool_data = np.bool_(sign_data) + pack_bit = np.packbits(bool_data, bitorder="little") + return pack_bit.reshape(size, pack_bit.shape[0] // size) + + def npu_op_exec(self, input1, size): + output = ads.common.npu_sign_bits_pack(input1, size) + output = output.to("cpu") + output = output.numpy() + return output + + def test_sign_bits_pack(self): + shape_format = [ + [[np.float16, (17,)], 1], + [[np.float32, (8,)], 1], + [[np.float32, (32,)], 1], + [[np.float32, (33,)], 1], + [[np.float32, (16,)], 2] + ] + + for item in shape_format: + input1 = np.random.uniform(-10, 10, item[0][1]).astype(item[0][0]) + npu_input1 = torch.from_numpy(input1).npu() + cpu_output = self.cpu_op_exec(input1, item[1]) + npu_output = self.npu_op_exec(npu_input1, item[1]) + + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_sign_bits_unpack.py b/tests/test_sign_bits_unpack.py new file mode 100644 index 00000000..07b9c1d2 --- /dev/null +++ b/tests/test_sign_bits_unpack.py @@ -0,0 +1,45 @@ +import torch +import numpy as np + +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestSignBitsUnpack(TestCase): + + def custom_sign_unpack(self, input_data, size, dtype): + bits = 8 + mask = 2**torch.arange(bits).to(input_data.device, input_data.dtype) + unpack_data = input_data.unsqueeze(-1).bitwise_and(mask).ne(0).byte().reshape(-1).to(dtype) + unpack_data = (unpack_data - 0.5) * 2.0 + return unpack_data.reshape(size, unpack_data.shape[0] // size) + + def custom_op_exec(self, input_data, dtype, size): + output = self.custom_sign_unpack(input_data, size, dtype) + return output.cpu().numpy() + + def npu_op_exec(self, npu_input, dtype, size): + nup_out = ads.common.npu_sign_bits_unpack(npu_input, size, dtype) + return nup_out.cpu().numpy() + + def test_sign_bits_unpack(self): + shape = np.random.uniform(1, 10**5, 1) + shape = shape // (10 ** int(np.random.uniform(0, int(np.log10(shape) + 1), 1))) + shape = max(int(shape), 1) + size = int(np.random.uniform(1, 100)) + shape = shape * size + + shape_format = [np.uint8, 2, [shape]] + cpu_input, npu_input = create_common_tensor(shape_format, 0, 255) + dtypes = [torch.float16, torch.float32] + for dtype in dtypes: + cpu_output = self.custom_op_exec(npu_input, dtype, size) + npu_output = self.npu_op_exec(npu_input, dtype, size) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_yolo_boxes_encode.py b/tests/test_yolo_boxes_encode.py new file mode 100644 index 00000000..70a037db --- /dev/null +++ b/tests/test_yolo_boxes_encode.py @@ -0,0 +1,35 @@ +import torch +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestYoloBoxesEncode(TestCase): + def npu_op_exec(self, anchor_boxes, gt_bboxes, stride, impl_mode=False): + out = ads.common.npu_yolo_boxes_encode(anchor_boxes, gt_bboxes, stride, impl_mode) + out = out.to("cpu") + return out.detach().numpy() + + def test_yolo_boxes_encode(self, device="npu"): + torch.manual_seed(1234) + anchor_boxes_list = [(2, 4)] + gt_bboxes_list = [(2, 4)] + stride_list = [[2, 2]] + expect_cpu_list = [[[0.7921727, 0.5314963, -0.74224466, -13.815511], + [0.7360072, 0.58343244, 4.3334002, -0.51378196]]] + + shape_format = [[i, j, k, h] for i in anchor_boxes_list + for j in gt_bboxes_list for k in stride_list for h in expect_cpu_list] + + for item in shape_format: + anchor_boxes_tensor = torch.rand(item[0], dtype=torch.float32).to("npu") + gt_bboxes_tensor = torch.rand(item[1], dtype=torch.float32).to("npu") + stride_tensor = torch.tensor(item[2], dtype=torch.int32).to("npu") + expect_cpu_tensor = torch.tensor(item[3], dtype=torch.float32) + npu_output = self.npu_op_exec(anchor_boxes_tensor, gt_bboxes_tensor, stride_tensor, False) + + self.assertRtolEqual(expect_cpu_tensor.numpy(), npu_output) + + +if __name__ == "__main__": + run_tests() -- Gitee