From adb291acf65f575570c43ec0d19258214cc60cf7 Mon Sep 17 00:00:00 2001 From: zhanhao Date: Wed, 13 Dec 2023 14:30:53 +0800 Subject: [PATCH 1/3] add op --- ads/common/__init__.py | 10 +- ads/common/ops/csrc/AbsOpApi.cpp | 13 + .../ops/csrc/AnchorResponseFlagsKernelNpu.cpp | 72 +++ ads/common/ops/csrc/BatchNms.cpp | 49 ++ .../ops/csrc/BoundingBoxDecodeKernelNpu.cpp | 57 ++ .../ops/csrc/BoundingBoxEncodeKernelNpu.cpp | 51 ++ ads/common/ops/csrc/BroadCastKernelNpu.cpp | 49 ++ .../ops/csrc/ConfusionTransposeKernelNpu.cpp | 97 +++ ads/common/ops/csrc/FastGeluKernelNpu.cpp | 52 ++ ads/common/ops/csrc/NpuSilu.cpp | 13 +- ads/common/ops/csrc/OpApiCommon.h | 588 ++++++++++++++++++ ads/common/ops/csrc/RotaryMulKernelNpu.cpp | 10 +- .../ops/csrc/RotatedBoxDecodeKernelNpu.cpp | 5 - .../ops/csrc/RotatedBoxEncodeKernelNpu.cpp | 5 - ads/common/ops/csrc/RotatedIouKernelNpu.cpp | 6 - .../ops/csrc/RotatedOverlapsKernelNpu.cpp | 6 - ads/common/ops/csrc/ScatterMaxKernelNpu.cpp | 16 +- ads/common/ops/csrc/ScatterV1KernelNpu.cpp | 4 - ads/common/ops/csrc/SignBitsPackKernelNpu.cpp | 4 - .../ops/csrc/SignBitsUnpackKernelNpu.cpp | 4 - ...SoftmaxCrossEntropyWithLogitsKernelNpu.cpp | 9 - ads/common/ops/csrc/StrideAddKernelNpu.cpp | 5 - ads/common/ops/csrc/TransposeKernelNpu.cpp | 5 - .../ops/csrc/YoloBoxesEncodeKernelNpu.cpp | 5 - ads/common/ops/csrc/common.cpp | 5 +- ads/common/ops/csrc/common.h | 6 +- ads/common/ops/csrc/functions.h | 78 ++- ads/common/ops/csrc/pybind.cpp | 25 + ads/common/ops/fast_gelu.py | 23 + ads/common/ops/npu_abs.py | 5 + ads/common/ops/npu_anchor_response_flags.py | 13 + ads/common/ops/npu_batch_nms.py | 30 + ads/common/ops/npu_bounding_box_decode.py | 20 + ads/common/ops/npu_bounding_box_encode.py | 18 + ads/common/ops/npu_broadcast.py | 16 + ads/common/ops/npu_confusion_transpose.py | 24 + ads/common/ops/rotary_mul.py | 2 +- ads/common/ops/rotated_iou.py | 2 +- ads/common/ops/rotated_overlaps.py | 2 +- ads/common/ops/scatter.py | 2 +- ads/common/ops/sign_bits_pack.py | 2 +- ads/common/ops/sign_bits_unpack.py | 2 +- ads/common/ops/silu.py | 4 +- .../ops/softmax_cross_entropy_with_logits.py | 2 +- ads/common/ops/stride_add.py | 2 +- ads/common/ops/transpose.py | 2 +- ads/common/ops/yolo_boxes_encode.py | 2 +- setup.py | 6 + tests/test_abs.py | 51 ++ tests/test_batch_nms.py | 44 ++ tests/test_fast_gelu.py | 51 ++ tests/test_fast_gelu_backward.py | 43 ++ tests/test_npu_anchor_response_flags.py | 60 ++ tests/test_npu_bounding_box_decode.py | 107 ++++ tests/test_npu_bounding_box_encode.py | 92 +++ tests/test_npu_broadcast.py | 48 ++ 56 files changed, 1814 insertions(+), 110 deletions(-) create mode 100644 ads/common/ops/csrc/AbsOpApi.cpp create mode 100644 ads/common/ops/csrc/AnchorResponseFlagsKernelNpu.cpp create mode 100644 ads/common/ops/csrc/BatchNms.cpp create mode 100644 ads/common/ops/csrc/BoundingBoxDecodeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/BoundingBoxEncodeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/BroadCastKernelNpu.cpp create mode 100644 ads/common/ops/csrc/ConfusionTransposeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/FastGeluKernelNpu.cpp create mode 100644 ads/common/ops/csrc/OpApiCommon.h create mode 100644 ads/common/ops/fast_gelu.py create mode 100644 ads/common/ops/npu_abs.py create mode 100644 ads/common/ops/npu_anchor_response_flags.py create mode 100644 ads/common/ops/npu_batch_nms.py create mode 100644 ads/common/ops/npu_bounding_box_decode.py create mode 100644 ads/common/ops/npu_bounding_box_encode.py create mode 100644 ads/common/ops/npu_broadcast.py create mode 100644 ads/common/ops/npu_confusion_transpose.py create mode 100644 tests/test_abs.py create mode 100644 tests/test_batch_nms.py create mode 100644 tests/test_fast_gelu.py create mode 100644 tests/test_fast_gelu_backward.py create mode 100644 tests/test_npu_anchor_response_flags.py create mode 100644 tests/test_npu_bounding_box_decode.py create mode 100644 tests/test_npu_bounding_box_encode.py create mode 100644 tests/test_npu_broadcast.py diff --git a/ads/common/__init__.py b/ads/common/__init__.py index ec5a1f3a..8e59829b 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -12,4 +12,12 @@ from .ops.yolo_boxes_encode import npu_yolo_boxes_encode from .ops.scatter import npu_scatter from .ops.silu import npu_silu from .ops.silu import npu_silu_ -from .ops.rotary_mul import npu_rotary_mul \ No newline at end of file +from .ops.rotary_mul import npu_rotary_mul +from .ops.npu_abs import npu_abs +from .ops.fast_gelu import fast_gelu +from .ops.npu_anchor_response_flags import npu_anchor_response_flags +from .ops.npu_bounding_box_decode import npu_bounding_box_decode +from .ops.npu_bounding_box_encode import npu_bounding_box_encode +from .ops.npu_batch_nms import npu_batch_nms +from .ops.npu_confusion_transpose import npu_confusion_transpose +from .ops.npu_broadcast import npu_broadcast diff --git a/ads/common/ops/csrc/AbsOpApi.cpp b/ads/common/ops/csrc/AbsOpApi.cpp new file mode 100644 index 00000000..824b60ce --- /dev/null +++ b/ads/common/ops/csrc/AbsOpApi.cpp @@ -0,0 +1,13 @@ +#include +#include "OpApiCommon.h" +#include "functions.h" + +at::Tensor npu_abs(const at::Tensor& self) +{ + // construct the output tensor of the NPU + at::Tensor result = at::empty(self.sizes(), self.options()); + + // calculate the output result of the NPU + EXEC_NPU_CMD(aclnnAbs, self, result); + return result; +} diff --git a/ads/common/ops/csrc/AnchorResponseFlagsKernelNpu.cpp b/ads/common/ops/csrc/AnchorResponseFlagsKernelNpu.cpp new file mode 100644 index 00000000..f414633c --- /dev/null +++ b/ads/common/ops/csrc/AnchorResponseFlagsKernelNpu.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +namespace { +c10::SmallVector infersize_npu_anchor_response_flags( + at::IntArrayRef featmap_size, + int64_t num_base_anchors) +{ + int64_t output_value = featmap_size[0] * featmap_size[1] * num_base_anchors; + c10::SmallVector output_size = {output_value}; + return output_size; +} + +inline void anchor_response_flags_check( + const at::Tensor& self, + at::IntArrayRef featmap_size, + at::IntArrayRef stride) +{ + TORCH_CHECK( + featmap_size.size() == 2, + "expected feat_map_size equals to 2, but got size ", + featmap_size.size()); + TORCH_CHECK( + self.dim() == 2 && self.size(1) == 4, + "Non-empty 2D gt_bboxes tensor expected but got a tensor with sizes ", + self.sizes()); + TORCH_CHECK( + self.scalar_type() == at::kHalf || self.scalar_type() == at::kFloat, + "float16 or float32 tensor expected but got a tensor with dtype: ", + self.scalar_type()); +} +} // namespace + +at::Tensor npu_anchor_response_flags( + const at::Tensor& self, + at::IntArrayRef featmap_size, + at::IntArrayRef stride, + int64_t num_base_anchors) +{ + anchor_response_flags_check(self, featmap_size, stride); + auto output_size = infersize_npu_anchor_response_flags(featmap_size, num_base_anchors); + auto options = self.options().dtype(at::kByte); + at::Tensor result = at::empty(output_size, options); + + at::Tensor self_cp = self.to(at::kFloat); + + at_npu::native::OpCommand cmd; + cmd.Name("AnchorResponseFlags") + .Input(self_cp) + .Output(result) + .Attr("featmap_size", featmap_size) + .Attr("strides", stride) + .Attr("num_base_anchors", num_base_anchors) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/BatchNms.cpp b/ads/common/ops/csrc/BatchNms.cpp new file mode 100644 index 00000000..a7051437 --- /dev/null +++ b/ads/common/ops/csrc/BatchNms.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +std::tuple npu_batch_nms( + const at::Tensor& self, + const at::Tensor& scores, + double score_threshold, + double iou_threshold, + int64_t max_size_per_class, + int64_t max_total_size, + bool change_coordinate_frame, + bool transpose_box) +{ + at::Tensor nmsed_boxes = at::empty({self.size(0), max_total_size, 4}, self.options()); + at::Tensor nmsed_scores = at::empty({self.size(0), max_total_size}, self.options()); + at::Tensor nmsed_classes = at::empty({self.size(0), max_total_size}, self.options()); + at::Tensor nmsed_num = at::empty({self.size(0)}, self.options().dtype(at::kInt)); + at_npu::native::OpCommand cmd; + cmd.Name("BatchMultiClassNonMaxSuppression") + .Input(self) + .Input(scores) + .Output(nmsed_boxes) + .Output(nmsed_scores) + .Output(nmsed_classes) + .Output(nmsed_num) + .Attr("score_threshold", static_cast(score_threshold)) + .Attr("iou_threshold", static_cast(iou_threshold)) + .Attr("max_size_per_class", max_size_per_class) + .Attr("max_total_size", max_total_size) + .Attr("change_coordinate_frame", change_coordinate_frame) + .Attr("transpose_box", transpose_box) + .Run(); + return std::tie(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num); +} diff --git a/ads/common/ops/csrc/BoundingBoxDecodeKernelNpu.cpp b/ads/common/ops/csrc/BoundingBoxDecodeKernelNpu.cpp new file mode 100644 index 00000000..85fc0764 --- /dev/null +++ b/ads/common/ops/csrc/BoundingBoxDecodeKernelNpu.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +at::Tensor npu_bounding_box_decode( + const at::Tensor& rois, + const at::Tensor& deltas, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3, + at::IntArrayRef max_shape, + double wh_ratio_clip) +{ + c10::SmallVector output_size = {rois.size(0), 4}; + at::Tensor result = at::empty(output_size, rois.options()); + c10::SmallVector means = { + static_cast(means0), + static_cast(means1), + static_cast(means2), + static_cast(means3)}; + c10::SmallVector stds = { + static_cast(stds0), + static_cast(stds1), + static_cast(stds2), + static_cast(stds3)}; + at_npu::native::OpCommand cmd; + cmd.Name("BoundingBoxDecode") + .Input(rois) + .Input(deltas) + .Output(result) + .Attr("means", means) + .Attr("stds", stds) + .Attr("max_shape", max_shape) + .Attr("wh_ratio_clip", static_cast(wh_ratio_clip)) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/BoundingBoxEncodeKernelNpu.cpp b/ads/common/ops/csrc/BoundingBoxEncodeKernelNpu.cpp new file mode 100644 index 00000000..aa5bad77 --- /dev/null +++ b/ads/common/ops/csrc/BoundingBoxEncodeKernelNpu.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +at::Tensor npu_bounding_box_encode( + const at::Tensor& anchor_box, + const at::Tensor& ground_truth_box, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3) +{ + at::Tensor result = at::empty({anchor_box.size(0), 4}, anchor_box.options()); + c10::SmallVector means = { + static_cast(means0), + static_cast(means1), + static_cast(means2), + static_cast(means3)}; + c10::SmallVector stds = { + static_cast(stds0), + static_cast(stds1), + static_cast(stds2), + static_cast(stds3)}; + at_npu::native::OpCommand cmd; + cmd.Name("BoundingBoxEncode") + .Input(anchor_box) + .Input(ground_truth_box) + .Output(result) + .Attr("means", means) + .Attr("stds", stds) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/BroadCastKernelNpu.cpp b/ads/common/ops/csrc/BroadCastKernelNpu.cpp new file mode 100644 index 00000000..b4858034 --- /dev/null +++ b/ads/common/ops/csrc/BroadCastKernelNpu.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" + + +namespace { +at::Tensor& npu_broadcast_out_nocheck(at::Tensor& result, const at::Tensor& self, at::IntArrayRef size) +{ + at_npu::native::OpCommand cmd; + cmd.Name("BroadcastTo") + .Input(self) + .Input(size) + .Output(result) + .Run(); + return result; +} +} // namespace + +at::Tensor& npu_broadcast_out(const at::Tensor& self, at::IntArrayRef size, at::Tensor& result) +{ + npu_broadcast_out_nocheck(result, self, size); + + return result; +} + +at::Tensor npu_broadcast(const at::Tensor& self, at::IntArrayRef size) +{ + at::Tensor self_cp = self.dtype() == at::kBool ? self.to(at::kInt) : self; + at::Tensor result = at::empty(size, self_cp.options()); + npu_broadcast_out_nocheck(result, self_cp, size); + + if (self.dtype() == at::kBool) { + result = result.to(at::kBool); + } + return result; +} diff --git a/ads/common/ops/csrc/ConfusionTransposeKernelNpu.cpp b/ads/common/ops/csrc/ConfusionTransposeKernelNpu.cpp new file mode 100644 index 00000000..a12d1d7c --- /dev/null +++ b/ads/common/ops/csrc/ConfusionTransposeKernelNpu.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +at::Tensor npu_confusion_transpose( + const at::Tensor& self, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first) +{ + c10::SmallVector output_size; + if (transpose_first) { + output_size = array_to_small_vector(shape); + } else { + auto shape_size = shape.size(); + for (uint i = 0; i < perm.size(); i++) { + TORCH_CHECK(shape_size > perm[i], "npu_confusion_transpose input invalid, " + "shape has size ", + shape_size, " but perm[i] is, ", perm[i]); + output_size.emplace_back(shape[perm[i]]); + } + } + + at::Tensor result = at::empty(output_size, self.options()); + at_npu::native::OpCommand cmd; + cmd.Name("ConfusionTransposeD") + .Input(self) + .Output(result) + .Attr("perm", perm) + .Attr("shape", shape) + .Attr("transpose_first", transpose_first) + .Run(); + + return result; +} + +void check_confusion_transpose_perm(at::IntArrayRef perm, at::IntArrayRef shape) +{ + auto input_dim = shape.size(); + TORCH_CHECK(perm.size() == input_dim, "The length of perm should be the same as shape."); + std::vector seen(input_dim); + for (const auto i : c10::irange(input_dim)) { + auto dim = at::maybe_wrap_dim(perm[i], input_dim); + TORCH_CHECK(!seen[dim], "Repeated dim in perm"); + seen[dim] = true; + } +} + +at::Tensor npu_confusion_transpose_backward( + const at::Tensor& grad, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first) +{ + c10::SmallVector svec_shape; + if (transpose_first) { + svec_shape = array_to_small_vector(shape); + } else { + check_confusion_transpose_perm(perm, shape); + for (int i = 0; i < perm.size(); i++) { + svec_shape.emplace_back(shape[perm[i]]); + } + } + std::vector vec_perm; + int64_t perm_len = perm.size(); + int64_t temp_perm[perm_len] = {0}; + for (int64_t i = 0; i < perm_len; i++) { + temp_perm[perm[i]] = i; + } + vec_perm = std::vector(temp_perm, temp_perm+perm_len); + perm = at::IntArrayRef(vec_perm); + at::Tensor result = at::empty(shape, grad.options()); + + at_npu::native::OpCommand cmd; + cmd.Name("ConfusionTransposeD") + .Input(grad) + .Output(result) + .Attr("perm", perm) + .Attr("shape", svec_shape) + .Attr("transpose_first", transpose_first) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/FastGeluKernelNpu.cpp b/ads/common/ops/csrc/FastGeluKernelNpu.cpp new file mode 100644 index 00000000..a56dcb7a --- /dev/null +++ b/ads/common/ops/csrc/FastGeluKernelNpu.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" + +namespace { +at::Tensor& fast_gelu_backward_npu_nocheck( + at::Tensor& grad_input, + const at::Tensor& grad, + const at::Tensor& self) +{ + at_npu::native::OpCommand cmd; + cmd.Name("FastGeluGrad") + .Input(grad) + .Input(self) + .Output(grad_input) + .Run(); + return grad_input; +} +} // namespace + +at::Tensor npu_fast_gelu(const at::Tensor& self) +{ + at::Tensor result = at::empty(self.sizes(), self.options()); + at_npu::native::OpCommand cmd; + cmd.Name("FastGelu") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor npu_fast_gelu_backward( + const at::Tensor& grad, + const at::Tensor& self) +{ + at::Tensor grad_input = at::empty(self.sizes(), self.options()); + fast_gelu_backward_npu_nocheck(grad_input, grad, self); + return grad_input; +} diff --git a/ads/common/ops/csrc/NpuSilu.cpp b/ads/common/ops/csrc/NpuSilu.cpp index 0b81c0f4..4f62cad5 100644 --- a/ads/common/ops/csrc/NpuSilu.cpp +++ b/ads/common/ops/csrc/NpuSilu.cpp @@ -1,18 +1,7 @@ -#include - #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using torch::autograd::AutogradContext; -using torch::autograd::Function; -using npu_preparation = at_npu::native::OpPreparation; -using npu_utils = at_npu::native::NpuUtils; -using tensor_list = std::vector; - at::Tensor &silu_out_npu_nocheck(at::Tensor &result, const at::Tensor &self) { at_npu::native::OpCommand cmd; @@ -80,4 +69,4 @@ at::Tensor &npu_silu_(at::Tensor &self) { silu_out_npu(self, self); return self; -} \ No newline at end of file +} diff --git a/ads/common/ops/csrc/OpApiCommon.h b/ads/common/ops/csrc/OpApiCommon.h new file mode 100644 index 00000000..717543ac --- /dev/null +++ b/ads/common/ops/csrc/OpApiCommon.h @@ -0,0 +1,588 @@ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/interface/EnvVariables.h" +#include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" + +#define NPU_NAME_SPACE at_npu::native + +typedef struct aclOpExecutor aclOpExecutor; +typedef struct aclTensor aclTensor; +typedef struct aclScalar aclScalar; +typedef struct aclIntArray aclIntArray; +typedef struct aclFloatArray aclFloatArray; +typedef struct aclBoolArray aclBoolArray; +typedef struct aclTensorList aclTensorList; + +typedef aclTensor *(*_aclCreateTensor)(const int64_t *view_dims, uint64_t view_dims_num, aclDataType data_type, + const int64_t *stride, int64_t offset, aclFormat format, const int64_t *storage_dims, uint64_t storage_dims_num, + void *tensor_data); +typedef aclScalar *(*_aclCreateScalar)(void *value, aclDataType data_type); +typedef aclIntArray *(*_aclCreateIntArray)(const int64_t *value, uint64_t size); +typedef aclFloatArray *(*_aclCreateFloatArray)(const float *value, uint64_t size); +typedef aclBoolArray *(*_aclCreateBoolArray)(const bool *value, uint64_t size); +typedef aclTensorList *(*_aclCreateTensorList)(const aclTensor *const *value, uint64_t size); + +typedef int (*_aclDestroyTensor)(const aclTensor *tensor); +typedef int (*_aclDestroyScalar)(const aclScalar *scalar); +typedef int (*_aclDestroyIntArray)(const aclIntArray *array); +typedef int (*_aclDestroyFloatArray)(const aclFloatArray *array); +typedef int (*_aclDestroyBoolArray)(const aclBoolArray *array); +typedef int (*_aclDestroyTensorList)(const aclTensorList *array); + +constexpr int kHashBufSize = 8192; +constexpr int kHashBufMaxSize = kHashBufSize + 1024; +extern thread_local char g_hashBuf[kHashBufSize]; +extern thread_local int g_hashOffset; + +#define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ + _(at::ScalarType::Byte, ACL_UINT8) \ + _(at::ScalarType::Char, ACL_INT8) \ + _(at::ScalarType::Short, ACL_INT16) \ + _(at::ScalarType::Int, ACL_INT32) \ + _(at::ScalarType::Long, ACL_INT64) \ + _(at::ScalarType::Half, ACL_FLOAT16) \ + _(at::ScalarType::Float, ACL_FLOAT) \ + _(at::ScalarType::Double, ACL_DOUBLE) \ + _(at::ScalarType::ComplexHalf, ACL_DT_UNDEFINED) \ + _(at::ScalarType::ComplexFloat, ACL_COMPLEX64) \ + _(at::ScalarType::ComplexDouble, ACL_COMPLEX128) \ + _(at::ScalarType::Bool, ACL_BOOL) \ + _(at::ScalarType::QInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt8, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QInt32, ACL_DT_UNDEFINED) \ + _(at::ScalarType::BFloat16, ACL_BF16) \ + _(at::ScalarType::QUInt4x2, ACL_DT_UNDEFINED) \ + _(at::ScalarType::QUInt2x4, ACL_DT_UNDEFINED) \ + _(at::ScalarType::Undefined, ACL_DT_UNDEFINED) \ + _(at::ScalarType::NumOptions, ACL_DT_UNDEFINED) + +constexpr aclDataType kATenScalarTypeToAclDataTypeTable[static_cast(at::ScalarType::NumOptions) + 1] = { +#define DEFINE_ENUM(_1, n) n, + AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(DEFINE_ENUM) +#undef DEFINE_ENUM +}; + +#define GET_OP_API_FUNC(apiName) reinterpret_cast<_##apiName>(GetOpApiFuncAddr(#apiName)) + +#define MEMCPY_TO_BUF(data_expression, size_expression) \ + if (g_hashOffset + (size_expression) > kHashBufSize) { \ + g_hashOffset = kHashBufMaxSize; \ + return; \ + } \ + memcpy(g_hashBuf + g_hashOffset, data_expression, size_expression); \ + g_hashOffset += size_expression; + +inline const char *GetOpApiLibName(void) +{ + return "libopapi.so"; +} + +inline const char *GetCustOpApiLibName(void) +{ + return "libcust_opapi.so"; +} + +inline void *GetOpApiFuncAddrInLib(void *handler, const char *libName, const char *apiName) +{ + auto funcAddr = dlsym(handler, apiName); + if (funcAddr == nullptr) { + ASCEND_LOGW("dlsym %s from %s failed, error:%s.", apiName, libName, dlerror()); + } + return funcAddr; +} + +inline void *GetOpApiLibHandler(const char *libName) +{ + auto handler = dlopen(libName, RTLD_LAZY); + if (handler == nullptr) { + ASCEND_LOGW("dlopen %s failed, error:%s.", libName, dlerror()); + } + return handler; +} + +inline void *GetOpApiFuncAddr(const char *apiName) +{ + static auto custOpApiHandler = GetOpApiLibHandler(GetCustOpApiLibName()); + if (custOpApiHandler != nullptr) { + auto funcAddr = GetOpApiFuncAddrInLib(custOpApiHandler, GetCustOpApiLibName(), apiName); + if (funcAddr != nullptr) { + return funcAddr; + } + } + + static auto opApiHandler = GetOpApiLibHandler(GetOpApiLibName()); + if (opApiHandler == nullptr) { + return nullptr; + } + return GetOpApiFuncAddrInLib(opApiHandler, GetOpApiLibName(), apiName); +} + +inline c10::Scalar ConvertTensorToScalar(const at::Tensor &tensor) +{ + c10::Scalar expScalar; + const at::Tensor *aclInput = &tensor; + if (aclInput->scalar_type() == at::ScalarType::Double) { + double value = *(double *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Long) { + int64_t value = *(int64_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Float) { + float value = *(float *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Int) { + int value = *(int *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Half) { + c10::Half value = *(c10::Half *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::Bool) { + int8_t value = *(int8_t *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexDouble) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::ComplexFloat) { + c10::complex value = *(c10::complex *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } else if (aclInput->scalar_type() == at::ScalarType::BFloat16) { + c10::BFloat16 value = *(c10::BFloat16 *)aclInput->data_ptr(); + c10::Scalar scalar(value); + expScalar = scalar; + } + return expScalar; +} + +inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) +{ + at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); + int deviceIndex = 0; + return cpuPinMemTensor.to( + c10::Device(at_npu::key::NativeDeviceType, deviceIndex), cpuPinMemTensor.scalar_type(), true, true); +} + +inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type) +{ + return CopyTensorHostToDevice(scalar_to_tensor(cpu_scalar).to(scalar_data_type)); +} + +inline aclTensor *ConvertType(const at::Tensor &at_tensor) +{ + static const auto aclCreateTensor = GET_OP_API_FUNC(aclCreateTensor); + if (aclCreateTensor == nullptr) { + return nullptr; + } + + if (!at_tensor.defined()) { + return nullptr; + } + at::ScalarType scalar_data_type = at_tensor.scalar_type(); + aclDataType acl_data_type = kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK( + acl_data_type != ACL_DT_UNDEFINED, std::string(c10::toString(scalar_data_type)) + " has not been supported") + c10::SmallVector storageDims; + // if acl_data_type is ACL_STRING, storageDims is empty. + auto itemsize = at_tensor.itemsize(); + if (itemsize == 0) { + AT_ERROR("When ConvertType, tensor item size of cannot be zero."); + return nullptr; + } + if (acl_data_type != ACL_STRING) { + storageDims.push_back(at_tensor.storage().nbytes() / itemsize); + } + + const auto dimNum = at_tensor.sizes().size(); + aclFormat format = ACL_FORMAT_ND; + switch (dimNum) { + case 3: + format = ACL_FORMAT_NCL; + break; + case 4: + format = ACL_FORMAT_NCHW; + break; + case 5: + format = ACL_FORMAT_NCDHW; + break; + default: + format = ACL_FORMAT_ND; + } + + if (at_tensor.unsafeGetTensorImpl()->is_wrapped_number()) { + c10::Scalar expScalar = ConvertTensorToScalar(at_tensor); + at::Tensor aclInput = CopyScalarToDevice(expScalar, scalar_data_type); + return aclCreateTensor(aclInput.sizes().data(), + aclInput.sizes().size(), + acl_data_type, + aclInput.strides().data(), + aclInput.storage_offset(), + format, + storageDims.data(), + storageDims.size(), + const_cast(aclInput.storage().data())); + } + + auto acl_tensor = aclCreateTensor(at_tensor.sizes().data(), + at_tensor.sizes().size(), + acl_data_type, + at_tensor.strides().data(), + at_tensor.storage_offset(), + format, + storageDims.data(), + storageDims.size(), + const_cast(at_tensor.storage().data())); + return acl_tensor; +} + +inline aclScalar *ConvertType(const at::Scalar &at_scalar) +{ + static const auto aclCreateScalar = GET_OP_API_FUNC(aclCreateScalar); + if (aclCreateScalar == nullptr) { + return nullptr; + } + + at::ScalarType scalar_data_type = at_scalar.type(); + aclDataType acl_data_type = kATenScalarTypeToAclDataTypeTable[static_cast(scalar_data_type)]; + TORCH_CHECK( + acl_data_type != ACL_DT_UNDEFINED, std::string(c10::toString(scalar_data_type)) + " has not been supported") + aclScalar *acl_scalar = nullptr; + switch (scalar_data_type) { + case at::ScalarType::Double: { + double value = at_scalar.toDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Long: { + int64_t value = at_scalar.toLong(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::Bool: { + bool value = at_scalar.toBool(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + case at::ScalarType::ComplexDouble: { + auto value = at_scalar.toComplexDouble(); + acl_scalar = aclCreateScalar(&value, acl_data_type); + break; + } + default: + acl_scalar = nullptr; + break; + } + return acl_scalar; +} + +inline aclIntArray *ConvertType(const at::IntArrayRef &at_array) +{ + static const auto aclCreateIntArray = GET_OP_API_FUNC(aclCreateIntArray); + if (aclCreateIntArray == nullptr) { + return nullptr; + } + auto array = aclCreateIntArray(at_array.data(), at_array.size()); + return array; +} + +template +inline aclBoolArray *ConvertType(const std::array &value) +{ + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclBoolArray *ConvertType(const at::ArrayRef &value) +{ + static const auto aclCreateBoolArray = GET_OP_API_FUNC(aclCreateBoolArray); + if (aclCreateBoolArray == nullptr) { + return nullptr; + } + + auto array = aclCreateBoolArray(value.data(), value.size()); + return array; +} + +inline aclTensorList *ConvertType(const at::TensorList &at_tensor_list) +{ + static const auto aclCreateTensorList = GET_OP_API_FUNC(aclCreateTensorList); + if (aclCreateTensorList == nullptr) { + return nullptr; + } + + std::vector tensor_list(at_tensor_list.size()); + for (size_t i = 0; i < at_tensor_list.size(); i++) { + tensor_list[i] = ConvertType(at_tensor_list[i]); + } + auto acl_tensor_list = aclCreateTensorList(tensor_list.data(), tensor_list.size()); + return acl_tensor_list; +} + +inline aclTensor *ConvertType(const c10::optional &opt_tensor) +{ + if (opt_tensor.has_value() && opt_tensor.value().defined()) { + return ConvertType(opt_tensor.value()); + } + return nullptr; +} + +inline aclIntArray *ConvertType(const c10::optional &opt_array) +{ + if (opt_array.has_value()) { + return ConvertType(opt_array.value()); + } + return nullptr; +} + +inline aclScalar *ConvertType(const c10::optional &opt_scalar) +{ + if (opt_scalar.has_value()) { + return ConvertType(opt_scalar.value()); + } + return nullptr; +} + +inline aclDataType ConvertType(const at::ScalarType scalarType) +{ + return kATenScalarTypeToAclDataTypeTable[static_cast(scalarType)]; +} + +template +T ConvertType(T value) +{ + return value; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr, std::index_sequence) +{ + typedef int (*OpApiFunc)(typename std::decay(params))>::type...); + auto func = reinterpret_cast(opApiAddr); + return func; +} + +template +auto ConvertToOpApiFunc(const Tuple ¶ms, void *opApiAddr) +{ + static constexpr auto size = std::tuple_size::value; + return ConvertToOpApiFunc(params, opApiAddr, std::make_index_sequence{}); +} + +inline void Release(aclTensor *p) +{ + static const auto aclDestroyTensor = GET_OP_API_FUNC(aclDestroyTensor); + if (aclDestroyTensor == nullptr) { + return; + } + aclDestroyTensor(p); +} + +inline void Release(aclScalar *p) +{ + static const auto aclDestroyScalar = GET_OP_API_FUNC(aclDestroyScalar); + if (aclDestroyScalar == nullptr) { + return; + } + aclDestroyScalar(p); +} + +inline void Release(aclIntArray *p) +{ + static const auto aclDestroyIntArray = GET_OP_API_FUNC(aclDestroyIntArray); + if (aclDestroyIntArray == nullptr) { + return; + } + + aclDestroyIntArray(p); +} + +inline void Release(aclBoolArray *p) +{ + static const auto aclDestroyBoolArray = GET_OP_API_FUNC(aclDestroyBoolArray); + if (aclDestroyBoolArray == nullptr) { + return; + } + + aclDestroyBoolArray(p); +} + +inline void Release(aclTensorList *p) +{ + static const auto aclDestroyTensorList = GET_OP_API_FUNC(aclDestroyTensorList); + if (aclDestroyTensorList == nullptr) { + return; + } + + aclDestroyTensorList(p); +} + +template +void Release(T value) +{ + (void)value; +} + +template +void CallRelease(Tuple t, std::index_sequence) +{ + (void)std::initializer_list{(Release(std::get(t)), 0)...}; +} + +template +void ReleaseConvertTypes(Tuple &t) +{ + static constexpr auto size = std::tuple_size::value; + CallRelease(t, std::make_index_sequence{}); +} + +template +constexpr auto ConvertTypes(Ts &...args) +{ + return std::make_tuple(ConvertType(args)...); +} + +template +auto call(Function f, Tuple t, std::index_sequence) +{ + return f(std::get(t)...); +} + +template +auto call(Function f, Tuple t) +{ + static constexpr auto size = std::tuple_size::value; + return call(f, t, std::make_index_sequence{}); +} + +template +void AddParamToBuf(const std::array &value) +{ + MEMCPY_TO_BUF(value.data(), value.size() * sizeof(bool)); +} + +template +void AddParamToBuf(const T &value) +{ + MEMCPY_TO_BUF(&value, sizeof(T)); +} + +void AddParamToBuf(const at::Tensor &); +void AddParamToBuf(const at::Scalar &); +void AddParamToBuf(const at::IntArrayRef &); +void AddParamToBuf(const at::ArrayRef &); +void AddParamToBuf(const at::TensorList &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const c10::optional &); +void AddParamToBuf(const at::ScalarType); +void AddParamToBuf(const string &); +void AddParamToBuf(); + +template +void AddParamToBuf(const T &arg, Args &...args) +{ + AddParamToBuf(arg); + AddParamToBuf(args...); +} + +uint64_t CalcHashId(); +typedef int (*InitHugeMemThreadLocal)(void *, bool); +typedef void (*UnInitHugeMemThreadLocal)(void *, bool); +typedef void (*ReleaseHugeMem)(void *, bool); + +#define DO_COMPATIBILITY(aclnn_api, originCallExpression) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + if (getWorkspaceSizeFuncAddr == nullptr || opApiFuncAddr == nullptr) { \ + ASCEND_LOGW("%s or %sGetWorkspaceSize not in %s, or %s not found. Will call %s", \ + #aclnn_api, \ + #aclnn_api, \ + GetOpApiLibName(), \ + GetOpApiLibName(), \ + #originCallExpression); \ + return originCallExpression; \ + } \ + } while (0) + +#define EXEC_NPU_CMD(aclnn_api, ...) \ + do { \ + static const auto getWorkspaceSizeFuncAddr = GetOpApiFuncAddr(#aclnn_api "GetWorkspaceSize"); \ + static const auto opApiFuncAddr = GetOpApiFuncAddr(#aclnn_api); \ + static const auto initMemAddr = GetOpApiFuncAddr("InitHugeMemThreadLocal"); \ + static const auto unInitMemAddr = GetOpApiFuncAddr("UnInitHugeMemThreadLocal"); \ + static const auto releaseMemAddr = GetOpApiFuncAddr("ReleaseHugeMem"); \ + TORCH_CHECK(getWorkspaceSizeFuncAddr != nullptr && opApiFuncAddr != nullptr, \ + #aclnn_api, \ + " or ", \ + #aclnn_api "GetWorkspaceSize", \ + " not in ", \ + GetOpApiLibName(), \ + ", or ", \ + GetOpApiLibName(), \ + "not found."); \ + auto acl_stream = c10_npu::getCurrentNPUStream().stream(false); \ + uint64_t workspace_size = 0; \ + uint64_t *workspace_size_addr = &workspace_size; \ + aclOpExecutor *executor = nullptr; \ + aclOpExecutor **executor_addr = &executor; \ + InitHugeMemThreadLocal initMemFunc = reinterpret_cast(initMemAddr); \ + UnInitHugeMemThreadLocal unInitMemFunc = reinterpret_cast(unInitMemAddr); \ + if (initMemFunc) { \ + initMemFunc(nullptr, false); \ + } \ + auto converted_params = ConvertTypes(__VA_ARGS__, workspace_size_addr, executor_addr); \ + static auto getWorkspaceSizeFunc = ConvertToOpApiFunc(converted_params, getWorkspaceSizeFuncAddr); \ + auto workspace_status = call(getWorkspaceSizeFunc, converted_params); \ + TORCH_CHECK(workspace_status == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + void *workspace_addr = nullptr; \ + if (workspace_size != 0) { \ + at::TensorOptions options = at::TensorOptions(torch_npu::utils::get_npu_device_type()); \ + auto workspace_tensor = at::empty({workspace_size}, options.dtype(at::kByte)); \ + workspace_addr = const_cast(workspace_tensor.storage().data()); \ + } \ + auto acl_call = [converted_params, workspace_addr, workspace_size, acl_stream, executor]()->int { \ + typedef int (*OpApiFunc)(void *, uint64_t, aclOpExecutor *, const aclrtStream); \ + OpApiFunc opApiFunc = reinterpret_cast(opApiFuncAddr); \ + auto api_ret = opApiFunc(workspace_addr, workspace_size, executor, acl_stream); \ + TORCH_CHECK(api_ret == 0, "call " #aclnn_api " failed, detail:", aclGetRecentErrMsg()); \ + ReleaseConvertTypes(converted_params); \ + ReleaseHugeMem releaseMemFunc = reinterpret_cast(releaseMemAddr); \ + if (releaseMemFunc) { \ + releaseMemFunc(nullptr, false); \ + } \ + return api_ret; \ + }; \ + at_npu::native::OpCommand cmd; \ + cmd.Name(#aclnn_api); \ + cmd.SetCustomHandler(acl_call); \ + cmd.Run(); \ + if (unInitMemFunc) { \ + unInitMemFunc(nullptr, false); \ + } \ + } while (false) diff --git a/ads/common/ops/csrc/RotaryMulKernelNpu.cpp b/ads/common/ops/csrc/RotaryMulKernelNpu.cpp index 3e814e4e..05569309 100644 --- a/ads/common/ops/csrc/RotaryMulKernelNpu.cpp +++ b/ads/common/ops/csrc/RotaryMulKernelNpu.cpp @@ -14,18 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/aten/CustomFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; -using torch::autograd::Function; -using torch::autograd::AutogradContext; using tensor_tuple = std::tuple; namespace { @@ -42,7 +34,7 @@ at::Tensor &rotary_mul_nocheck(at::Tensor &y, const at::Tensor &x, const at::Ten return y; } -tensor_tuple rotary_mul_backward_nocheck(at::Tensor &dx, at::Tensor &dr1, at::Tensor &dr2, const at::Tensor &x, +tensor_tuple rotary_mul_backward_nocheck(at::Tensor &dx, at::Tensor &dr1, at::Tensor &dr2, const at::Tensor &x, const at::Tensor &r1, const at::Tensor &r2, const at::Tensor &dy) { TORCH_CHECK(x.dim() == 4, "The dim of input tensor [x] shoule equal to four."); diff --git a/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp b/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp index db949fc9..0e8aa592 100644 --- a/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp +++ b/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp @@ -15,13 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; -using npu_utils = at_npu::native::NpuUtils; at::Tensor npu_rotated_box_decode(const at::Tensor &self, const at::Tensor &deltas, const at::Tensor &weight) { diff --git a/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp b/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp index cfe515ba..865b994b 100644 --- a/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp +++ b/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp @@ -15,13 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; - at::Tensor npu_rotated_box_encode( const at::Tensor &self, const at::Tensor >Box, diff --git a/ads/common/ops/csrc/RotatedIouKernelNpu.cpp b/ads/common/ops/csrc/RotatedIouKernelNpu.cpp index 7c8c334a..dc94943d 100644 --- a/ads/common/ops/csrc/RotatedIouKernelNpu.cpp +++ b/ads/common/ops/csrc/RotatedIouKernelNpu.cpp @@ -15,14 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/aten/CustomFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; - namespace { at::Tensor &rotated_iou_npu_nocheck( at::Tensor &iou, diff --git a/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp b/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp index ac476c62..0a957ca9 100644 --- a/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp +++ b/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp @@ -15,14 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/aten/CustomFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; - namespace { at::Tensor &rotated_overlaps_npu_nocheck( at::Tensor &overlaps, diff --git a/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp b/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp index c4e94384..f3b11664 100644 --- a/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp +++ b/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp @@ -1,16 +1,8 @@ -#include - #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "common.h" -using namespace at; using namespace std; -using torch::autograd::Function; -using torch::autograd::AutogradContext; -using tensor_list = std::vector; - std::tuple npu_scatter_max( const at::Tensor& updates, const at::Tensor& indices, @@ -21,7 +13,7 @@ std::tuple npu_scatter_max( sizes[0] = indices.max().item().toLong() + 1; at::Tensor result = out.value_or(at::zeros(sizes, updates.options().dtype(at::kFloat))); - at::Tensor argmax = at_npu::native::OpPreparation::ApplyTensor(result, result.options().dtype(at::kInt)); + at::Tensor argmax = at::empty(result.sizes(), result.options().dtype(at::kInt)); at_npu::native::OpCommand cmd; cmd.Name("ScatterMaxWithArgmax") @@ -37,7 +29,7 @@ std::tuple npu_scatter_max( at::Tensor npu_scatter_max_backward(const at::Tensor& x, const at::Tensor& segment_ids, const at::Tensor& num_segments) { - c10::SmallVector output_size; + c10::SmallVector output_size; auto num_segments_value = num_segments.item().toLong(); output_size.push_back(num_segments_value); @@ -47,7 +39,7 @@ at::Tensor npu_scatter_max_backward(const at::Tensor& x, const at::Tensor& segme copy(x_sizes.begin() + segment_ids_dims, x_sizes.end(), std::back_inserter(output_size)); - at::Tensor out = at_npu::native::OpPreparation::ApplyTensor(x, output_size); + at::Tensor out = at::empty(output_size, x.options()); at_npu::native::OpCommand cmd; cmd.Name("UnsortedSegmentSum") .Input(x) diff --git a/ads/common/ops/csrc/ScatterV1KernelNpu.cpp b/ads/common/ops/csrc/ScatterV1KernelNpu.cpp index 155ea383..f96d4608 100644 --- a/ads/common/ops/csrc/ScatterV1KernelNpu.cpp +++ b/ads/common/ops/csrc/ScatterV1KernelNpu.cpp @@ -15,12 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; at::Tensor npu_scatter(const at::Tensor &self, const at::Tensor &indices, const at::Tensor &updates, int64_t dim) { diff --git a/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp b/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp index 5fb1139c..95f4c3ff 100644 --- a/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp +++ b/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp @@ -15,12 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; at::Tensor npu_sign_bits_pack(const at::Tensor &self, int64_t size) { diff --git a/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp b/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp index e7a35680..27ae440b 100644 --- a/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp +++ b/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp @@ -15,13 +15,9 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; at::Tensor npu_sign_bits_unpack_compute( const at::Tensor &input, diff --git a/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp b/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp index e936c819..cc8f95df 100644 --- a/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp +++ b/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp @@ -14,19 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; -using torch::autograd::AutogradContext; -using torch::autograd::Function; -using tensor_list = std::vector; - namespace { std::tuple softmax_cross_entropy_with_logits_out_nocheck( at::Tensor &result, diff --git a/ads/common/ops/csrc/StrideAddKernelNpu.cpp b/ads/common/ops/csrc/StrideAddKernelNpu.cpp index ebcfbfda..47922f62 100644 --- a/ads/common/ops/csrc/StrideAddKernelNpu.cpp +++ b/ads/common/ops/csrc/StrideAddKernelNpu.cpp @@ -14,15 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; namespace { at::Tensor &stride_add_out_npu_nocheck( diff --git a/ads/common/ops/csrc/TransposeKernelNpu.cpp b/ads/common/ops/csrc/TransposeKernelNpu.cpp index ad9d2e97..2e8705c2 100644 --- a/ads/common/ops/csrc/TransposeKernelNpu.cpp +++ b/ads/common/ops/csrc/TransposeKernelNpu.cpp @@ -15,14 +15,9 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using npu_utils = at_npu::native::NpuUtils; - namespace { at::Tensor &npu_transpose_out_nocheck( at::Tensor &result, diff --git a/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp b/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp index f3cd4201..df02a325 100644 --- a/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp +++ b/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp @@ -15,14 +15,9 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/aten/CustomFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; namespace { inline void yolo_boxes_encode_check( diff --git a/ads/common/ops/csrc/common.cpp b/ads/common/ops/csrc/common.cpp index 8e4b3037..f6f9cc49 100644 --- a/ads/common/ops/csrc/common.cpp +++ b/ads/common/ops/csrc/common.cpp @@ -1,12 +1,9 @@ #include #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/aten/mirror/NPUMemoryOverlap.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "third_party/acl/inc/acl/acl_base.h" #include "common.h" -using npu_utils = at_npu::native::NpuUtils; using CalcuOpUtil = at_npu::native::CalcuOpUtil; #define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ @@ -192,4 +189,4 @@ bool check_match(const at::Tensor &self) void format_fresh_view(at::Tensor &x, const at::Tensor &y) { x.copy_(y); -} \ No newline at end of file +} diff --git a/ads/common/ops/csrc/common.h b/ads/common/ops/csrc/common.h index 49a75653..95c2b5a1 100644 --- a/ads/common/ops/csrc/common.h +++ b/ads/common/ops/csrc/common.h @@ -1,3 +1,5 @@ +#ifndef __COMMON_H__ +#define __COMMON_H__ #include #include #include @@ -26,4 +28,6 @@ c10::SmallVector convert_array_to_vector(c10::IntArrayRef intArray); c10::SmallVector infersize_stride_add(c10::IntArrayRef shape1_, c10::IntArrayRef shape2_); c10::SmallVector transpose_npu_output_size(const at::Tensor &self, c10::IntArrayRef perm); bool check_match(const at::Tensor &self); -void format_fresh_view(at::Tensor &x, const at::Tensor &y); \ No newline at end of file +void format_fresh_view(at::Tensor &x, const at::Tensor &y); + +#endif // __COMMON_H__ diff --git a/ads/common/ops/csrc/functions.h b/ads/common/ops/csrc/functions.h index daa82f4d..1e9ab8d9 100644 --- a/ads/common/ops/csrc/functions.h +++ b/ads/common/ops/csrc/functions.h @@ -1,3 +1,19 @@ +// Copyright (c) 2023, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef __FUNCTIONS_H__ +#define __FUNCTIONS_H__ + #include #include #include @@ -41,4 +57,64 @@ at::Tensor npu_yolo_boxes_encode( at::Tensor npu_scatter(const at::Tensor& self, const at::Tensor& indices, const at::Tensor& updates, int64_t dim); at::Tensor npu_rotary_mul(const at::Tensor &self, const at::Tensor &r1, const at::Tensor &r2); at::Tensor npu_silu(const at::Tensor& self); -at::Tensor& npu_silu_(at::Tensor& self); \ No newline at end of file +at::Tensor& npu_silu_(at::Tensor& self); +at::Tensor npu_fast_gelu_backward(const at::Tensor& grad, const at::Tensor& self); +at::Tensor npu_abs(const at::Tensor& self); +at::Tensor npu_fast_gelu(const at::Tensor& self); +at::Tensor npu_anchor_response_flags(const at::Tensor& self, at::IntArrayRef featmap_size, at::IntArrayRef stride, int64_t num_base_anchors); +at::Tensor npu_bounding_box_decode( + const at::Tensor& rois, + const at::Tensor& deltas, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3, + at::IntArrayRef max_shape, + double wh_ratio_clip); +at::Tensor npu_bounding_box_encode( + const at::Tensor& anchor_box, + const at::Tensor& ground_truth_box, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3); +std::tuple npu_batch_nms( + const at::Tensor& self, + const at::Tensor& scores, + double score_threshold, + double iou_threshold, + int64_t max_size_per_class, + int64_t max_total_size, + bool change_coordinate_frame, + bool transpose_box); +at::Tensor npu_confusion_transpose( + const at::Tensor& self, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first); +at::Tensor npu_confusion_transpose_backward( + const at::Tensor& grad, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first); +at::Tensor npu_conv_transpose2d( + const at::Tensor& input, + const at::Tensor& weight, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups); +at::Tensor npu_broadcast(const at::Tensor& self, at::IntArrayRef size); +at::Tensor& npu_broadcast_out(const at::Tensor& self, at::IntArrayRef size, at::Tensor& result); + +#endif // __FUNCTIONS_H__ diff --git a/ads/common/ops/csrc/pybind.cpp b/ads/common/ops/csrc/pybind.cpp index c4383ac8..b8ebe3f5 100644 --- a/ads/common/ops/csrc/pybind.cpp +++ b/ads/common/ops/csrc/pybind.cpp @@ -41,4 +41,29 @@ void init_common(pybind11::module &m) // rotary mul m.def("npu_rotary_mul", &npu_rotary_mul); + + m.def("npu_abs", &npu_abs); + + // npu_fast_gelu + m.def("npu_fast_gelu", &npu_fast_gelu); + m.def("npu_fast_gelu_backward", &npu_fast_gelu_backward); + + // npu_anchor_response_flags + m.def("npu_anchor_response_flags", &npu_anchor_response_flags); + + // npu_bounding_box_decode + m.def("npu_bounding_box_decode", &npu_bounding_box_decode); + + // npu_bounding_box_encode + m.def("npu_bounding_box_encode", &npu_bounding_box_encode); + + // npu_batch_nms + m.def("npu_batch_nms", &npu_batch_nms); + + // npu_confusion_transpose + m.def("npu_confusion_transpose", &npu_confusion_transpose); + m.def("npu_confusion_transpose_backward", &npu_confusion_transpose_backward); + + // npu_broadcast + m.def("npu_broadcast", &npu_broadcast); } diff --git a/ads/common/ops/fast_gelu.py b/ads/common/ops/fast_gelu.py new file mode 100644 index 00000000..45557513 --- /dev/null +++ b/ads/common/ops/fast_gelu.py @@ -0,0 +1,23 @@ +import torch +from torch.autograd import Function + +import torch_npu +import ads_c + + +class FastGeluFunction(Function): + @staticmethod + def forward(ctx, self): + out = ads_c.npu_fast_gelu(self) + ctx.save_for_backward(self) + return out + + @staticmethod + def backward(ctx, grad_output): + self = ctx.saved_tensors[0] + + grad = ads_c.npu_fast_gelu_backward(grad_output, self) + + return grad + +fast_gelu = FastGeluFunction.apply diff --git a/ads/common/ops/npu_abs.py b/ads/common/ops/npu_abs.py new file mode 100644 index 00000000..62f02b84 --- /dev/null +++ b/ads/common/ops/npu_abs.py @@ -0,0 +1,5 @@ +import torch +import torch_npu +import ads_c + +npu_abs = ads_c.npu_abs diff --git a/ads/common/ops/npu_anchor_response_flags.py b/ads/common/ops/npu_anchor_response_flags.py new file mode 100644 index 00000000..b75fd77c --- /dev/null +++ b/ads/common/ops/npu_anchor_response_flags.py @@ -0,0 +1,13 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class NpuAnchorResponseFlagsFunction(Function): + @staticmethod + def forward(ctx, self, featmap_size, stride, num_base_anchors): + result = ads_c.npu_anchor_response_flags(self, featmap_size, stride, num_base_anchors) + return result + +npu_anchor_response_flags = NpuAnchorResponseFlagsFunction.apply diff --git a/ads/common/ops/npu_batch_nms.py b/ads/common/ops/npu_batch_nms.py new file mode 100644 index 00000000..4a5b2ef9 --- /dev/null +++ b/ads/common/ops/npu_batch_nms.py @@ -0,0 +1,30 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class NpuBatchNmsFunction(Function): + @staticmethod + def forward( + ctx, + self, + scores, + score_threshold, + iou_threshold, + max_size_per_class, + max_total_size, + change_coordinate_frame=False, + transpose_box=False): + result = ads_c.npu_batch_nms( + self, + scores, + score_threshold, + iou_threshold, + max_size_per_class, + max_total_size, + change_coordinate_frame, + transpose_box) + return result + +npu_batch_nms = NpuBatchNmsFunction.apply diff --git a/ads/common/ops/npu_bounding_box_decode.py b/ads/common/ops/npu_bounding_box_decode.py new file mode 100644 index 00000000..67099d20 --- /dev/null +++ b/ads/common/ops/npu_bounding_box_decode.py @@ -0,0 +1,20 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class NpuBoundingBodDecodeFunction(Function): + @staticmethod + def forward(ctx, rois, deltas, + means0, means1, means2, means3, + stds0, tds1, stds2, stds3, + max_shape, wh_ratio_clip): + result = ads_c.npu_bounding_box_decode( + rois, deltas, + means0, means1, means2, means3, + stds0, stds1, stds2, stds3, + max_shape, wh_ratio_clip) + return result + +npu_bounding_box_decode = NpuBoundingBodDecodeFunction.apply diff --git a/ads/common/ops/npu_bounding_box_encode.py b/ads/common/ops/npu_bounding_box_encode.py new file mode 100644 index 00000000..6efac4dd --- /dev/null +++ b/ads/common/ops/npu_bounding_box_encode.py @@ -0,0 +1,18 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class NpuBoundingBodEncodeFunction(Function): + @staticmethod + def forward(ctx, anchor_box, ground_truth_box, + means0, means1, means2, means3, + stds0, tds1, stds2, stds3): + result = ads_c.npu_bounding_box_encode( + anchor_box, ground_truth_box, + means0, means1, means2, means3, + stds0, stds1, stds2, stds3) + return result + +npu_bounding_box_encode = NpuBoundingBodEncodeFunction.apply \ No newline at end of file diff --git a/ads/common/ops/npu_broadcast.py b/ads/common/ops/npu_broadcast.py new file mode 100644 index 00000000..b3371b28 --- /dev/null +++ b/ads/common/ops/npu_broadcast.py @@ -0,0 +1,16 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class BroadCastlFunction(Function): + @staticmethod + def forward(ctx, self, size, out=None): + if out is None: + result = ads_c.npu_broadcast(self, size) + else: + result = ads_c.npu_broadcast_out(self, size, out) + return result + +npu_broadcast = BroadCastlFunction.apply \ No newline at end of file diff --git a/ads/common/ops/npu_confusion_transpose.py b/ads/common/ops/npu_confusion_transpose.py new file mode 100644 index 00000000..566d19f1 --- /dev/null +++ b/ads/common/ops/npu_confusion_transpose.py @@ -0,0 +1,24 @@ +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads_c + + +class NpuConfusionTransposeFunction(Function): + @staticmethod + def forward(ctx, self, perm, shape, transpose_first): + out = ads_c.npu_confusion_transpose(self, perm, shape, transpose_first) + ctx.save_for_backward(perm, self.sizes(), transpose_first) + + return out + + @staticmethod + def backward(ctx, grad_output): + perm, sefl_sizes, transpose_first = ctx.saved_tensors + out = ads_c.npu_confusion_transpose_backward(grad_output, perm, sefl_sizes, not transpose_first) + + return out, None, None, None + +npu_confusion_transpose = NpuConfusionTransposeFunction.apply diff --git a/ads/common/ops/rotary_mul.py b/ads/common/ops/rotary_mul.py index bf9c2a9b..5079961b 100644 --- a/ads/common/ops/rotary_mul.py +++ b/ads/common/ops/rotary_mul.py @@ -19,4 +19,4 @@ class RotaryMulFunction(Function): result = ads_c.npu_rotary_mul_backward(grad_output, input, r1, r2) return result -npu_rotary_mul = RotaryMulFunction.apply \ No newline at end of file +npu_rotary_mul = RotaryMulFunction.apply diff --git a/ads/common/ops/rotated_iou.py b/ads/common/ops/rotated_iou.py index d88d3e9b..896f001c 100644 --- a/ads/common/ops/rotated_iou.py +++ b/ads/common/ops/rotated_iou.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_rotated_iou = ads_c.npu_rotated_iou \ No newline at end of file +npu_rotated_iou = ads_c.npu_rotated_iou diff --git a/ads/common/ops/rotated_overlaps.py b/ads/common/ops/rotated_overlaps.py index 40753235..c481fe6f 100644 --- a/ads/common/ops/rotated_overlaps.py +++ b/ads/common/ops/rotated_overlaps.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_rotated_overlaps = ads_c.npu_rotated_overlaps \ No newline at end of file +npu_rotated_overlaps = ads_c.npu_rotated_overlaps diff --git a/ads/common/ops/scatter.py b/ads/common/ops/scatter.py index 7d89109c..d9e6de8a 100644 --- a/ads/common/ops/scatter.py +++ b/ads/common/ops/scatter.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_scatter = ads_c.npu_scatter \ No newline at end of file +npu_scatter = ads_c.npu_scatter diff --git a/ads/common/ops/sign_bits_pack.py b/ads/common/ops/sign_bits_pack.py index c09d486a..7b1e0040 100644 --- a/ads/common/ops/sign_bits_pack.py +++ b/ads/common/ops/sign_bits_pack.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_sign_bits_pack = ads_c.npu_sign_bits_pack \ No newline at end of file +npu_sign_bits_pack = ads_c.npu_sign_bits_pack diff --git a/ads/common/ops/sign_bits_unpack.py b/ads/common/ops/sign_bits_unpack.py index efa1a2dd..ed374e17 100644 --- a/ads/common/ops/sign_bits_unpack.py +++ b/ads/common/ops/sign_bits_unpack.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_sign_bits_unpack = ads_c.npu_sign_bits_unpack \ No newline at end of file +npu_sign_bits_unpack = ads_c.npu_sign_bits_unpack diff --git a/ads/common/ops/silu.py b/ads/common/ops/silu.py index 8ca866db..bd4251b0 100644 --- a/ads/common/ops/silu.py +++ b/ads/common/ops/silu.py @@ -13,7 +13,7 @@ class SiluFunction(Function): result = func(input) ctx.save_for_backward(input, result) return result - + @staticmethod def backward(ctx, grad_outputs): x0, x1 = ctx.saved_tensors @@ -22,4 +22,4 @@ class SiluFunction(Function): npu_silu = SiluFunction.apply -npu_silu_ = ads_c.npu_silu_ \ No newline at end of file +npu_silu_ = ads_c.npu_silu_ diff --git a/ads/common/ops/softmax_cross_entropy_with_logits.py b/ads/common/ops/softmax_cross_entropy_with_logits.py index f09d2a3e..cd12c5dd 100644 --- a/ads/common/ops/softmax_cross_entropy_with_logits.py +++ b/ads/common/ops/softmax_cross_entropy_with_logits.py @@ -20,4 +20,4 @@ class SoftMaxFunction(Function): result = ads_c.npu_softmax_cross_entropy_with_logits_backward(grad_output, feature, labels) return result -npu_softmax_cross_entropy_with_logits = SoftMaxFunction.apply \ No newline at end of file +npu_softmax_cross_entropy_with_logits = SoftMaxFunction.apply diff --git a/ads/common/ops/stride_add.py b/ads/common/ops/stride_add.py index 24a3946b..586a83c3 100644 --- a/ads/common/ops/stride_add.py +++ b/ads/common/ops/stride_add.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_stride_add = ads_c.npu_stride_add \ No newline at end of file +npu_stride_add = ads_c.npu_stride_add diff --git a/ads/common/ops/transpose.py b/ads/common/ops/transpose.py index 14972299..a27dca7d 100644 --- a/ads/common/ops/transpose.py +++ b/ads/common/ops/transpose.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_transpose = ads_c.npu_transpose \ No newline at end of file +npu_transpose = ads_c.npu_transpose diff --git a/ads/common/ops/yolo_boxes_encode.py b/ads/common/ops/yolo_boxes_encode.py index 585adb58..cb915a0f 100644 --- a/ads/common/ops/yolo_boxes_encode.py +++ b/ads/common/ops/yolo_boxes_encode.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_yolo_boxes_encode = ads_c.npu_yolo_boxes_encode \ No newline at end of file +npu_yolo_boxes_encode = ads_c.npu_yolo_boxes_encode diff --git a/setup.py b/setup.py index 2ff09279..13484c88 100644 --- a/setup.py +++ b/setup.py @@ -8,10 +8,16 @@ source_file = [] source_file += glob.glob(os.path.join("./ads/common/ops/csrc/", "*.cpp")) source_file += glob.glob(os.path.join("./bind/", "*.cpp")) +torch_npu_dir = extension.PYTORCH_NPU_INSTALL_PATH +include_dirs = [] +include_dirs.append(torch_npu_dir + "/include/third_party/acl/inc/") + exts = [] ext1 = extension.NpuExtension( name="ads_c", sources=source_file, + include_dirs=include_dirs, + extra_compile_args=['-D__FILENAME__=\"$$(notdir $$(abspath $$<))\"'], ) exts.append(ext1) diff --git a/tests/test_abs.py b/tests/test_abs.py new file mode 100644 index 00000000..a8aa5960 --- /dev/null +++ b/tests/test_abs.py @@ -0,0 +1,51 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestAbs(TestCase): + def cpu_op_exec(self, input1): + output = torch.abs(input1) + output = output.numpy() + return output + + def npu_op_exec(self, input1): + output = ads.common.npu_abs(input1) + output = output.to("cpu") + output = output.numpy() + return output + + def test_abs_shape_format_fp16(self, device="npu"): + format_list = [0, 3] + shape_list = [[5]] + shape_format = [ + [np.float16, i, j] for i in format_list for j in shape_list + ] + for item in shape_format: + cpu_input, npu_input = create_common_tensor(item, -10, 10) + cpu_input = cpu_input.to(torch.float32) + cpu_output = self.cpu_op_exec(cpu_input) + npu_output = self.npu_op_exec(npu_input) + cpu_output = cpu_output.astype(np.float16) + self.assertRtolEqual(cpu_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_batch_nms.py b/tests/test_batch_nms.py new file mode 100644 index 00000000..11c3245c --- /dev/null +++ b/tests/test_batch_nms.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TesBatchNms(TestCase): + def test_batch_nms_shape_format(self): + boxes = torch.randn(8, 4, 1, 4).npu() + scores = torch.randn(8, 4, 1).npu() + boxes_fp16 = boxes.half() + scores_fp16 = scores.half() + nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = ads.common.npu_batch_nms(boxes, scores, 0.3, 0.5, 4, 4) + boxes1, scores1, classes1, num1 = ads.common.npu_batch_nms(boxes_fp16, scores_fp16, 0.3, 0.5, 4, 4) + expedt_nmsed_classes = torch.tensor([[0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000]], dtype=torch.float32) + self.assertRtolEqual(expedt_nmsed_classes, nmsed_classes.cpu()) + self.assertRtolEqual(expedt_nmsed_classes.half(), classes1.cpu()) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_fast_gelu.py b/tests/test_fast_gelu.py new file mode 100644 index 00000000..c81b5e92 --- /dev/null +++ b/tests/test_fast_gelu.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestFastGelu(TestCase): + + def supported_op_exec(self, input1): + attr = 1.702 + attr_half = attr / 2 + abs_input1 = torch.abs(input1) + numerator = input1 * \ + torch.exp((attr_half * input1) * (input1 - abs_input1)) + denominator = 1.0 + torch.exp(- attr * abs_input1) + output = numerator / denominator + return output.cpu().detach() + + def custom_op_exec(self, input1): + output = ads.common.fast_gelu(input1) + return output.cpu().detach() + + def test_fast_gelu(self, device="npu"): + item = [np.float32, 0, [3, 16, 32]] + _, npu_input = create_common_tensor(item, 0, 100) + + supported_output = self.supported_op_exec(npu_input) + custom_output = self.custom_op_exec(npu_input) + self.assertRtolEqual(supported_output, custom_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_fast_gelu_backward.py b/tests/test_fast_gelu_backward.py new file mode 100644 index 00000000..1c7920ef --- /dev/null +++ b/tests/test_fast_gelu_backward.py @@ -0,0 +1,43 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestFastGelu(TestCase): + def npu_op_exec(self, input1): + input1.requires_grad = True + output = ads.common.fast_gelu(input1) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.cpu().detach().numpy() + return output_grad, output + + def test_fastgelu(self, device="npu"): + input1 = torch.tensor([1., 2., 3., 4.]).npu() + exoutputgrad = torch.tensor([1.0677795, 1.0738151, 1.0245483, 1.0064018]) + exoutput = torch.tensor([0.8458, 1.9357, 2.9819, 3.9956]) + outputgrad, output = self.npu_op_exec(input1) + self.assertRtolEqual(exoutputgrad.numpy(), outputgrad) + self.assertRtolEqual(exoutput.numpy(), output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_anchor_response_flags.py b/tests/test_npu_anchor_response_flags.py new file mode 100644 index 00000000..b656ee94 --- /dev/null +++ b/tests/test_npu_anchor_response_flags.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestNpuAnchorResponseFlags(TestCase): + def custom_op_exec(self, gt_bboxes, featmap_size, strides, num_base_anchors): + if gt_bboxes.dtype == torch.float16: + gt_bboxes = gt_bboxes.to(torch.float32) + feat_h, feat_w = featmap_size + gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5) + gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5) + gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / strides[0]).int() + gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / strides[1]).int() + gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x + responsible_grid = torch.zeros(feat_h * feat_w, dtype=torch.uint8).npu() + gt_bboxes_grid_idx = gt_bboxes_grid_idx.long() + responsible_grid[gt_bboxes_grid_idx] = 1 + responsible_grid = responsible_grid[:, None].expand( + responsible_grid.size(0), num_base_anchors).contiguous().view(-1) + return responsible_grid.cpu().numpy() + + def npu_op_exec(self, input_npu, featmap_size, strides, num_base_anchors): + out = ads.common.npu_anchor_response_flags(input_npu, featmap_size, strides, num_base_anchors) + out = out.cpu().numpy() + return out + + def test_npu_anchor_response_flags(self): + shape_format = [ + [[np.float32, -1, [100, 4]], [60, 60], [2, 2], 9], + [[np.float16, -1, [200, 4]], [10, 10], [32, 32], 3], + [[np.float16, -1, [500, 4]], [32, 32], [16, 16], 5] + ] + for item in shape_format: + _, npu_input = create_common_tensor(item[0], 0, 100) + custom_output = self.custom_op_exec(npu_input, *item[1:]) + npu_output = self.npu_op_exec(npu_input, *item[1:]) + self.assertRtolEqual(custom_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_bounding_box_decode.py b/tests/test_npu_bounding_box_decode.py new file mode 100644 index 00000000..248fe36c --- /dev/null +++ b/tests/test_npu_bounding_box_decode.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestBoundingBoxDecode(TestCase): + def npu_bounding_box_decode(self, rois, deltas, means0, means1, means2, means3, + stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip): + means = [means0, means1, means2, means3] + stds = [stds0, stds1, stds2, stds3] + means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) + stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) + denorm_deltas = deltas * stds + means + + dx = denorm_deltas[:, 0::4] + dy = denorm_deltas[:, 1::4] + dw = denorm_deltas[:, 2::4] + dh = denorm_deltas[:, 3::4] + max_ratio = torch.abs(torch.log(torch.tensor(wh_ratio_clip))) + + dw = torch.clamp(dw, min=-max_ratio, max=max_ratio) + dh = torch.clamp(dh, min=-max_ratio, max=max_ratio) + + ax = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) + ay = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) + aw = (rois[:, 2] - rois[:, 0] * 0.5).unsqueeze(1).expand_as(dw) + ah = (rois[:, 3] - rois[:, 1] * 0.5).unsqueeze(1).expand_as(dh) + + pw = aw * dw.exp() + ph = ah * dh.exp() + px = torch.addcmul(ax, 1, aw, dx) + py = torch.addcmul(ay, 1, ah, dy) + + x1 = px - pw * 0.5 + 0.5 + y1 = py - ph * 0.5 + 0.5 + x2 = px + pw * 0.5 - 0.5 + y2 = py + ph * 0.5 - 0.5 + + if max_shape is not None: + x1 = torch.clamp(x1, min=0, max=(max_shape[1] - 1)) + y1 = torch.clamp(y1, min=0, max=(max_shape[0] - 1)) + x2 = torch.clamp(x2, min=0, max=(max_shape[1] - 1)) + y2 = torch.clamp(y2, min=0, max=(max_shape[0] - 1)) + boxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) + return boxes + + def custom_op_exec(self, rois, deltas, means0, means1, means2, means3, + stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip): + output = self.npu_bounding_box_decode(rois, deltas, means0, means1, + means2, means3, stds0, stds1, + stds2, stds3, max_shape, wh_ratio_clip) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec(self, rois, deltas, means0, means1, means2, means3, + stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip): + output = ads.common.npu_bounding_box_decode(rois, deltas, means0, means1, + means2, means3, stds0, stds1, + stds2, stds3, max_shape, wh_ratio_clip) + output = output.to("cpu") + output = output.numpy() + return output + + def test_decode_shape_format_fp32(self): + input1 = torch.tensor([[1., 2., 3., 4.], [3., 4., 5., 6.]], + dtype=torch.float32).to("npu") + input2 = torch.tensor([[5., 6., 7., 8.], [7., 8., 9., 6.]], + dtype=torch.float32).to("npu") + + npu_output = self.npu_op_exec(input1, input2, 0, 0, 0, 0, + 1, 1, 1, 1, (10, 10), 0.1) + custom_output = self.custom_op_exec(input1, input2, 0, 0, 0, 0, + 1, 1, 1, 1, (10, 10), 0.1) + self.assertRtolEqual(npu_output, custom_output) + + def test_decode_shape_format_fp16(self): + input1_fp16 = torch.tensor([[1., 2., 3., 4.], [3., 4., 5., 6.]], + dtype=torch.float16).to("npu") + input2_fp16 = torch.tensor([[5., 6., 7., 8.], [7., 8., 9., 6.]], + dtype=torch.float16).to("npu") + + npu_output = self.npu_op_exec(input1_fp16, input2_fp16, 0, 0, 0, 0, + 1, 1, 1, 1, (10, 10), 0.1) + custom_output = self.custom_op_exec(input1_fp16, input2_fp16, 0, 0, 0, 0, + 1, 1, 1, 1, (10, 10), 0.1) + self.assertRtolEqual(npu_output, custom_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_bounding_box_encode.py b/tests/test_npu_bounding_box_encode.py new file mode 100644 index 00000000..b6a406b6 --- /dev/null +++ b/tests/test_npu_bounding_box_encode.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestBoundingBoxEncode(TestCase): + def npu_bounding_box_encode(self, anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3): + means = [means0, means1, means2, means3] + stds = [stds0, stds1, stds2, stds3] + px = (anchor_box[..., 0] + anchor_box[..., 2]) * 0.5 + py = (anchor_box[..., 1] + anchor_box[..., 3]) * 0.5 + pw = anchor_box[..., 2] - anchor_box[..., 0] + 1.0 + ph = anchor_box[..., 3] - anchor_box[..., 1] + 1.0 + + gx = (ground_truth_box[..., 0] + ground_truth_box[..., 2]) * 0.5 + gy = (ground_truth_box[..., 1] + ground_truth_box[..., 3]) * 0.5 + gw = ground_truth_box[..., 2] - ground_truth_box[..., 0] + 1.0 + gh = ground_truth_box[..., 3] - ground_truth_box[..., 1] + 1.0 + + eps = 1e-7 + dx = (gx - px) / (pw + eps) + dy = (gy - py) / (ph + eps) + dw = torch.log(torch.abs(gw) / torch.abs(pw + eps)) + dh = torch.log(torch.abs(gh) / torch.abs(ph + eps)) + deltas = torch.stack([dx, dy, dw, dh], dim=-1) + + means = deltas.new_tensor(means).unsqueeze(0) + stds = deltas.new_tensor(stds).unsqueeze(0) + deltas = deltas.sub_(means) .div_(stds) + + return deltas + + def custom_op_exec(self, anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3): + output = self.npu_bounding_box_encode(anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec(self, anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3): + output = ads.common.npu_bounding_box_encode(anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3) + output = output.to("cpu") + output = output.numpy() + return output + + def test_encode_shape_format_fp32(self): + input1 = torch.tensor([[1., 2., 3., 4.], [3., 4., 5., 6.]], + dtype=torch.float32).to("npu") + input2 = torch.tensor([[5., 6., 7., 8.], [7., 8., 9., 6.]], + dtype=torch.float32).to("npu") + + npu_output = self.npu_op_exec(input1, input2, 0, 0, 0, 0, + 0.1, 0.1, 0.2, 0.2) + custom_output = self.custom_op_exec(input1, input2, 0, 0, 0, 0, + 0.1, 0.1, 0.2, 0.2) + self.assertRtolEqual(npu_output, custom_output, 1e-3) + + def test_encode_shape_format_fp16(self): + input1_fp16 = torch.tensor([[1., 2., 3., 4.], [3., 4., 5., 6.]], + dtype=torch.float16).to("npu") + input2_fp16 = torch.tensor([[5., 6., 7., 8.], [7., 8., 9., 6.]], + dtype=torch.float16).to("npu") + + npu_output = self.npu_op_exec(input1_fp16, input2_fp16, 0, 0, 0, 0, + 0.1, 0.1, 0.2, 0.2) + custom_output = self.custom_op_exec(input1_fp16, input2_fp16, 0, 0, 0, 0, + 0.1, 0.1, 0.2, 0.2) + self.assertRtolEqual(npu_output, custom_output, 1e-3) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_broadcast.py b/tests/test_npu_broadcast.py new file mode 100644 index 00000000..d44badc1 --- /dev/null +++ b/tests/test_npu_broadcast.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +import torch_npu +import ads.common +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestNpuBroadcast(TestCase): + def custom_op_exec(self, input1, shape): + output = torch.broadcast_to(input1, shape) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec(self, input1, size): + output = ads.common.npu_broadcast(input1, size) + output = output.to("cpu") + output = output.numpy() + return output + + def test_npu_broadcast(self): + input1 = [ + torch.tensor([1, 2, 3]).npu(), + torch.tensor([[1], [2], [3]]).npu() + ] + for item in input1: + custom_output = self.custom_op_exec(item, (3, 3)) + npu_output = self.npu_op_exec(item, (3, 3)) + self.assertRtolEqual(custom_output, npu_output) + + +if __name__ == "__main__": + run_tests() -- Gitee From ec577d168f58288bef5e58b5887df3fc189b0976 Mon Sep 17 00:00:00 2001 From: zhanhao Date: Wed, 13 Dec 2023 14:30:53 +0800 Subject: [PATCH 2/3] add op --- ads/common/__init__.py | 7 ++ .../ops/csrc/AnchorResponseFlagsKernelNpu.cpp | 72 ++++++++++++ ads/common/ops/csrc/BatchNms.cpp | 49 ++++++++ .../ops/csrc/BoundingBoxDecodeKernelNpu.cpp | 57 ++++++++++ .../ops/csrc/BoundingBoxEncodeKernelNpu.cpp | 51 +++++++++ ads/common/ops/csrc/BroadCastKernelNpu.cpp | 49 ++++++++ .../ops/csrc/ConfusionTransposeKernelNpu.cpp | 97 ++++++++++++++++ ads/common/ops/csrc/FastGeluKernelNpu.cpp | 52 +++++++++ ads/common/ops/csrc/NpuSilu.cpp | 13 +-- ads/common/ops/csrc/RotaryMulKernelNpu.cpp | 10 +- .../ops/csrc/RotatedBoxDecodeKernelNpu.cpp | 5 - .../ops/csrc/RotatedBoxEncodeKernelNpu.cpp | 5 - ads/common/ops/csrc/RotatedIouKernelNpu.cpp | 6 - .../ops/csrc/RotatedOverlapsKernelNpu.cpp | 6 - ads/common/ops/csrc/ScatterMaxKernelNpu.cpp | 16 +-- ads/common/ops/csrc/ScatterV1KernelNpu.cpp | 4 - ads/common/ops/csrc/SignBitsPackKernelNpu.cpp | 4 - .../ops/csrc/SignBitsUnpackKernelNpu.cpp | 4 - ...SoftmaxCrossEntropyWithLogitsKernelNpu.cpp | 9 -- ads/common/ops/csrc/StrideAddKernelNpu.cpp | 5 - ads/common/ops/csrc/TransposeKernelNpu.cpp | 5 - .../ops/csrc/YoloBoxesEncodeKernelNpu.cpp | 5 - ads/common/ops/csrc/common.cpp | 5 +- ads/common/ops/csrc/common.h | 6 +- ads/common/ops/csrc/functions.h | 76 +++++++++++++ ads/common/ops/csrc/pybind.cpp | 23 ++++ ads/common/ops/fast_gelu.py | 23 ++++ ads/common/ops/npu_anchor_response_flags.py | 13 +++ ads/common/ops/npu_batch_nms.py | 30 +++++ ads/common/ops/npu_bounding_box_decode.py | 20 ++++ ads/common/ops/npu_bounding_box_encode.py | 18 +++ ads/common/ops/npu_broadcast.py | 16 +++ ads/common/ops/npu_confusion_transpose.py | 24 ++++ ads/common/ops/rotary_mul.py | 2 +- ads/common/ops/rotated_iou.py | 2 +- ads/common/ops/rotated_overlaps.py | 2 +- ads/common/ops/scatter.py | 2 +- ads/common/ops/sign_bits_pack.py | 2 +- ads/common/ops/sign_bits_unpack.py | 2 +- ads/common/ops/silu.py | 4 +- .../ops/softmax_cross_entropy_with_logits.py | 2 +- ads/common/ops/stride_add.py | 2 +- ads/common/ops/transpose.py | 2 +- ads/common/ops/yolo_boxes_encode.py | 2 +- tests/test_batch_nms.py | 44 +++++++ tests/test_fast_gelu.py | 51 +++++++++ tests/test_fast_gelu_backward.py | 43 +++++++ tests/test_npu_anchor_response_flags.py | 60 ++++++++++ tests/test_npu_bounding_box_decode.py | 107 ++++++++++++++++++ tests/test_npu_bounding_box_encode.py | 92 +++++++++++++++ tests/test_npu_broadcast.py | 48 ++++++++ 51 files changed, 1146 insertions(+), 108 deletions(-) create mode 100644 ads/common/ops/csrc/AnchorResponseFlagsKernelNpu.cpp create mode 100644 ads/common/ops/csrc/BatchNms.cpp create mode 100644 ads/common/ops/csrc/BoundingBoxDecodeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/BoundingBoxEncodeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/BroadCastKernelNpu.cpp create mode 100644 ads/common/ops/csrc/ConfusionTransposeKernelNpu.cpp create mode 100644 ads/common/ops/csrc/FastGeluKernelNpu.cpp create mode 100644 ads/common/ops/fast_gelu.py create mode 100644 ads/common/ops/npu_anchor_response_flags.py create mode 100644 ads/common/ops/npu_batch_nms.py create mode 100644 ads/common/ops/npu_bounding_box_decode.py create mode 100644 ads/common/ops/npu_bounding_box_encode.py create mode 100644 ads/common/ops/npu_broadcast.py create mode 100644 ads/common/ops/npu_confusion_transpose.py create mode 100644 tests/test_batch_nms.py create mode 100644 tests/test_fast_gelu.py create mode 100644 tests/test_fast_gelu_backward.py create mode 100644 tests/test_npu_anchor_response_flags.py create mode 100644 tests/test_npu_bounding_box_decode.py create mode 100644 tests/test_npu_bounding_box_encode.py create mode 100644 tests/test_npu_broadcast.py diff --git a/ads/common/__init__.py b/ads/common/__init__.py index c4dcc62e..8e59829b 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -14,3 +14,10 @@ from .ops.silu import npu_silu from .ops.silu import npu_silu_ from .ops.rotary_mul import npu_rotary_mul from .ops.npu_abs import npu_abs +from .ops.fast_gelu import fast_gelu +from .ops.npu_anchor_response_flags import npu_anchor_response_flags +from .ops.npu_bounding_box_decode import npu_bounding_box_decode +from .ops.npu_bounding_box_encode import npu_bounding_box_encode +from .ops.npu_batch_nms import npu_batch_nms +from .ops.npu_confusion_transpose import npu_confusion_transpose +from .ops.npu_broadcast import npu_broadcast diff --git a/ads/common/ops/csrc/AnchorResponseFlagsKernelNpu.cpp b/ads/common/ops/csrc/AnchorResponseFlagsKernelNpu.cpp new file mode 100644 index 00000000..f414633c --- /dev/null +++ b/ads/common/ops/csrc/AnchorResponseFlagsKernelNpu.cpp @@ -0,0 +1,72 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +namespace { +c10::SmallVector infersize_npu_anchor_response_flags( + at::IntArrayRef featmap_size, + int64_t num_base_anchors) +{ + int64_t output_value = featmap_size[0] * featmap_size[1] * num_base_anchors; + c10::SmallVector output_size = {output_value}; + return output_size; +} + +inline void anchor_response_flags_check( + const at::Tensor& self, + at::IntArrayRef featmap_size, + at::IntArrayRef stride) +{ + TORCH_CHECK( + featmap_size.size() == 2, + "expected feat_map_size equals to 2, but got size ", + featmap_size.size()); + TORCH_CHECK( + self.dim() == 2 && self.size(1) == 4, + "Non-empty 2D gt_bboxes tensor expected but got a tensor with sizes ", + self.sizes()); + TORCH_CHECK( + self.scalar_type() == at::kHalf || self.scalar_type() == at::kFloat, + "float16 or float32 tensor expected but got a tensor with dtype: ", + self.scalar_type()); +} +} // namespace + +at::Tensor npu_anchor_response_flags( + const at::Tensor& self, + at::IntArrayRef featmap_size, + at::IntArrayRef stride, + int64_t num_base_anchors) +{ + anchor_response_flags_check(self, featmap_size, stride); + auto output_size = infersize_npu_anchor_response_flags(featmap_size, num_base_anchors); + auto options = self.options().dtype(at::kByte); + at::Tensor result = at::empty(output_size, options); + + at::Tensor self_cp = self.to(at::kFloat); + + at_npu::native::OpCommand cmd; + cmd.Name("AnchorResponseFlags") + .Input(self_cp) + .Output(result) + .Attr("featmap_size", featmap_size) + .Attr("strides", stride) + .Attr("num_base_anchors", num_base_anchors) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/BatchNms.cpp b/ads/common/ops/csrc/BatchNms.cpp new file mode 100644 index 00000000..a7051437 --- /dev/null +++ b/ads/common/ops/csrc/BatchNms.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +std::tuple npu_batch_nms( + const at::Tensor& self, + const at::Tensor& scores, + double score_threshold, + double iou_threshold, + int64_t max_size_per_class, + int64_t max_total_size, + bool change_coordinate_frame, + bool transpose_box) +{ + at::Tensor nmsed_boxes = at::empty({self.size(0), max_total_size, 4}, self.options()); + at::Tensor nmsed_scores = at::empty({self.size(0), max_total_size}, self.options()); + at::Tensor nmsed_classes = at::empty({self.size(0), max_total_size}, self.options()); + at::Tensor nmsed_num = at::empty({self.size(0)}, self.options().dtype(at::kInt)); + at_npu::native::OpCommand cmd; + cmd.Name("BatchMultiClassNonMaxSuppression") + .Input(self) + .Input(scores) + .Output(nmsed_boxes) + .Output(nmsed_scores) + .Output(nmsed_classes) + .Output(nmsed_num) + .Attr("score_threshold", static_cast(score_threshold)) + .Attr("iou_threshold", static_cast(iou_threshold)) + .Attr("max_size_per_class", max_size_per_class) + .Attr("max_total_size", max_total_size) + .Attr("change_coordinate_frame", change_coordinate_frame) + .Attr("transpose_box", transpose_box) + .Run(); + return std::tie(nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num); +} diff --git a/ads/common/ops/csrc/BoundingBoxDecodeKernelNpu.cpp b/ads/common/ops/csrc/BoundingBoxDecodeKernelNpu.cpp new file mode 100644 index 00000000..85fc0764 --- /dev/null +++ b/ads/common/ops/csrc/BoundingBoxDecodeKernelNpu.cpp @@ -0,0 +1,57 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +at::Tensor npu_bounding_box_decode( + const at::Tensor& rois, + const at::Tensor& deltas, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3, + at::IntArrayRef max_shape, + double wh_ratio_clip) +{ + c10::SmallVector output_size = {rois.size(0), 4}; + at::Tensor result = at::empty(output_size, rois.options()); + c10::SmallVector means = { + static_cast(means0), + static_cast(means1), + static_cast(means2), + static_cast(means3)}; + c10::SmallVector stds = { + static_cast(stds0), + static_cast(stds1), + static_cast(stds2), + static_cast(stds3)}; + at_npu::native::OpCommand cmd; + cmd.Name("BoundingBoxDecode") + .Input(rois) + .Input(deltas) + .Output(result) + .Attr("means", means) + .Attr("stds", stds) + .Attr("max_shape", max_shape) + .Attr("wh_ratio_clip", static_cast(wh_ratio_clip)) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/BoundingBoxEncodeKernelNpu.cpp b/ads/common/ops/csrc/BoundingBoxEncodeKernelNpu.cpp new file mode 100644 index 00000000..aa5bad77 --- /dev/null +++ b/ads/common/ops/csrc/BoundingBoxEncodeKernelNpu.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +at::Tensor npu_bounding_box_encode( + const at::Tensor& anchor_box, + const at::Tensor& ground_truth_box, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3) +{ + at::Tensor result = at::empty({anchor_box.size(0), 4}, anchor_box.options()); + c10::SmallVector means = { + static_cast(means0), + static_cast(means1), + static_cast(means2), + static_cast(means3)}; + c10::SmallVector stds = { + static_cast(stds0), + static_cast(stds1), + static_cast(stds2), + static_cast(stds3)}; + at_npu::native::OpCommand cmd; + cmd.Name("BoundingBoxEncode") + .Input(anchor_box) + .Input(ground_truth_box) + .Output(result) + .Attr("means", means) + .Attr("stds", stds) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/BroadCastKernelNpu.cpp b/ads/common/ops/csrc/BroadCastKernelNpu.cpp new file mode 100644 index 00000000..b4858034 --- /dev/null +++ b/ads/common/ops/csrc/BroadCastKernelNpu.cpp @@ -0,0 +1,49 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" + + +namespace { +at::Tensor& npu_broadcast_out_nocheck(at::Tensor& result, const at::Tensor& self, at::IntArrayRef size) +{ + at_npu::native::OpCommand cmd; + cmd.Name("BroadcastTo") + .Input(self) + .Input(size) + .Output(result) + .Run(); + return result; +} +} // namespace + +at::Tensor& npu_broadcast_out(const at::Tensor& self, at::IntArrayRef size, at::Tensor& result) +{ + npu_broadcast_out_nocheck(result, self, size); + + return result; +} + +at::Tensor npu_broadcast(const at::Tensor& self, at::IntArrayRef size) +{ + at::Tensor self_cp = self.dtype() == at::kBool ? self.to(at::kInt) : self; + at::Tensor result = at::empty(size, self_cp.options()); + npu_broadcast_out_nocheck(result, self_cp, size); + + if (self.dtype() == at::kBool) { + result = result.to(at::kBool); + } + return result; +} diff --git a/ads/common/ops/csrc/ConfusionTransposeKernelNpu.cpp b/ads/common/ops/csrc/ConfusionTransposeKernelNpu.cpp new file mode 100644 index 00000000..a12d1d7c --- /dev/null +++ b/ads/common/ops/csrc/ConfusionTransposeKernelNpu.cpp @@ -0,0 +1,97 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" +#include "common.h" + +at::Tensor npu_confusion_transpose( + const at::Tensor& self, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first) +{ + c10::SmallVector output_size; + if (transpose_first) { + output_size = array_to_small_vector(shape); + } else { + auto shape_size = shape.size(); + for (uint i = 0; i < perm.size(); i++) { + TORCH_CHECK(shape_size > perm[i], "npu_confusion_transpose input invalid, " + "shape has size ", + shape_size, " but perm[i] is, ", perm[i]); + output_size.emplace_back(shape[perm[i]]); + } + } + + at::Tensor result = at::empty(output_size, self.options()); + at_npu::native::OpCommand cmd; + cmd.Name("ConfusionTransposeD") + .Input(self) + .Output(result) + .Attr("perm", perm) + .Attr("shape", shape) + .Attr("transpose_first", transpose_first) + .Run(); + + return result; +} + +void check_confusion_transpose_perm(at::IntArrayRef perm, at::IntArrayRef shape) +{ + auto input_dim = shape.size(); + TORCH_CHECK(perm.size() == input_dim, "The length of perm should be the same as shape."); + std::vector seen(input_dim); + for (const auto i : c10::irange(input_dim)) { + auto dim = at::maybe_wrap_dim(perm[i], input_dim); + TORCH_CHECK(!seen[dim], "Repeated dim in perm"); + seen[dim] = true; + } +} + +at::Tensor npu_confusion_transpose_backward( + const at::Tensor& grad, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first) +{ + c10::SmallVector svec_shape; + if (transpose_first) { + svec_shape = array_to_small_vector(shape); + } else { + check_confusion_transpose_perm(perm, shape); + for (int i = 0; i < perm.size(); i++) { + svec_shape.emplace_back(shape[perm[i]]); + } + } + std::vector vec_perm; + int64_t perm_len = perm.size(); + int64_t temp_perm[perm_len] = {0}; + for (int64_t i = 0; i < perm_len; i++) { + temp_perm[perm[i]] = i; + } + vec_perm = std::vector(temp_perm, temp_perm+perm_len); + perm = at::IntArrayRef(vec_perm); + at::Tensor result = at::empty(shape, grad.options()); + + at_npu::native::OpCommand cmd; + cmd.Name("ConfusionTransposeD") + .Input(grad) + .Output(result) + .Attr("perm", perm) + .Attr("shape", svec_shape) + .Attr("transpose_first", transpose_first) + .Run(); + return result; +} diff --git a/ads/common/ops/csrc/FastGeluKernelNpu.cpp b/ads/common/ops/csrc/FastGeluKernelNpu.cpp new file mode 100644 index 00000000..a56dcb7a --- /dev/null +++ b/ads/common/ops/csrc/FastGeluKernelNpu.cpp @@ -0,0 +1,52 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "torch_npu/csrc/framework/OpCommand.h" + +namespace { +at::Tensor& fast_gelu_backward_npu_nocheck( + at::Tensor& grad_input, + const at::Tensor& grad, + const at::Tensor& self) +{ + at_npu::native::OpCommand cmd; + cmd.Name("FastGeluGrad") + .Input(grad) + .Input(self) + .Output(grad_input) + .Run(); + return grad_input; +} +} // namespace + +at::Tensor npu_fast_gelu(const at::Tensor& self) +{ + at::Tensor result = at::empty(self.sizes(), self.options()); + at_npu::native::OpCommand cmd; + cmd.Name("FastGelu") + .Input(self) + .Output(result) + .Run(); + return result; +} + +at::Tensor npu_fast_gelu_backward( + const at::Tensor& grad, + const at::Tensor& self) +{ + at::Tensor grad_input = at::empty(self.sizes(), self.options()); + fast_gelu_backward_npu_nocheck(grad_input, grad, self); + return grad_input; +} diff --git a/ads/common/ops/csrc/NpuSilu.cpp b/ads/common/ops/csrc/NpuSilu.cpp index 0b81c0f4..4f62cad5 100644 --- a/ads/common/ops/csrc/NpuSilu.cpp +++ b/ads/common/ops/csrc/NpuSilu.cpp @@ -1,18 +1,7 @@ -#include - #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using torch::autograd::AutogradContext; -using torch::autograd::Function; -using npu_preparation = at_npu::native::OpPreparation; -using npu_utils = at_npu::native::NpuUtils; -using tensor_list = std::vector; - at::Tensor &silu_out_npu_nocheck(at::Tensor &result, const at::Tensor &self) { at_npu::native::OpCommand cmd; @@ -80,4 +69,4 @@ at::Tensor &npu_silu_(at::Tensor &self) { silu_out_npu(self, self); return self; -} \ No newline at end of file +} diff --git a/ads/common/ops/csrc/RotaryMulKernelNpu.cpp b/ads/common/ops/csrc/RotaryMulKernelNpu.cpp index 3e814e4e..05569309 100644 --- a/ads/common/ops/csrc/RotaryMulKernelNpu.cpp +++ b/ads/common/ops/csrc/RotaryMulKernelNpu.cpp @@ -14,18 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/aten/CustomFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; -using torch::autograd::Function; -using torch::autograd::AutogradContext; using tensor_tuple = std::tuple; namespace { @@ -42,7 +34,7 @@ at::Tensor &rotary_mul_nocheck(at::Tensor &y, const at::Tensor &x, const at::Ten return y; } -tensor_tuple rotary_mul_backward_nocheck(at::Tensor &dx, at::Tensor &dr1, at::Tensor &dr2, const at::Tensor &x, +tensor_tuple rotary_mul_backward_nocheck(at::Tensor &dx, at::Tensor &dr1, at::Tensor &dr2, const at::Tensor &x, const at::Tensor &r1, const at::Tensor &r2, const at::Tensor &dy) { TORCH_CHECK(x.dim() == 4, "The dim of input tensor [x] shoule equal to four."); diff --git a/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp b/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp index db949fc9..0e8aa592 100644 --- a/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp +++ b/ads/common/ops/csrc/RotatedBoxDecodeKernelNpu.cpp @@ -15,13 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; -using npu_utils = at_npu::native::NpuUtils; at::Tensor npu_rotated_box_decode(const at::Tensor &self, const at::Tensor &deltas, const at::Tensor &weight) { diff --git a/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp b/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp index cfe515ba..865b994b 100644 --- a/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp +++ b/ads/common/ops/csrc/RotatedBoxEncodeKernelNpu.cpp @@ -15,13 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; - at::Tensor npu_rotated_box_encode( const at::Tensor &self, const at::Tensor >Box, diff --git a/ads/common/ops/csrc/RotatedIouKernelNpu.cpp b/ads/common/ops/csrc/RotatedIouKernelNpu.cpp index 7c8c334a..dc94943d 100644 --- a/ads/common/ops/csrc/RotatedIouKernelNpu.cpp +++ b/ads/common/ops/csrc/RotatedIouKernelNpu.cpp @@ -15,14 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/aten/CustomFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; - namespace { at::Tensor &rotated_iou_npu_nocheck( at::Tensor &iou, diff --git a/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp b/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp index ac476c62..0a957ca9 100644 --- a/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp +++ b/ads/common/ops/csrc/RotatedOverlapsKernelNpu.cpp @@ -15,14 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/aten/CustomFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; - namespace { at::Tensor &rotated_overlaps_npu_nocheck( at::Tensor &overlaps, diff --git a/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp b/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp index c4e94384..f3b11664 100644 --- a/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp +++ b/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp @@ -1,16 +1,8 @@ -#include - #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "common.h" -using namespace at; using namespace std; -using torch::autograd::Function; -using torch::autograd::AutogradContext; -using tensor_list = std::vector; - std::tuple npu_scatter_max( const at::Tensor& updates, const at::Tensor& indices, @@ -21,7 +13,7 @@ std::tuple npu_scatter_max( sizes[0] = indices.max().item().toLong() + 1; at::Tensor result = out.value_or(at::zeros(sizes, updates.options().dtype(at::kFloat))); - at::Tensor argmax = at_npu::native::OpPreparation::ApplyTensor(result, result.options().dtype(at::kInt)); + at::Tensor argmax = at::empty(result.sizes(), result.options().dtype(at::kInt)); at_npu::native::OpCommand cmd; cmd.Name("ScatterMaxWithArgmax") @@ -37,7 +29,7 @@ std::tuple npu_scatter_max( at::Tensor npu_scatter_max_backward(const at::Tensor& x, const at::Tensor& segment_ids, const at::Tensor& num_segments) { - c10::SmallVector output_size; + c10::SmallVector output_size; auto num_segments_value = num_segments.item().toLong(); output_size.push_back(num_segments_value); @@ -47,7 +39,7 @@ at::Tensor npu_scatter_max_backward(const at::Tensor& x, const at::Tensor& segme copy(x_sizes.begin() + segment_ids_dims, x_sizes.end(), std::back_inserter(output_size)); - at::Tensor out = at_npu::native::OpPreparation::ApplyTensor(x, output_size); + at::Tensor out = at::empty(output_size, x.options()); at_npu::native::OpCommand cmd; cmd.Name("UnsortedSegmentSum") .Input(x) diff --git a/ads/common/ops/csrc/ScatterV1KernelNpu.cpp b/ads/common/ops/csrc/ScatterV1KernelNpu.cpp index 155ea383..f96d4608 100644 --- a/ads/common/ops/csrc/ScatterV1KernelNpu.cpp +++ b/ads/common/ops/csrc/ScatterV1KernelNpu.cpp @@ -15,12 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; at::Tensor npu_scatter(const at::Tensor &self, const at::Tensor &indices, const at::Tensor &updates, int64_t dim) { diff --git a/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp b/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp index 5fb1139c..95f4c3ff 100644 --- a/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp +++ b/ads/common/ops/csrc/SignBitsPackKernelNpu.cpp @@ -15,12 +15,8 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" -using npu_preparation = at_npu::native::OpPreparation; at::Tensor npu_sign_bits_pack(const at::Tensor &self, int64_t size) { diff --git a/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp b/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp index e7a35680..27ae440b 100644 --- a/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp +++ b/ads/common/ops/csrc/SignBitsUnpackKernelNpu.cpp @@ -15,13 +15,9 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; at::Tensor npu_sign_bits_unpack_compute( const at::Tensor &input, diff --git a/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp b/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp index e936c819..cc8f95df 100644 --- a/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp +++ b/ads/common/ops/csrc/SoftmaxCrossEntropyWithLogitsKernelNpu.cpp @@ -14,19 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; -using torch::autograd::AutogradContext; -using torch::autograd::Function; -using tensor_list = std::vector; - namespace { std::tuple softmax_cross_entropy_with_logits_out_nocheck( at::Tensor &result, diff --git a/ads/common/ops/csrc/StrideAddKernelNpu.cpp b/ads/common/ops/csrc/StrideAddKernelNpu.cpp index ebcfbfda..47922f62 100644 --- a/ads/common/ops/csrc/StrideAddKernelNpu.cpp +++ b/ads/common/ops/csrc/StrideAddKernelNpu.cpp @@ -14,15 +14,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; namespace { at::Tensor &stride_add_out_npu_nocheck( diff --git a/ads/common/ops/csrc/TransposeKernelNpu.cpp b/ads/common/ops/csrc/TransposeKernelNpu.cpp index ad9d2e97..2e8705c2 100644 --- a/ads/common/ops/csrc/TransposeKernelNpu.cpp +++ b/ads/common/ops/csrc/TransposeKernelNpu.cpp @@ -15,14 +15,9 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "functions.h" #include "common.h" -using npu_utils = at_npu::native::NpuUtils; - namespace { at::Tensor &npu_transpose_out_nocheck( at::Tensor &result, diff --git a/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp b/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp index f3cd4201..df02a325 100644 --- a/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp +++ b/ads/common/ops/csrc/YoloBoxesEncodeKernelNpu.cpp @@ -15,14 +15,9 @@ // limitations under the License. #include "torch_npu/csrc/framework/OpCommand.h" -#include "torch_npu/csrc/framework/utils/OpPreparation.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" -#include "torch_npu/csrc/aten/CustomFunctions.h" #include "functions.h" #include "common.h" -using npu_preparation = at_npu::native::OpPreparation; namespace { inline void yolo_boxes_encode_check( diff --git a/ads/common/ops/csrc/common.cpp b/ads/common/ops/csrc/common.cpp index 8e4b3037..f6f9cc49 100644 --- a/ads/common/ops/csrc/common.cpp +++ b/ads/common/ops/csrc/common.cpp @@ -1,12 +1,9 @@ #include #include "torch_npu/csrc/framework/utils/CalcuOpUtil.h" -#include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/aten/mirror/NPUMemoryOverlap.h" -#include "torch_npu/csrc/aten/NPUNativeFunctions.h" #include "third_party/acl/inc/acl/acl_base.h" #include "common.h" -using npu_utils = at_npu::native::NpuUtils; using CalcuOpUtil = at_npu::native::CalcuOpUtil; #define AT_ALL_SCALAR_TYPE_AND_ACL_DATATYPE_PAIR(_) \ @@ -192,4 +189,4 @@ bool check_match(const at::Tensor &self) void format_fresh_view(at::Tensor &x, const at::Tensor &y) { x.copy_(y); -} \ No newline at end of file +} diff --git a/ads/common/ops/csrc/common.h b/ads/common/ops/csrc/common.h index 49a75653..95c2b5a1 100644 --- a/ads/common/ops/csrc/common.h +++ b/ads/common/ops/csrc/common.h @@ -1,3 +1,5 @@ +#ifndef __COMMON_H__ +#define __COMMON_H__ #include #include #include @@ -26,4 +28,6 @@ c10::SmallVector convert_array_to_vector(c10::IntArrayRef intArray); c10::SmallVector infersize_stride_add(c10::IntArrayRef shape1_, c10::IntArrayRef shape2_); c10::SmallVector transpose_npu_output_size(const at::Tensor &self, c10::IntArrayRef perm); bool check_match(const at::Tensor &self); -void format_fresh_view(at::Tensor &x, const at::Tensor &y); \ No newline at end of file +void format_fresh_view(at::Tensor &x, const at::Tensor &y); + +#endif // __COMMON_H__ diff --git a/ads/common/ops/csrc/functions.h b/ads/common/ops/csrc/functions.h index 243774ab..b832c7e2 100644 --- a/ads/common/ops/csrc/functions.h +++ b/ads/common/ops/csrc/functions.h @@ -1,3 +1,19 @@ +// Copyright (c) 2023, Huawei Technologies.All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifndef __FUNCTIONS_H__ +#define __FUNCTIONS_H__ + #include #include #include @@ -43,3 +59,63 @@ at::Tensor npu_rotary_mul(const at::Tensor &self, const at::Tensor &r1, const at at::Tensor npu_silu(const at::Tensor& self); at::Tensor& npu_silu_(at::Tensor& self); at::Tensor npu_abs(const at::Tensor& self); +at::Tensor npu_fast_gelu_backward(const at::Tensor& grad, const at::Tensor& self); +at::Tensor npu_abs(const at::Tensor& self); +at::Tensor npu_fast_gelu(const at::Tensor& self); +at::Tensor npu_anchor_response_flags(const at::Tensor& self, at::IntArrayRef featmap_size, at::IntArrayRef stride, int64_t num_base_anchors); +at::Tensor npu_bounding_box_decode( + const at::Tensor& rois, + const at::Tensor& deltas, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3, + at::IntArrayRef max_shape, + double wh_ratio_clip); +at::Tensor npu_bounding_box_encode( + const at::Tensor& anchor_box, + const at::Tensor& ground_truth_box, + double means0, + double means1, + double means2, + double means3, + double stds0, + double stds1, + double stds2, + double stds3); +std::tuple npu_batch_nms( + const at::Tensor& self, + const at::Tensor& scores, + double score_threshold, + double iou_threshold, + int64_t max_size_per_class, + int64_t max_total_size, + bool change_coordinate_frame, + bool transpose_box); +at::Tensor npu_confusion_transpose( + const at::Tensor& self, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first); +at::Tensor npu_confusion_transpose_backward( + const at::Tensor& grad, + at::IntArrayRef perm, + at::IntArrayRef shape, + bool transpose_first); +at::Tensor npu_conv_transpose2d( + const at::Tensor& input, + const at::Tensor& weight, + const c10::optional& bias_opt, + at::IntArrayRef padding, + at::IntArrayRef output_padding, + at::IntArrayRef stride, + at::IntArrayRef dilation, + int64_t groups); +at::Tensor npu_broadcast(const at::Tensor& self, at::IntArrayRef size); +at::Tensor& npu_broadcast_out(const at::Tensor& self, at::IntArrayRef size, at::Tensor& result); + +#endif // __FUNCTIONS_H__ diff --git a/ads/common/ops/csrc/pybind.cpp b/ads/common/ops/csrc/pybind.cpp index 4eb1cf6f..b8ebe3f5 100644 --- a/ads/common/ops/csrc/pybind.cpp +++ b/ads/common/ops/csrc/pybind.cpp @@ -43,4 +43,27 @@ void init_common(pybind11::module &m) m.def("npu_rotary_mul", &npu_rotary_mul); m.def("npu_abs", &npu_abs); + + // npu_fast_gelu + m.def("npu_fast_gelu", &npu_fast_gelu); + m.def("npu_fast_gelu_backward", &npu_fast_gelu_backward); + + // npu_anchor_response_flags + m.def("npu_anchor_response_flags", &npu_anchor_response_flags); + + // npu_bounding_box_decode + m.def("npu_bounding_box_decode", &npu_bounding_box_decode); + + // npu_bounding_box_encode + m.def("npu_bounding_box_encode", &npu_bounding_box_encode); + + // npu_batch_nms + m.def("npu_batch_nms", &npu_batch_nms); + + // npu_confusion_transpose + m.def("npu_confusion_transpose", &npu_confusion_transpose); + m.def("npu_confusion_transpose_backward", &npu_confusion_transpose_backward); + + // npu_broadcast + m.def("npu_broadcast", &npu_broadcast); } diff --git a/ads/common/ops/fast_gelu.py b/ads/common/ops/fast_gelu.py new file mode 100644 index 00000000..45557513 --- /dev/null +++ b/ads/common/ops/fast_gelu.py @@ -0,0 +1,23 @@ +import torch +from torch.autograd import Function + +import torch_npu +import ads_c + + +class FastGeluFunction(Function): + @staticmethod + def forward(ctx, self): + out = ads_c.npu_fast_gelu(self) + ctx.save_for_backward(self) + return out + + @staticmethod + def backward(ctx, grad_output): + self = ctx.saved_tensors[0] + + grad = ads_c.npu_fast_gelu_backward(grad_output, self) + + return grad + +fast_gelu = FastGeluFunction.apply diff --git a/ads/common/ops/npu_anchor_response_flags.py b/ads/common/ops/npu_anchor_response_flags.py new file mode 100644 index 00000000..b75fd77c --- /dev/null +++ b/ads/common/ops/npu_anchor_response_flags.py @@ -0,0 +1,13 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class NpuAnchorResponseFlagsFunction(Function): + @staticmethod + def forward(ctx, self, featmap_size, stride, num_base_anchors): + result = ads_c.npu_anchor_response_flags(self, featmap_size, stride, num_base_anchors) + return result + +npu_anchor_response_flags = NpuAnchorResponseFlagsFunction.apply diff --git a/ads/common/ops/npu_batch_nms.py b/ads/common/ops/npu_batch_nms.py new file mode 100644 index 00000000..4a5b2ef9 --- /dev/null +++ b/ads/common/ops/npu_batch_nms.py @@ -0,0 +1,30 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class NpuBatchNmsFunction(Function): + @staticmethod + def forward( + ctx, + self, + scores, + score_threshold, + iou_threshold, + max_size_per_class, + max_total_size, + change_coordinate_frame=False, + transpose_box=False): + result = ads_c.npu_batch_nms( + self, + scores, + score_threshold, + iou_threshold, + max_size_per_class, + max_total_size, + change_coordinate_frame, + transpose_box) + return result + +npu_batch_nms = NpuBatchNmsFunction.apply diff --git a/ads/common/ops/npu_bounding_box_decode.py b/ads/common/ops/npu_bounding_box_decode.py new file mode 100644 index 00000000..67099d20 --- /dev/null +++ b/ads/common/ops/npu_bounding_box_decode.py @@ -0,0 +1,20 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class NpuBoundingBodDecodeFunction(Function): + @staticmethod + def forward(ctx, rois, deltas, + means0, means1, means2, means3, + stds0, tds1, stds2, stds3, + max_shape, wh_ratio_clip): + result = ads_c.npu_bounding_box_decode( + rois, deltas, + means0, means1, means2, means3, + stds0, stds1, stds2, stds3, + max_shape, wh_ratio_clip) + return result + +npu_bounding_box_decode = NpuBoundingBodDecodeFunction.apply diff --git a/ads/common/ops/npu_bounding_box_encode.py b/ads/common/ops/npu_bounding_box_encode.py new file mode 100644 index 00000000..6efac4dd --- /dev/null +++ b/ads/common/ops/npu_bounding_box_encode.py @@ -0,0 +1,18 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class NpuBoundingBodEncodeFunction(Function): + @staticmethod + def forward(ctx, anchor_box, ground_truth_box, + means0, means1, means2, means3, + stds0, tds1, stds2, stds3): + result = ads_c.npu_bounding_box_encode( + anchor_box, ground_truth_box, + means0, means1, means2, means3, + stds0, stds1, stds2, stds3) + return result + +npu_bounding_box_encode = NpuBoundingBodEncodeFunction.apply \ No newline at end of file diff --git a/ads/common/ops/npu_broadcast.py b/ads/common/ops/npu_broadcast.py new file mode 100644 index 00000000..b3371b28 --- /dev/null +++ b/ads/common/ops/npu_broadcast.py @@ -0,0 +1,16 @@ +import torch +from torch.autograd import Function +import torch_npu +import ads_c + + +class BroadCastlFunction(Function): + @staticmethod + def forward(ctx, self, size, out=None): + if out is None: + result = ads_c.npu_broadcast(self, size) + else: + result = ads_c.npu_broadcast_out(self, size, out) + return result + +npu_broadcast = BroadCastlFunction.apply \ No newline at end of file diff --git a/ads/common/ops/npu_confusion_transpose.py b/ads/common/ops/npu_confusion_transpose.py new file mode 100644 index 00000000..566d19f1 --- /dev/null +++ b/ads/common/ops/npu_confusion_transpose.py @@ -0,0 +1,24 @@ +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads_c + + +class NpuConfusionTransposeFunction(Function): + @staticmethod + def forward(ctx, self, perm, shape, transpose_first): + out = ads_c.npu_confusion_transpose(self, perm, shape, transpose_first) + ctx.save_for_backward(perm, self.sizes(), transpose_first) + + return out + + @staticmethod + def backward(ctx, grad_output): + perm, sefl_sizes, transpose_first = ctx.saved_tensors + out = ads_c.npu_confusion_transpose_backward(grad_output, perm, sefl_sizes, not transpose_first) + + return out, None, None, None + +npu_confusion_transpose = NpuConfusionTransposeFunction.apply diff --git a/ads/common/ops/rotary_mul.py b/ads/common/ops/rotary_mul.py index bf9c2a9b..5079961b 100644 --- a/ads/common/ops/rotary_mul.py +++ b/ads/common/ops/rotary_mul.py @@ -19,4 +19,4 @@ class RotaryMulFunction(Function): result = ads_c.npu_rotary_mul_backward(grad_output, input, r1, r2) return result -npu_rotary_mul = RotaryMulFunction.apply \ No newline at end of file +npu_rotary_mul = RotaryMulFunction.apply diff --git a/ads/common/ops/rotated_iou.py b/ads/common/ops/rotated_iou.py index d88d3e9b..896f001c 100644 --- a/ads/common/ops/rotated_iou.py +++ b/ads/common/ops/rotated_iou.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_rotated_iou = ads_c.npu_rotated_iou \ No newline at end of file +npu_rotated_iou = ads_c.npu_rotated_iou diff --git a/ads/common/ops/rotated_overlaps.py b/ads/common/ops/rotated_overlaps.py index 40753235..c481fe6f 100644 --- a/ads/common/ops/rotated_overlaps.py +++ b/ads/common/ops/rotated_overlaps.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_rotated_overlaps = ads_c.npu_rotated_overlaps \ No newline at end of file +npu_rotated_overlaps = ads_c.npu_rotated_overlaps diff --git a/ads/common/ops/scatter.py b/ads/common/ops/scatter.py index 7d89109c..d9e6de8a 100644 --- a/ads/common/ops/scatter.py +++ b/ads/common/ops/scatter.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_scatter = ads_c.npu_scatter \ No newline at end of file +npu_scatter = ads_c.npu_scatter diff --git a/ads/common/ops/sign_bits_pack.py b/ads/common/ops/sign_bits_pack.py index c09d486a..7b1e0040 100644 --- a/ads/common/ops/sign_bits_pack.py +++ b/ads/common/ops/sign_bits_pack.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_sign_bits_pack = ads_c.npu_sign_bits_pack \ No newline at end of file +npu_sign_bits_pack = ads_c.npu_sign_bits_pack diff --git a/ads/common/ops/sign_bits_unpack.py b/ads/common/ops/sign_bits_unpack.py index efa1a2dd..ed374e17 100644 --- a/ads/common/ops/sign_bits_unpack.py +++ b/ads/common/ops/sign_bits_unpack.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_sign_bits_unpack = ads_c.npu_sign_bits_unpack \ No newline at end of file +npu_sign_bits_unpack = ads_c.npu_sign_bits_unpack diff --git a/ads/common/ops/silu.py b/ads/common/ops/silu.py index 8ca866db..bd4251b0 100644 --- a/ads/common/ops/silu.py +++ b/ads/common/ops/silu.py @@ -13,7 +13,7 @@ class SiluFunction(Function): result = func(input) ctx.save_for_backward(input, result) return result - + @staticmethod def backward(ctx, grad_outputs): x0, x1 = ctx.saved_tensors @@ -22,4 +22,4 @@ class SiluFunction(Function): npu_silu = SiluFunction.apply -npu_silu_ = ads_c.npu_silu_ \ No newline at end of file +npu_silu_ = ads_c.npu_silu_ diff --git a/ads/common/ops/softmax_cross_entropy_with_logits.py b/ads/common/ops/softmax_cross_entropy_with_logits.py index f09d2a3e..cd12c5dd 100644 --- a/ads/common/ops/softmax_cross_entropy_with_logits.py +++ b/ads/common/ops/softmax_cross_entropy_with_logits.py @@ -20,4 +20,4 @@ class SoftMaxFunction(Function): result = ads_c.npu_softmax_cross_entropy_with_logits_backward(grad_output, feature, labels) return result -npu_softmax_cross_entropy_with_logits = SoftMaxFunction.apply \ No newline at end of file +npu_softmax_cross_entropy_with_logits = SoftMaxFunction.apply diff --git a/ads/common/ops/stride_add.py b/ads/common/ops/stride_add.py index 24a3946b..586a83c3 100644 --- a/ads/common/ops/stride_add.py +++ b/ads/common/ops/stride_add.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_stride_add = ads_c.npu_stride_add \ No newline at end of file +npu_stride_add = ads_c.npu_stride_add diff --git a/ads/common/ops/transpose.py b/ads/common/ops/transpose.py index 14972299..a27dca7d 100644 --- a/ads/common/ops/transpose.py +++ b/ads/common/ops/transpose.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_transpose = ads_c.npu_transpose \ No newline at end of file +npu_transpose = ads_c.npu_transpose diff --git a/ads/common/ops/yolo_boxes_encode.py b/ads/common/ops/yolo_boxes_encode.py index 585adb58..cb915a0f 100644 --- a/ads/common/ops/yolo_boxes_encode.py +++ b/ads/common/ops/yolo_boxes_encode.py @@ -2,4 +2,4 @@ import torch import torch_npu import ads_c -npu_yolo_boxes_encode = ads_c.npu_yolo_boxes_encode \ No newline at end of file +npu_yolo_boxes_encode = ads_c.npu_yolo_boxes_encode diff --git a/tests/test_batch_nms.py b/tests/test_batch_nms.py new file mode 100644 index 00000000..11c3245c --- /dev/null +++ b/tests/test_batch_nms.py @@ -0,0 +1,44 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TesBatchNms(TestCase): + def test_batch_nms_shape_format(self): + boxes = torch.randn(8, 4, 1, 4).npu() + scores = torch.randn(8, 4, 1).npu() + boxes_fp16 = boxes.half() + scores_fp16 = scores.half() + nmsed_boxes, nmsed_scores, nmsed_classes, nmsed_num = ads.common.npu_batch_nms(boxes, scores, 0.3, 0.5, 4, 4) + boxes1, scores1, classes1, num1 = ads.common.npu_batch_nms(boxes_fp16, scores_fp16, 0.3, 0.5, 4, 4) + expedt_nmsed_classes = torch.tensor([[0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000], + [0.0000, 0.0000, 0.0000, 0.0000]], dtype=torch.float32) + self.assertRtolEqual(expedt_nmsed_classes, nmsed_classes.cpu()) + self.assertRtolEqual(expedt_nmsed_classes.half(), classes1.cpu()) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_fast_gelu.py b/tests/test_fast_gelu.py new file mode 100644 index 00000000..c81b5e92 --- /dev/null +++ b/tests/test_fast_gelu.py @@ -0,0 +1,51 @@ +# Copyright (c) 2023 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestFastGelu(TestCase): + + def supported_op_exec(self, input1): + attr = 1.702 + attr_half = attr / 2 + abs_input1 = torch.abs(input1) + numerator = input1 * \ + torch.exp((attr_half * input1) * (input1 - abs_input1)) + denominator = 1.0 + torch.exp(- attr * abs_input1) + output = numerator / denominator + return output.cpu().detach() + + def custom_op_exec(self, input1): + output = ads.common.fast_gelu(input1) + return output.cpu().detach() + + def test_fast_gelu(self, device="npu"): + item = [np.float32, 0, [3, 16, 32]] + _, npu_input = create_common_tensor(item, 0, 100) + + supported_output = self.supported_op_exec(npu_input) + custom_output = self.custom_op_exec(npu_input) + self.assertRtolEqual(supported_output, custom_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_fast_gelu_backward.py b/tests/test_fast_gelu_backward.py new file mode 100644 index 00000000..1c7920ef --- /dev/null +++ b/tests/test_fast_gelu_backward.py @@ -0,0 +1,43 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import numpy as np + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestFastGelu(TestCase): + def npu_op_exec(self, input1): + input1.requires_grad = True + output = ads.common.fast_gelu(input1) + output.backward(torch.ones_like(output)) + output_grad = input1.grad + output_grad = output_grad.to("cpu") + output_grad = output_grad.detach().numpy() + output = output.cpu().detach().numpy() + return output_grad, output + + def test_fastgelu(self, device="npu"): + input1 = torch.tensor([1., 2., 3., 4.]).npu() + exoutputgrad = torch.tensor([1.0677795, 1.0738151, 1.0245483, 1.0064018]) + exoutput = torch.tensor([0.8458, 1.9357, 2.9819, 3.9956]) + outputgrad, output = self.npu_op_exec(input1) + self.assertRtolEqual(exoutputgrad.numpy(), outputgrad) + self.assertRtolEqual(exoutput.numpy(), output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_anchor_response_flags.py b/tests/test_npu_anchor_response_flags.py new file mode 100644 index 00000000..b656ee94 --- /dev/null +++ b/tests/test_npu_anchor_response_flags.py @@ -0,0 +1,60 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + + +class TestNpuAnchorResponseFlags(TestCase): + def custom_op_exec(self, gt_bboxes, featmap_size, strides, num_base_anchors): + if gt_bboxes.dtype == torch.float16: + gt_bboxes = gt_bboxes.to(torch.float32) + feat_h, feat_w = featmap_size + gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5) + gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5) + gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / strides[0]).int() + gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / strides[1]).int() + gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x + responsible_grid = torch.zeros(feat_h * feat_w, dtype=torch.uint8).npu() + gt_bboxes_grid_idx = gt_bboxes_grid_idx.long() + responsible_grid[gt_bboxes_grid_idx] = 1 + responsible_grid = responsible_grid[:, None].expand( + responsible_grid.size(0), num_base_anchors).contiguous().view(-1) + return responsible_grid.cpu().numpy() + + def npu_op_exec(self, input_npu, featmap_size, strides, num_base_anchors): + out = ads.common.npu_anchor_response_flags(input_npu, featmap_size, strides, num_base_anchors) + out = out.cpu().numpy() + return out + + def test_npu_anchor_response_flags(self): + shape_format = [ + [[np.float32, -1, [100, 4]], [60, 60], [2, 2], 9], + [[np.float16, -1, [200, 4]], [10, 10], [32, 32], 3], + [[np.float16, -1, [500, 4]], [32, 32], [16, 16], 5] + ] + for item in shape_format: + _, npu_input = create_common_tensor(item[0], 0, 100) + custom_output = self.custom_op_exec(npu_input, *item[1:]) + npu_output = self.npu_op_exec(npu_input, *item[1:]) + self.assertRtolEqual(custom_output, npu_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_bounding_box_decode.py b/tests/test_npu_bounding_box_decode.py new file mode 100644 index 00000000..248fe36c --- /dev/null +++ b/tests/test_npu_bounding_box_decode.py @@ -0,0 +1,107 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestBoundingBoxDecode(TestCase): + def npu_bounding_box_decode(self, rois, deltas, means0, means1, means2, means3, + stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip): + means = [means0, means1, means2, means3] + stds = [stds0, stds1, stds2, stds3] + means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4) + stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4) + denorm_deltas = deltas * stds + means + + dx = denorm_deltas[:, 0::4] + dy = denorm_deltas[:, 1::4] + dw = denorm_deltas[:, 2::4] + dh = denorm_deltas[:, 3::4] + max_ratio = torch.abs(torch.log(torch.tensor(wh_ratio_clip))) + + dw = torch.clamp(dw, min=-max_ratio, max=max_ratio) + dh = torch.clamp(dh, min=-max_ratio, max=max_ratio) + + ax = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx) + ay = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy) + aw = (rois[:, 2] - rois[:, 0] * 0.5).unsqueeze(1).expand_as(dw) + ah = (rois[:, 3] - rois[:, 1] * 0.5).unsqueeze(1).expand_as(dh) + + pw = aw * dw.exp() + ph = ah * dh.exp() + px = torch.addcmul(ax, 1, aw, dx) + py = torch.addcmul(ay, 1, ah, dy) + + x1 = px - pw * 0.5 + 0.5 + y1 = py - ph * 0.5 + 0.5 + x2 = px + pw * 0.5 - 0.5 + y2 = py + ph * 0.5 - 0.5 + + if max_shape is not None: + x1 = torch.clamp(x1, min=0, max=(max_shape[1] - 1)) + y1 = torch.clamp(y1, min=0, max=(max_shape[0] - 1)) + x2 = torch.clamp(x2, min=0, max=(max_shape[1] - 1)) + y2 = torch.clamp(y2, min=0, max=(max_shape[0] - 1)) + boxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas) + return boxes + + def custom_op_exec(self, rois, deltas, means0, means1, means2, means3, + stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip): + output = self.npu_bounding_box_decode(rois, deltas, means0, means1, + means2, means3, stds0, stds1, + stds2, stds3, max_shape, wh_ratio_clip) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec(self, rois, deltas, means0, means1, means2, means3, + stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip): + output = ads.common.npu_bounding_box_decode(rois, deltas, means0, means1, + means2, means3, stds0, stds1, + stds2, stds3, max_shape, wh_ratio_clip) + output = output.to("cpu") + output = output.numpy() + return output + + def test_decode_shape_format_fp32(self): + input1 = torch.tensor([[1., 2., 3., 4.], [3., 4., 5., 6.]], + dtype=torch.float32).to("npu") + input2 = torch.tensor([[5., 6., 7., 8.], [7., 8., 9., 6.]], + dtype=torch.float32).to("npu") + + npu_output = self.npu_op_exec(input1, input2, 0, 0, 0, 0, + 1, 1, 1, 1, (10, 10), 0.1) + custom_output = self.custom_op_exec(input1, input2, 0, 0, 0, 0, + 1, 1, 1, 1, (10, 10), 0.1) + self.assertRtolEqual(npu_output, custom_output) + + def test_decode_shape_format_fp16(self): + input1_fp16 = torch.tensor([[1., 2., 3., 4.], [3., 4., 5., 6.]], + dtype=torch.float16).to("npu") + input2_fp16 = torch.tensor([[5., 6., 7., 8.], [7., 8., 9., 6.]], + dtype=torch.float16).to("npu") + + npu_output = self.npu_op_exec(input1_fp16, input2_fp16, 0, 0, 0, 0, + 1, 1, 1, 1, (10, 10), 0.1) + custom_output = self.custom_op_exec(input1_fp16, input2_fp16, 0, 0, 0, 0, + 1, 1, 1, 1, (10, 10), 0.1) + self.assertRtolEqual(npu_output, custom_output) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_bounding_box_encode.py b/tests/test_npu_bounding_box_encode.py new file mode 100644 index 00000000..b6a406b6 --- /dev/null +++ b/tests/test_npu_bounding_box_encode.py @@ -0,0 +1,92 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +import ads.common + + +class TestBoundingBoxEncode(TestCase): + def npu_bounding_box_encode(self, anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3): + means = [means0, means1, means2, means3] + stds = [stds0, stds1, stds2, stds3] + px = (anchor_box[..., 0] + anchor_box[..., 2]) * 0.5 + py = (anchor_box[..., 1] + anchor_box[..., 3]) * 0.5 + pw = anchor_box[..., 2] - anchor_box[..., 0] + 1.0 + ph = anchor_box[..., 3] - anchor_box[..., 1] + 1.0 + + gx = (ground_truth_box[..., 0] + ground_truth_box[..., 2]) * 0.5 + gy = (ground_truth_box[..., 1] + ground_truth_box[..., 3]) * 0.5 + gw = ground_truth_box[..., 2] - ground_truth_box[..., 0] + 1.0 + gh = ground_truth_box[..., 3] - ground_truth_box[..., 1] + 1.0 + + eps = 1e-7 + dx = (gx - px) / (pw + eps) + dy = (gy - py) / (ph + eps) + dw = torch.log(torch.abs(gw) / torch.abs(pw + eps)) + dh = torch.log(torch.abs(gh) / torch.abs(ph + eps)) + deltas = torch.stack([dx, dy, dw, dh], dim=-1) + + means = deltas.new_tensor(means).unsqueeze(0) + stds = deltas.new_tensor(stds).unsqueeze(0) + deltas = deltas.sub_(means) .div_(stds) + + return deltas + + def custom_op_exec(self, anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3): + output = self.npu_bounding_box_encode(anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec(self, anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3): + output = ads.common.npu_bounding_box_encode(anchor_box, ground_truth_box, means0, means1, + means2, means3, stds0, stds1, stds2, stds3) + output = output.to("cpu") + output = output.numpy() + return output + + def test_encode_shape_format_fp32(self): + input1 = torch.tensor([[1., 2., 3., 4.], [3., 4., 5., 6.]], + dtype=torch.float32).to("npu") + input2 = torch.tensor([[5., 6., 7., 8.], [7., 8., 9., 6.]], + dtype=torch.float32).to("npu") + + npu_output = self.npu_op_exec(input1, input2, 0, 0, 0, 0, + 0.1, 0.1, 0.2, 0.2) + custom_output = self.custom_op_exec(input1, input2, 0, 0, 0, 0, + 0.1, 0.1, 0.2, 0.2) + self.assertRtolEqual(npu_output, custom_output, 1e-3) + + def test_encode_shape_format_fp16(self): + input1_fp16 = torch.tensor([[1., 2., 3., 4.], [3., 4., 5., 6.]], + dtype=torch.float16).to("npu") + input2_fp16 = torch.tensor([[5., 6., 7., 8.], [7., 8., 9., 6.]], + dtype=torch.float16).to("npu") + + npu_output = self.npu_op_exec(input1_fp16, input2_fp16, 0, 0, 0, 0, + 0.1, 0.1, 0.2, 0.2) + custom_output = self.custom_op_exec(input1_fp16, input2_fp16, 0, 0, 0, 0, + 0.1, 0.1, 0.2, 0.2) + self.assertRtolEqual(npu_output, custom_output, 1e-3) + + +if __name__ == "__main__": + run_tests() diff --git a/tests/test_npu_broadcast.py b/tests/test_npu_broadcast.py new file mode 100644 index 00000000..d44badc1 --- /dev/null +++ b/tests/test_npu_broadcast.py @@ -0,0 +1,48 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +import torch_npu +import ads.common +from torch_npu.testing.testcase import TestCase, run_tests + + +class TestNpuBroadcast(TestCase): + def custom_op_exec(self, input1, shape): + output = torch.broadcast_to(input1, shape) + output = output.to("cpu") + output = output.numpy() + return output + + def npu_op_exec(self, input1, size): + output = ads.common.npu_broadcast(input1, size) + output = output.to("cpu") + output = output.numpy() + return output + + def test_npu_broadcast(self): + input1 = [ + torch.tensor([1, 2, 3]).npu(), + torch.tensor([[1], [2], [3]]).npu() + ] + for item in input1: + custom_output = self.custom_op_exec(item, (3, 3)) + npu_output = self.npu_op_exec(item, (3, 3)) + self.assertRtolEqual(custom_output, npu_output) + + +if __name__ == "__main__": + run_tests() -- Gitee From 3890023291baed5c3e30a0e7967fd7c4d3672013 Mon Sep 17 00:00:00 2001 From: zhanhao Date: Wed, 13 Dec 2023 14:36:06 +0800 Subject: [PATCH 3/3] add op --- ads/common/ops/npu_bounding_box_decode.py | 2 +- ads/common/ops/npu_bounding_box_encode.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/ads/common/ops/npu_bounding_box_decode.py b/ads/common/ops/npu_bounding_box_decode.py index 67099d20..6f16d8e4 100644 --- a/ads/common/ops/npu_bounding_box_decode.py +++ b/ads/common/ops/npu_bounding_box_decode.py @@ -8,7 +8,7 @@ class NpuBoundingBodDecodeFunction(Function): @staticmethod def forward(ctx, rois, deltas, means0, means1, means2, means3, - stds0, tds1, stds2, stds3, + stds0, stds1, stds2, stds3, max_shape, wh_ratio_clip): result = ads_c.npu_bounding_box_decode( rois, deltas, diff --git a/ads/common/ops/npu_bounding_box_encode.py b/ads/common/ops/npu_bounding_box_encode.py index 6efac4dd..756e1fe0 100644 --- a/ads/common/ops/npu_bounding_box_encode.py +++ b/ads/common/ops/npu_bounding_box_encode.py @@ -8,7 +8,7 @@ class NpuBoundingBodEncodeFunction(Function): @staticmethod def forward(ctx, anchor_box, ground_truth_box, means0, means1, means2, means3, - stds0, tds1, stds2, stds3): + stds0, stds1, stds2, stds3): result = ads_c.npu_bounding_box_encode( anchor_box, ground_truth_box, means0, means1, means2, means3, -- Gitee