diff --git a/ads/common/__init__.py b/ads/common/__init__.py index ca4f48b01b5dbf8ce013e8be03a890032c33ebc6..d20c765a535e69ec59d355969e54dbda46b6f365 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -30,3 +30,4 @@ from .ops.npu_multi_scale_deformable_attn_function import npu_multi_scale_deform from .ops.dynamic_voxelization import voxelization from .ops.dynamic_voxelization import Voxelization from .ops.nms3d_normal import npu_nms3d_normal +from .ops.npu_nms3d import npu_nms3d diff --git a/ads/common/ops/csrc/Nms3dOpApi.cpp b/ads/common/ops/csrc/Nms3dOpApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..300e67d20953aef86c7f06cd2d47c559670a83b4 --- /dev/null +++ b/ads/common/ops/csrc/Nms3dOpApi.cpp @@ -0,0 +1,17 @@ +#include +#include "csrc/OpApiCommon.h" +#include "functions.h" + +std::tuple nms3d(const at::Tensor &boxes, double threshold) +{ + int32_t box_num = boxes.size(0); + int32_t data_align = 16; + int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + at::Tensor mask = at::empty({ box_num, mask_num }, boxes.options().dtype(at::kShort)); + EXEC_NPU_CMD(aclnnNms3d, boxes, threshold, mask); + + at::Tensor keep = at::zeros({ box_num }, mask.options()); + at::Tensor num_out = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, num_out); + return std::tie(keep, num_out); +} diff --git a/ads/common/ops/csrc/functions.h b/ads/common/ops/csrc/functions.h index cf901d9cc56948605f33ae8be666b10a79e384fb..58c4ad8b39d1722a153d0b35e0d795d2d5f3834d 100644 --- a/ads/common/ops/csrc/functions.h +++ b/ads/common/ops/csrc/functions.h @@ -160,4 +160,6 @@ at::Tensor DynamicVoxelization( const double coorsMinZ); std::tuple nms3d_normal(const at::Tensor &boxes, double nms_overlap_thresh); + +std::tuple nms3d(const at::Tensor &boxes, double threshold); #endif // COMMON_OPS_CSRC_FUNCTIONS_H_ diff --git a/ads/common/ops/csrc/pybind.cpp b/ads/common/ops/csrc/pybind.cpp index 3d2c805cd16c06284a63aeec820fca005e481ffd..cd2eb126419f9f0a94eda032fda7f017e14d2016 100644 --- a/ads/common/ops/csrc/pybind.cpp +++ b/ads/common/ops/csrc/pybind.cpp @@ -91,7 +91,10 @@ void init_common(pybind11::module &m) // dyn_voxelization m.def("dynamic_voxelization", &DynamicVoxelization); - + // nms3d_normal m.def("nms3d_normal", &nms3d_normal); + + // ads_nms3d + m.def("nms3d", &nms3d); } diff --git a/ads/common/ops/kernels/op_host/gather_nms3d_mask_tiling.h b/ads/common/ops/kernels/op_host/gather_nms3d_mask_tiling.h index bf7a47fa6968862409bb4072c77eb6705c3aa3ba..b0e0034337359ffee064043a38f482f5deb3efa3 100644 --- a/ads/common/ops/kernels/op_host/gather_nms3d_mask_tiling.h +++ b/ads/common/ops/kernels/op_host/gather_nms3d_mask_tiling.h @@ -1,18 +1,18 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - */ -#ifndef GATHER_NMS3D_MASK_TILING_H -#define GATHER_NMS3D_MASK_TILING_H - -#include "register/tilingdata_base.h" - -namespace optiling { -BEGIN_TILING_DATA_DEF(GatherNms3dMaskTilingData) - TILING_DATA_FIELD_DEF(uint32_t, box_num); - TILING_DATA_FIELD_DEF(uint32_t, mask_num); -END_TILING_DATA_DEF; - -REGISTER_TILING_DATA_CLASS(GatherNms3dMask, GatherNms3dMaskTilingData) -} - -#endif // GATHER_NMS3D_MASK_TILING_H +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + */ +#ifndef GATHER_NMS3D_MASK_TILING_H +#define GATHER_NMS3D_MASK_TILING_H + +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(GatherNms3dMaskTilingData) + TILING_DATA_FIELD_DEF(uint32_t, box_num); + TILING_DATA_FIELD_DEF(uint32_t, mask_num); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(GatherNms3dMask, GatherNms3dMaskTilingData) +} + +#endif // GATHER_NMS3D_MASK_TILING_H diff --git a/ads/common/ops/kernels/op_host/nms3d.cpp b/ads/common/ops/kernels/op_host/nms3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..79c6b34fd69ceb968daa3d16aa8dc497e1da6060 --- /dev/null +++ b/ads/common/ops/kernels/op_host/nms3d.cpp @@ -0,0 +1,90 @@ +#include "nms3d_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" + +using namespace std; + +namespace optiling { +static ge::graphStatus Nms3dTilingFunc(gert::TilingContext* context) +{ + Nms3dTilingData tiling; + + auto platformInfo = context->GetPlatformInfo(); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + static uint32_t coreNum = ascendcPlatform.GetCoreNumAiv(); + + auto boxShape = context->GetInputShape(0)->GetStorageShape(); + auto maskShape = context->GetOutputShape(0)->GetStorageShape(); + auto dtype = context->GetInputDesc(0)->GetDataType(); + auto attrs = context->GetAttrs(); + uint32_t boxNum = boxShape.GetDim(0); + uint32_t maskNum = maskShape.GetDim(1); + uint32_t dataAlign = 16; + if (ge::DT_FLOAT == dtype) { + context->SetTilingKey(1); + } else if (ge::DT_FLOAT16 == dtype) { + context->SetTilingKey(2); + } else { + return ge::GRAPH_FAILED; + } + + uint32_t usedCoreNum = std::min((boxNum - 1) / dataAlign + 1, coreNum); + uint32_t loopTime = (boxNum - 1) / (usedCoreNum * dataAlign) + 1; + uint32_t tailSum = boxNum - usedCoreNum * (loopTime - 1) * dataAlign; + uint32_t tailNum = (tailSum - 1) % dataAlign + 1; + float nms_overlap_thresh = *(attrs->GetAttrPointer(0)); + + context->SetBlockDim(usedCoreNum); + tiling.set_usedCoreNum(usedCoreNum); + tiling.set_boxNum(boxNum); + tiling.set_loopTime(loopTime); + tiling.set_eachSum(loopTime * dataAlign); + tiling.set_tailSum(tailSum); + tiling.set_tailNum(tailNum); + tiling.set_maskNum(maskNum); + tiling.set_overlapThresh(nms_overlap_thresh); + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + currentWorkspace[0] = 0; + return ge::GRAPH_SUCCESS; +} +} + +namespace ge { +static ge::graphStatus Nms3dInferShape(gert::InferShapeContext* context) +{ + return GRAPH_SUCCESS; +} +} + +namespace ops { +class Nms3d : public OpDef { +public: + explicit Nms3d(const char* name) : OpDef(name) + { + this->Input("boxes") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("mask") + .ParamType(REQUIRED) + .DataType({ge::DT_INT16, ge::DT_INT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("threshold") + .AttrType(REQUIRED) + .Float(); + + this->SetInferShape(ge::Nms3dInferShape); + + this->AICore() + .SetTiling(optiling::Nms3dTilingFunc); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(Nms3d); +} \ No newline at end of file diff --git a/ads/common/ops/kernels/op_host/nms3d_tiling.h b/ads/common/ops/kernels/op_host/nms3d_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..2db9787fe1e8c3c4e177ff3e95b8d2675895af2d --- /dev/null +++ b/ads/common/ops/kernels/op_host/nms3d_tiling.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + */ +#ifndef NMS3D_TILING_H +#define NMS3D_TILING_H + +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(Nms3dTilingData) + TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum) // used cores + TILING_DATA_FIELD_DEF(uint32_t, boxNum) // count of boxes + TILING_DATA_FIELD_DEF(uint32_t, loopTime) // loop times + TILING_DATA_FIELD_DEF(uint32_t, eachSum) // count of each core, = loop_time * 8 + TILING_DATA_FIELD_DEF(uint32_t, tailSum) // count of tail core + TILING_DATA_FIELD_DEF(uint32_t, tailNum) // last time count of tail core + TILING_DATA_FIELD_DEF(uint32_t, maskNum) // mask align 32bit + TILING_DATA_FIELD_DEF(float, overlapThresh) +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(Nms3d, Nms3dTilingData) +} // namespace optiling + +#endif // NMS3D_TILING_H \ No newline at end of file diff --git a/ads/common/ops/kernels/op_kernel/gather_nms3d_mask.cpp b/ads/common/ops/kernels/op_kernel/gather_nms3d_mask.cpp index 70da3e7f6bd1b54d6f05c76135f611331eb59291..e973dccea92a8473e978a583bda25e73d6abc90a 100644 --- a/ads/common/ops/kernels/op_kernel/gather_nms3d_mask.cpp +++ b/ads/common/ops/kernels/op_kernel/gather_nms3d_mask.cpp @@ -1,114 +1,114 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. - * - * This sample is a very basic sample that implements vector add on Ascend plaform. - */ -#include "kernel_operator.h" -using namespace AscendC; -constexpr int32_t BUFFER_NUM = 2; -constexpr int32_t BUF_SIZE_UNIT = 32; -constexpr int32_t NUM_SIZE = 1; - -class KernelGatherNms3dMask { -public: - __aicore__ inline KernelGatherNms3dMask() {} - __aicore__ inline void Init(GM_ADDR mask, GM_ADDR keep, GM_ADDR num_out, GatherNms3dMaskTilingData *tiling_data) - { - ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); - box_num = tiling_data->box_num; - mask_num = tiling_data->mask_num; - - int32_t assign_num = (box_num * sizeof(int16_t) + BUF_SIZE_UNIT - 1) / BUF_SIZE_UNIT; - mask_size = assign_num * BUF_SIZE_UNIT / sizeof(int16_t); - - maskGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t * > (mask), box_num * mask_num); - keepGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t * > (keep), box_num); - numOutGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t * > (num_out), NUM_SIZE); - - pipe.InitBuffer(inQueueMask, BUFFER_NUM, mask_size * sizeof(int16_t)); - pipe.InitBuffer(maskBuf, mask_size * sizeof(int16_t)); - pipe.InitBuffer(keepBuf, mask_size * sizeof(int16_t)); - pipe.InitBuffer(numOutBuf, BUF_SIZE_UNIT); - } - __aicore__ inline void Process() - { - InitCmp(); - for (int32_t i = 0; i < box_num; ++i) { - if (maskTemp.GetValue(i) == 1) { - SaveKeep(i); - CopyIn(i); - Compute(i); - } - } - EndCmp(); - } - -private: - __aicore__ inline void InitCmp() - { - maskTemp = maskBuf.Get(); - keepTemp = keepBuf.Get(); - Duplicate(maskTemp, static_cast(1), mask_size); - Duplicate(keepTemp, static_cast(0), mask_size); - DataCopyParams copyParams{1, static_cast(box_num * sizeof(int16_t)), 0, 0}; - DataCopyPadParams padParams{false, 0, 2, 0}; - DataCopyPad(maskTemp, maskGm, copyParams, padParams); - } - __aicore__ inline void CopyIn(int32_t idx) - { - LocalTensor maskLocal = inQueueMask.AllocTensor(); - Duplicate(maskLocal, static_cast(1), mask_size); - DataCopyParams copyParams{1, static_cast(box_num * sizeof(int16_t)), 0, 0}; - DataCopyPadParams padParams{false, 0, 0, 2}; - DataCopyPad(maskLocal, maskGm[idx * mask_num], copyParams, padParams); - inQueueMask.EnQue(maskLocal); - } - __aicore__ inline void Compute(int32_t idx) - { - LocalTensor maskLocal = inQueueMask.DeQue(); - maskTemp = maskLocal & maskTemp; - pipe_barrier(PIPE_ALL); - inQueueMask.FreeTensor(maskLocal); - } - __aicore__ inline void SaveKeep(int32_t idx) - { - keepTemp.SetValue(keep_num, idx); - keep_num = keep_num + 1; - } - __aicore__ inline void EndCmp() - { - DataCopyParams copyMaskParams{1, static_cast(box_num * sizeof(int16_t)), 0, 0}; - DataCopyPad(keepGm, keepTemp, copyMaskParams); - LocalTensor numOutLocal = numOutBuf.Get(); - numOutLocal.SetValue(0, keep_num); - DataCopyParams copyNumParams{1, static_cast(NUM_SIZE * sizeof(int16_t)), 0, 0}; - DataCopyPad(numOutGm, numOutLocal, copyNumParams); - } - -private: - TPipe pipe; - TQue inQueueMask; - - GlobalTensor maskGm; - GlobalTensor keepGm; - GlobalTensor numOutGm; - - LocalTensor maskTemp; - LocalTensor keepTemp; - - TBuf maskBuf, keepBuf, numOutBuf; - - uint32_t box_num; - uint32_t mask_num; - uint32_t mask_size; - uint32_t keep_num = 0; -}; - -extern "C" __global__ __aicore__ -void gather_nms3d_mask(GM_ADDR mask, GM_ADDR keep, GM_ADDR num_out, GM_ADDR workspace, GM_ADDR tiling) -{ - GET_TILING_DATA(tiling_data, tiling); - KernelGatherNms3dMask op; - op.Init(mask, keep, num_out, &tiling_data); - op.Process(); -} +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2024. All rights reserved. + * + * This sample is a very basic sample that implements vector add on Ascend plaform. + */ +#include "kernel_operator.h" +using namespace AscendC; +constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t BUF_SIZE_UNIT = 32; +constexpr int32_t NUM_SIZE = 1; + +class KernelGatherNms3dMask { +public: + __aicore__ inline KernelGatherNms3dMask() {} + __aicore__ inline void Init(GM_ADDR mask, GM_ADDR keep, GM_ADDR num_out, GatherNms3dMaskTilingData *tiling_data) + { + ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); + box_num = tiling_data->box_num; + mask_num = tiling_data->mask_num; + + int32_t assign_num = (box_num * sizeof(int16_t) + BUF_SIZE_UNIT - 1) / BUF_SIZE_UNIT; + mask_size = assign_num * BUF_SIZE_UNIT / sizeof(int16_t); + + maskGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t * > (mask), box_num * mask_num); + keepGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t * > (keep), box_num); + numOutGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t * > (num_out), NUM_SIZE); + + pipe.InitBuffer(inQueueMask, BUFFER_NUM, mask_size * sizeof(int16_t)); + pipe.InitBuffer(maskBuf, mask_size * sizeof(int16_t)); + pipe.InitBuffer(keepBuf, mask_size * sizeof(int16_t)); + pipe.InitBuffer(numOutBuf, BUF_SIZE_UNIT); + } + __aicore__ inline void Process() + { + InitCmp(); + for (int32_t i = 0; i < box_num; ++i) { + if (maskTemp.GetValue(i) == 1) { + SaveKeep(i); + CopyIn(i); + Compute(i); + } + } + EndCmp(); + } + +private: + __aicore__ inline void InitCmp() + { + maskTemp = maskBuf.Get(); + keepTemp = keepBuf.Get(); + Duplicate(maskTemp, static_cast(1), mask_size); + Duplicate(keepTemp, static_cast(0), mask_size); + DataCopyParams copyParams{1, static_cast(box_num * sizeof(int16_t)), 0, 0}; + DataCopyPadParams padParams{false, 0, 2, 0}; + DataCopyPad(maskTemp, maskGm, copyParams, padParams); + } + __aicore__ inline void CopyIn(int32_t idx) + { + LocalTensor maskLocal = inQueueMask.AllocTensor(); + Duplicate(maskLocal, static_cast(1), mask_size); + DataCopyParams copyParams{1, static_cast(box_num * sizeof(int16_t)), 0, 0}; + DataCopyPadParams padParams{false, 0, 0, 2}; + DataCopyPad(maskLocal, maskGm[idx * mask_num], copyParams, padParams); + inQueueMask.EnQue(maskLocal); + } + __aicore__ inline void Compute(int32_t idx) + { + LocalTensor maskLocal = inQueueMask.DeQue(); + maskTemp = maskLocal & maskTemp; + pipe_barrier(PIPE_ALL); + inQueueMask.FreeTensor(maskLocal); + } + __aicore__ inline void SaveKeep(int32_t idx) + { + keepTemp.SetValue(keep_num, idx); + keep_num = keep_num + 1; + } + __aicore__ inline void EndCmp() + { + DataCopyParams copyMaskParams{1, static_cast(box_num * sizeof(int16_t)), 0, 0}; + DataCopyPad(keepGm, keepTemp, copyMaskParams); + LocalTensor numOutLocal = numOutBuf.Get(); + numOutLocal.SetValue(0, keep_num); + DataCopyParams copyNumParams{1, static_cast(NUM_SIZE * sizeof(int16_t)), 0, 0}; + DataCopyPad(numOutGm, numOutLocal, copyNumParams); + } + +private: + TPipe pipe; + TQue inQueueMask; + + GlobalTensor maskGm; + GlobalTensor keepGm; + GlobalTensor numOutGm; + + LocalTensor maskTemp; + LocalTensor keepTemp; + + TBuf maskBuf, keepBuf, numOutBuf; + + uint32_t box_num; + uint32_t mask_num; + uint32_t mask_size; + uint32_t keep_num = 0; +}; + +extern "C" __global__ __aicore__ +void gather_nms3d_mask(GM_ADDR mask, GM_ADDR keep, GM_ADDR num_out, GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tiling_data, tiling); + KernelGatherNms3dMask op; + op.Init(mask, keep, num_out, &tiling_data); + op.Process(); +} diff --git a/ads/common/ops/kernels/op_kernel/nms3d.cpp b/ads/common/ops/kernels/op_kernel/nms3d.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a83cdf6fdb4d844dd181978fa391acbfcf57169b --- /dev/null +++ b/ads/common/ops/kernels/op_kernel/nms3d.cpp @@ -0,0 +1,414 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + */ +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_utils.h" +#define M_PI 3.14159265358979323846 /* pi */ + +using namespace AscendC; +constexpr int32_t BUFFER_NUM = 2; +constexpr float EPS = 1e-8; +struct Point { + float x, y; + __aicore__ Point() { + + } + __aicore__ Point(float _x, float _y) { + x = _x; + y = _y; + } + + __aicore__ void set(float _x, float _y) { + x = _x; + y = _y; + } + + __aicore__ Point operator+(const Point& b) const { + return Point(x + b.x, y + b.y); + } + + __aicore__ Point operator-(const Point& b) const { + return Point(x - b.x, y - b.y); + } + +}; +template +class KernelNms3d { +public: + __aicore__ inline KernelNms3d() {} + __aicore__ inline void Init(GM_ADDR boxes, GM_ADDR mask, const Nms3dTilingData* __restrict tiling_data) + { + ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); + usedCoreNum = tiling_data->usedCoreNum; + eachSum = tiling_data->eachSum; + boxNum = tiling_data->boxNum; + tailSum = tiling_data->tailSum; + tailNum = tiling_data->tailNum; + maskNum = tiling_data->maskNum; + loopTime = tiling_data->loopTime; + overlapThresh = tiling_data->overlapThresh; + + uint32_t core_id = GetBlockIdx(); + isLastCore = (core_id == (tiling_data->usedCoreNum - 1)); + + boxGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(boxes), boxNum * 7); + maskGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t*>(mask), maskNum * boxNum); + + pipe.InitBuffer(inQueueCur, BUFFER_NUM, data_align * sizeof(T)); + pipe.InitBuffer(inQueueBox, BUFFER_NUM, data_align * 7 * sizeof(T)); + pipe.InitBuffer(outQueueMask, BUFFER_NUM, data_align * sizeof(int16_t)); + pipe.InitBuffer(oneMask, BUFFER_NUM, dataAlign * sizeof(int16_t)); + + pipe.InitBuffer(comBuf, data_align * sizeof(float)); + pipe.InitBuffer(p1Buf, data_align * sizeof(T)); + pipe.InitBuffer(p2Buf, data_align * sizeof(T)); + pipe.InitBuffer(q1Buf, data_align * sizeof(T)); + pipe.InitBuffer(q2Buf, data_align * sizeof(T)); + pipe.InitBuffer(angleBuf, data_align * sizeof(T)); + pipe.InitBuffer(sinBuf, data_align * sizeof(T)); + pipe.InitBuffer(cosBuf, data_align * sizeof(T)); + pipe.InitBuffer(pointBuf, data_align * sizeof(T)); + pipe.InitBuffer(min1Buf, data_align * sizeof(T)); + pipe.InitBuffer(min2Buf, data_align * sizeof(T)); + pipe.InitBuffer(max1Buf, data_align * sizeof(T)); + pipe.InitBuffer(max2Buf, data_align * sizeof(T)); + if constexpr (sizeof(T) == sizeof(half)) { + pipe.InitBuffer(calcBuf, data_align * 2 * 7 * sizeof(float)); + curTemp = calcBuf.Get(data_align * 2 * 7); + boxTemp = curTemp[8]; + } + } + __aicore__ inline void Process(GM_ADDR boxes, GM_ADDR mask) + { + uint32_t core_id = GetBlockIdx(); + LocalTensor oneLocal = oneMask.AllocTensor(); + Duplicate(oneLocal, static_cast(1), dataAlign); + for (size_t i = 0; i < boxNum; ++i) { + for (size_t j = 0; j < loopTime; ++j) { + uint32_t start = core_id * eachSum + dataAlign * j; + if (i >= start + dataAlign) { + DataCopy(maskGm[i * maskNum + start], oneLocal, dataAlign); + continue; + } + bool is_last = (isLastCore) && (j == loopTime - 1); + CopyIn(i, start, is_last); + Compute(i, start, is_last); + CopyOut(i, start); + } + } + oneMask.FreeTensor(oneLocal); + } + +private: + __aicore__ inline void CopyIn(int32_t cur_box, int32_t com_box, bool is_last) + { + LocalTensor curLocal = inQueueCur.AllocTensor(); + LocalTensor boxLocal = inQueueBox.AllocTensor(); + DataCopy(curLocal, boxGm[cur_box * 7], data_align); + DataCopy(boxLocal, boxGm[com_box * 7], data_align * 7); + inQueueCur.EnQue(curLocal); + inQueueBox.EnQue(boxLocal); + } + __aicore__ inline void Compute(int32_t cur_box, int32_t com_box, bool is_last) + { + uint32_t cmp_num = is_last ? tail_num : data_align; + if constexpr (sizeof(T) == sizeof(half)) { + LocalTensor curLocal = inQueueCur.DeQue(); + LocalTensor boxLocal = inQueueBox.DeQue(); + Cast(curTemp, curLocal, RoundMode::CAST_NONE, data_align); + Cast(boxTemp, boxLocal, RoundMode::CAST_NONE, 7 * data_align); + inQueueCur.FreeTensor(curLocal); + inQueueBox.FreeTensor(boxLocal); + } else { + curTemp = inQueueCur.DeQue(); + boxTemp = inQueueBox.DeQue(); + } + + PipeBarrier(); + LocalTensor outLocal = outQueueMask.AllocTensor(); + for (size_t i = 0; i < cmp_num; i++) { + if (cur_box >= com_box + i) { + outLocal.SetValue(i, 1); + continue; + } + LocalTensor comLocal = comBuf.Get(); + for (size_t k = 0; k < 7; k++) { + comLocal.SetValue(k, static_cast(boxTemp.GetValue(i * 7 + k))); + } + auto flag = iou_bev(curTemp, comLocal); + if (flag > overlap_thresh) + outLocal.SetValue(i, 0); + else + outLocal.SetValue(i, 1); + } + PipeBarrier(); + outQueueMask.EnQue(outLocal); + if constexpr (sizeof(T) != sizeof(half)) { + inQueueCur.FreeTensor(curTemp); + inQueueBox.FreeTensor(boxTemp); + } + } + __aicore__ inline void CopyOut(int32_t cur_box, int32_t com_box) + { + LocalTensor outLocal = outQueueMask.DeQue(); + DataCopy(maskGm[cur_box * maskNum + com_box], outLocal, dataAlign); + outQueueMask.FreeTensor(outLocal); + } + +private: + __aicore__ inline float cross(const Point &a, const Point &b) { + return a.x * b.y - a.y * b.x; + } + + __aicore__ inline float cross(const Point &p1, const Point &p2, const Point &p0 ) { + return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); + } + + __aicore__ int check_rect_cross(const Point &p1, const Point &p2, const Point &q1, const Point &q2 ) { + int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) && + min(q1.x, q2.x) <= max(p1.x, p2.x) && + min(p1.y, p2.y) <= max(q1.y, q2.y) && + min(q1.y, q2.y) <= max(p1.y, p2.y); + return ret; + } + + __aicore__ inline int check_in_box2d(const LocalTensor &box, const Point &p) { + const float MARGIN = 1e-2; + float center_x = box.GetValue(0); + float center_y = box.GetValue(1); + LocalTensor angleLocal = angleBuf.Get(); + LocalTensor sinLocal = sinBuf.Get(); + LocalTensor cosLocal = cosBuf.Get(); + angleLocal.SetValue(0, -box.GetValue(6)); + Sin(sinLocal, angleLocal); + Cos(cosLocal, angleLocal); + float angle_cos = cosLocal.GetValue(0), angle_sin = sinLocal.GetValue(0); + float rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin); + float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos; + + return (abs(rot_x) < box.GetValue(3) / 2 + MARGIN && abs(rot_y) < box.GetValue(4) / 2 + MARGIN); + } + + __aicore__ inline int intersection(const Point &p1, const Point &p0, const Point &q1, const Point &q0, Point &ans_point) { + if (check_rect_cross(p0, p1, q0, q1) == 0) { + return 0; + } + float s1 = cross(q0, p1, p0); + float s2 = cross(p1, q1, p0); + float s3 = cross(p0, q1, q0); + float s4 = cross(q1, p1, q0); + if (!(s1 * s2 > static_cast(0.0) && s3 * s4 > static_cast(0.0))){ + return 0; + } + float s5 = cross(q1, p1, p0); + if (abs(s5 - s1) > EPS) { + ans_point.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1); + ans_point.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1); + } else { + float a0 = p0.y - p1.y, b0 = p1.x - p0.x, c0 = p0.x * p1.y - p1.x * p0.y; + float a1 = q0.y - q1.y, b1 = q1.x - q0.x, c1 = q0.x * q1.y - q1.x * q0.y; + float D = a0 * b1 - a1 * b0; + + ans_point.x = (b0 * c1 - b1 * c0) / D; + ans_point.y = (a1 * c0 - a0 * c1) / D; + } + + return 1; + } + + __aicore__ inline void rotate_around_center(const Point ¢er, + const float angle_cos, + const float angle_sin, Point &p) { + float new_x = + (p.x - center.x) * angle_cos - (p.y - center.y) * angle_sin + center.x; + float new_y = + (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; + p.set(new_x, new_y); + } + + __aicore__ inline int point_cmp(const Point &a, const Point &b, const Point ¢er) { + return math_atan2(a.y - center.y, a.x - center.x) > math_atan2(b.y - center.y, b.x - center.x); + } + + __aicore__ inline float math_atan2(float a, float b) { + float atan2val; + if (b > 0) { + atan2val = math_atan(a / b); + } else if ((b < 0) && (a >= 0)) { + atan2val = math_atan(a / b) + static_cast(M_PI); + } else if ((b < 0) && (a < 0)) { + atan2val = math_atan(a / b) - static_cast(M_PI); + } else if ((b == 0) && (a > 0)) { + atan2val = static_cast(M_PI) / 2; + } else if ((b == 0) && (a < 0)) { + atan2val = 0 - (static_cast(M_PI) / 2); + } else if ((b == 0) && (a == 0)) { + atan2val = 1000; + } + return atan2val; + } + + __aicore__ inline float math_atan(const float x) { + LocalTensor angleLocal = angleBuf.Get(); + LocalTensor atanLocal = sinBuf.Get(); + angleLocal.SetValue(0, x); + Atan(atanLocal, angleLocal, 1); + return atanLocal.GetValue(0); + } + + __aicore__ inline float box_overlap(const LocalTensor &boxATensor, const LocalTensor &boxBTensor) { + // params box_a: [x, y, z, dx, dy, dz, heading] + // params box_b: [x, y, z, dx, dy, dz, heading] + + float a_angle = boxATensor.GetValue(6), b_angle = boxBTensor.GetValue(6); + float a_dx_half = boxATensor.GetValue(3) / 2, b_dx_half = boxBTensor.GetValue(3) / 2, + a_dy_half = boxATensor.GetValue(4) / 2, b_dy_half = boxBTensor.GetValue(4) / 2; + float a_x1 = boxATensor.GetValue(0) - a_dx_half, a_y1 = boxATensor.GetValue(1) - a_dy_half; + float a_x2 = boxATensor.GetValue(0) + a_dx_half, a_y2 = boxATensor.GetValue(1) + a_dy_half; + float b_x1 = boxBTensor.GetValue(0) - b_dx_half, b_y1 = boxBTensor.GetValue(1) - b_dy_half; + float b_x2 = boxBTensor.GetValue(0) + b_dx_half, b_y2 = boxBTensor.GetValue(1) + b_dy_half; + + Point center_a(boxATensor.GetValue(0), boxATensor.GetValue(1)); + Point center_b(boxBTensor.GetValue(0), boxBTensor.GetValue(1)); + + Point box_a_corners[5]; + box_a_corners[0].set(a_x1, a_y1); + box_a_corners[1].set(a_x2, a_y1); + box_a_corners[2].set(a_x2, a_y2); + box_a_corners[3].set(a_x1, a_y2); + + Point box_b_corners[5]; + box_b_corners[0].set(b_x1, b_y1); + box_b_corners[1].set(b_x2, b_y1); + box_b_corners[2].set(b_x2, b_y2); + box_b_corners[3].set(b_x1, b_y2); + + // get oriented corners + LocalTensor angleLocal = angleBuf.Get(); + LocalTensor sinLocal = sinBuf.Get(); + LocalTensor cosLocal = cosBuf.Get(); + angleLocal.SetValue(0, a_angle); + angleLocal.SetValue(1, b_angle); + Sin(sinLocal, angleLocal); + Cos(cosLocal, angleLocal); + float a_angle_cos = cosLocal.GetValue(0), a_angle_sin = sinLocal.GetValue(0); + float b_angle_cos = cosLocal.GetValue(1), b_angle_sin = sinLocal.GetValue(1); + + for (int k = 0; k < 4; k++) { + rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]); + rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]); + } + + box_a_corners[4] = box_a_corners[0]; + box_b_corners[4] = box_b_corners[0]; + + // get intersection of lines + Point cross_points[16]; + Point poly_center; + int cnt = 0, flag = 0; + + poly_center.set(0, 0); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + flag = intersection(box_a_corners[i + 1], box_a_corners[i], + box_b_corners[j + 1], box_b_corners[j], + cross_points[cnt]); + if (flag) { + poly_center = poly_center + cross_points[cnt]; + cnt++; + } + } + } + + // check corners + for (int k = 0; k < 4; k++) { + if (check_in_box2d(boxATensor, box_b_corners[k])) { + poly_center = poly_center + box_b_corners[k]; + cross_points[cnt] = box_b_corners[k]; + cnt++; + } + if (check_in_box2d(boxBTensor, box_a_corners[k])) { + poly_center = poly_center + box_a_corners[k]; + cross_points[cnt] = box_a_corners[k]; + cnt++; + } + } + if (cnt != 0) { + poly_center.x /= cnt; + poly_center.y /= cnt; + } + + // sort the points of polygon + Point temp; + for (int j = 0; j < cnt - 1; j++) { + for (int i = 0; i < cnt - j - 1; i++) { + if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) { + temp = cross_points[i]; + cross_points[i] = cross_points[i + 1]; + cross_points[i + 1] = temp; + } + } + } + + // get the overlap areas + float area = 0; + for (int k = 0; k < cnt - 1; k++) { + area += cross(cross_points[k] - cross_points[0], + cross_points[k + 1] - cross_points[0]); + } + + return abs(area) / static_cast(2.0); + } + + __aicore__ inline float iou_bev(const LocalTensor &boxATensor, const LocalTensor &boxBTensor) { + // params box_a: [x, y, z, dx, dy, dz, heading] + // params box_b: [x, y, z, dx, dy, dz, heading] + float sa = boxATensor.GetValue(3) * boxATensor.GetValue(4); + float sb = boxBTensor.GetValue(3) * boxBTensor.GetValue(4); + float s_overlap = box_overlap(boxATensor, boxBTensor); + return s_overlap / max(sa + sb - s_overlap, EPS); + } + + + +private: + TPipe pipe; + TQue inQueueCur, inQueueBox; + TQue outQueueMask, oneMask; + TBuf calcBuf; + TBuf comBuf; + + TBuf p1Buf, p2Buf, q1Buf, q2Buf; + TBuf angleBuf, sinBuf, cosBuf, pointBuf; + TBuf min1Buf, min2Buf, max1Buf, max2Buf; + + GlobalTensor boxGm; + GlobalTensor maskGm; + LocalTensor curTemp, boxTemp; + uint32_t usedCoreNum; + uint32_t loopTime; + uint32_t eachSum; + uint32_t boxNum; + uint32_t tailSum; + uint32_t tailNum; + uint32_t maskNum; + uint32_t dataAlign = 16; + float overlapThresh; + bool isLastCore; +}; + +extern "C" __global__ __aicore__ void nms3d(GM_ADDR boxes, GM_ADDR mask, GM_ADDR workspace, GM_ADDR tiling) { + GET_TILING_DATA(tiling_data, tiling); + const Nms3dTilingData* __restrict tilingDevice = &tilingData; + if (TILING_KEY_IS(1)) { + KernelNms3d op; + op.Init(boxes, mask, tilingDevice); + op.Process(boxes, mask); + } else if (TILING_KEY_IS(2)) { + KernelNms3d op; + op.Init(boxes, mask, tilingDevice); + op.Process(boxes, mask); + } +} \ No newline at end of file diff --git a/ads/common/ops/npu_nms3d.py b/ads/common/ops/npu_nms3d.py new file mode 100644 index 0000000000000000000000000000000000000000..f039ec21d959ecb29a0b4a81f8889a65876398cb --- /dev/null +++ b/ads/common/ops/npu_nms3d.py @@ -0,0 +1,20 @@ +import torch +import torch_npu +from torch.autograd import Function +import ads_c + + +class Nms3dFunction(Function): + @staticmethod + def forward(ctx, boxes, scores, iou_threshold: float): + assert boxes.shape[1] == 7, 'Input boxes shape should be (N, 7)' + if boxes.shape[1] != 7: + raise Exception('Input boxes shape should be (N, 7)') + order = scores.sort(0, descending=True)[1] + boxes = boxes[order].contiguous() + + keep, num_out = ads_c.nms3d(boxes, iou_threshold) + return order[keep[:num_out].long()].contiguous() + + +npu_nms3d = Nms3dFunction.apply diff --git a/tests/torch/test_npu_nms3d.py b/tests/torch/test_npu_nms3d.py new file mode 100644 index 0000000000000000000000000000000000000000..d79c320bf75c43ba78258956403e7613ddd6c5cd --- /dev/null +++ b/tests/torch/test_npu_nms3d.py @@ -0,0 +1,277 @@ +import unittest +import torch +import numpy as np + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +from typing import List +from math import cos, sin, fabs, atan2 +import ads.common + +torch.npu.config.allow_internal_format = False +torch_npu.npu.set_compile_mode(jit_compile=False) +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +EPS = 1e-8 + + +class Point: + def __init__(self, x=0.0, y=0.0): + self.x = x + self.y = y + + def set(self, _x: float, _y: float): + self.x = _x + self.y = _y + + def __add__(self, other): + x = self.x + other.x + y = self.y + other.y + return Point(x, y) + + def __sub__(self, other): + x = self.x - other.x + y = self.y - other.y + return Point(x, y) + + +def cross(p1: Point, p2: Point, p0: Point) -> float: + if p0 is None: + return p1.x * p2.y - p1.y * p2.x + return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y) + + +def check_rect_cross(p1: Point, p2: Point, q1: Point, q2: Point) -> bool: + ret = min(p1.x, p2.x) <= max(q1.x, q2.x) and \ + min(q1.x, q2.x) <= max(p1.x, p2.x) and \ + min(p1.y, p2.y) <= max(q1.y, q2.y) and \ + min(q1.y, q2.y) <= max(p1.y, p2.y) + return ret + + +def box_overlap(box_a: List[float], box_b: List[float]): + a_angle = box_a[6] + b_angle = box_b[6] + a_dx_half = box_a[3] / 2 + b_dx_half = box_b[3] / 2 + a_dy_half = box_a[4] / 2 + b_dy_half = box_b[4] / 2 + a_x1 = box_a[0] - a_dx_half + a_y1 = box_a[1] - a_dy_half + a_x2 = box_a[0] + a_dx_half + a_y2 = box_a[1] + a_dy_half + b_x1 = box_b[0] - b_dx_half + b_y1 = box_b[1] - b_dy_half + b_x2 = box_b[0] + b_dx_half + b_y2 = box_b[1] + b_dy_half + + center_a = Point(box_a[0], box_a[1]) + center_b = Point(box_b[0], box_b[1]) + + box_a_corners = [Point()] * 5 + box_a_corners[0] = Point(a_x1, a_y1) + box_a_corners[1] = Point(a_x2, a_y1) + box_a_corners[2] = Point(a_x2, a_y2) + box_a_corners[3] = Point(a_x1, a_y2) + + box_b_corners = [Point()] * 5 + box_b_corners[0] = Point(b_x1, b_y1) + box_b_corners[1] = Point(b_x2, b_y1) + box_b_corners[2] = Point(b_x2, b_y2) + box_b_corners[3] = Point(b_x1, b_y2) + # get oriented corners + a_angle_cos = cos(a_angle) + a_angle_sin = sin(a_angle) + + b_angle_cos = cos(b_angle) + b_angle_sin = sin(b_angle) + for k in range(4): + rotate_point_a = rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]) + box_a_corners[k] = rotate_point_a + rotate_point_b = rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]) + box_b_corners[k] = rotate_point_b + box_a_corners[4] = box_a_corners[0] + box_b_corners[4] = box_b_corners[0] + cross_points = [Point()] * 16 + poly_center = Point(0, 0) + cnt = 0 + flag = 0 + for i in range(4): + for j in range(4): + flag, ans_point = intersection(box_a_corners[i + 1], box_a_corners[i], + box_b_corners[j + 1], box_b_corners[j]) + + cross_points[cnt] = ans_point + + if flag: + poly_center = poly_center + cross_points[cnt] + cnt += 1 + # check corners + for k in range(4): + if check_in_box2d(box_a, box_b_corners[k]): + poly_center = poly_center + box_b_corners[k] + cross_points[cnt] = box_b_corners[k] + cnt += 1 + if check_in_box2d(box_b, box_a_corners[k]): + poly_center = poly_center + box_a_corners[k] + cross_points[cnt] = box_a_corners[k] + cnt += 1 + + if cnt != 0: + poly_center.x /= cnt + poly_center.y /= cnt + # sort the points of polygon + + for j in range(cnt - 1): + for i in range(cnt - j - 1): + if point_cmp(cross_points[i], cross_points[i+1], poly_center): + temp = cross_points[i] + cross_points[i] = cross_points[i + 1] + cross_points[i + 1] = temp + + # get the overlap areas + area = 0 + for k in range(cnt - 1): + v1 = cross_points[k] - cross_points[0] + v2 = cross_points[k + 1] - cross_points[0] + area += cross(v1, v2, None) + return fabs(area) / 2.0 + + +def rotate_around_center(center: Point, angle_cos: float, angle_sin: float, p: Point) -> Point: + new_x = (p.x - center.x) * angle_cos - (p.y - center.y) * angle_sin + center.x + new_y = (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y + p.set(new_x, new_y) + return p + + +def check_in_box2d(box: List[float], p: Point): + # params: box (7) [x, y, z, dx, dy, dz, heading] + MARGIN = 1e-2 + + center_x = box[0] + center_y = box[1] + # rotate the point in the opposite direction of box + angle_cos = cos(-box[6]) + angle_sin = sin(-box[6]) + rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin) + rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + + return fabs(rot_x) < box[3] / 2 + MARGIN and fabs(rot_y) < box[4] / 2 + MARGIN + + +def point_cmp(a: Point, b: Point, center: Point): + return atan2(a.y - center.y, a.x - center.x) > atan2(b.y - center.y, b.x - center.x) + + +def intersection(p1: Point, p0: Point, q1: Point, q0: Point): + ans_point = Point() + # fast exclusion + if check_rect_cross(p0, p1, q0, q1) == 0: + return 0, ans_point + + # check cross standing + s1 = cross(q0, p1, p0) + s2 = cross(p1, q1, p0) + s3 = cross(p0, q1, q0) + s4 = cross(q1, p1, q0) + + if not (s1 * s2 > 0 and s3 * s4 > 0): + return 0, ans_point + + # calculate intersection of two lines + s5 = cross(q1, p1, p0) + if fabs(s5 - s1) > EPS: + ans_point.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1) + ans_point.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1) + + else: + a0 = p0.y - p1.y + b0 = p1.x - p0.x + c0 = p0.x * p1.y - p1.x * p0.y + a1 = q0.y - q1.y + b1 = q1.x - q0.x + c1 = q0.x * q1.y - q1.x * q0.y + + D = a0 * b1 - a1 * b0 + + ans_point.x = (b0 * c1 - b1 * c0) / D + ans_point.y = (a1 * c0 - a0 * c1) / D + return 1, ans_point + + +def iou_bev(box_a: List[float], box_b: List[float]): + # params box_a: [x, y, z, dx, dy, dz, heading] + # params box_b: [x, y, z, dx, dy, dz, heading] + sa = box_a[3] * box_a[4] + sb = box_b[3] * box_b[4] + s_overlap = box_overlap(box_a, box_b) + max_val = max(sa + sb - s_overlap, EPS) + return s_overlap / max_val + + +class TestNms3d(TestCase): + def cpu_to_exec(self, boxes, scores, threshold=0.0): + boxes = boxes.numpy() + scores = scores.numpy() + order = scores.argsort()[::-1][:scores.shape[0]] + boxes = boxes.take(order, 0) + keep, num_out = self.cpu_nms_forward(boxes, threshold) + keep = keep.astype(np.int64) + keep = order[keep[:num_out]] + return torch.from_numpy(keep) + + def cpu_nms_forward(self, boxes, nms_overlap_thresh=0.0): + mask = np.ones(boxes.shape[0], dtype=int) + keep = -np.ones(boxes.shape[0]) + num_out = 0 + for i in range(0, boxes.shape[0]): + if mask[i] == 0: + continue + keep[num_out] = i + num_out += 1 + for j in range(i + 1, boxes.shape[0]): + if iou_bev(boxes[i], boxes[j]) > nms_overlap_thresh: + mask[j] = 0 + return keep, num_out + + def npu_to_exec(self, boxes, scores, threshold=0.0): + keep = ads.common.npu_nms3d(boxes, scores, threshold) + return keep.cpu() + + # def gen_data(self, shape, dtype): + # mask = np.ones(shape) + # for i in range(shape[0]): + # for j in range(i + 1, shape[0]): + # if np.random.randint(0, 2) == 1: + # mask[i, j] = np.random.randint(0, 2) + # mask = mask.astype(np.int16) + # mask_cpu = torch.from_numpy(mask) + # mask_npu = mask_cpu.npu() + # return mask_cpu, mask_npu + + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `MoeTutel` is only supported on 910B, skip this ut!") + def test_nms3d(self): + shape_format = [ + [[np.float32, -1, [5, 7]], [np.float32, -1, [5]], 0.1], + [[np.float32, -1, [500, 7]], [np.float32, -1, [500]], 0.2], + [[np.float32, -1, [1024, 7]], [np.float32, -1, [1024]], 0.3], + [[np.float16, -1, [100, 7]], [np.float32, -1, [100]], 0.1], + [[np.float16, -1, [500, 7]], [np.float32, -1, [500]], 0.2], + [[np.float16, -1, [1024, 7]], [np.float32, -1, [1024]], 0.3], + ] + for item in shape_format: + boxes_cpu, boxes_npu = create_common_tensor(item[0], 0, 10) + scores_cpu, scores_npu = create_common_tensor(item[1], 0, 1) + # torch.save(boxes_cpu, f"boxes_{boxes_cpu.shape[0]}.pt") + # torch.save(scores_cpu, f"scores_{scores_cpu.shape[0]}.pt") + threshold = item[2] + out_cpu = self.cpu_to_exec(boxes_cpu, scores_cpu, threshold) + print("out_cpu: \n", out_cpu) + out_npu = self.npu_to_exec(boxes_npu, scores_npu, threshold) + print("out_npu: \n", out_npu) + self.assertRtolEqual(out_cpu, out_npu) + + +if __name__ == '__main__': + run_tests()