diff --git a/include/csrc/functions.h b/include/csrc/functions.h index ccb8c98c7079a7dcd0b1715131bc66ac1255a833..7ac4a2a64392d780fcd9eb8b4bf2cd3908975ef9 100644 --- a/include/csrc/functions.h +++ b/include/csrc/functions.h @@ -193,7 +193,7 @@ std::tuple nms3d_normal(const at::Tensor& boxes, double std::tuple nms3d(const at::Tensor& boxes, double threshold); -std::tuple nms3d_on_sight(const at::Tensor& boxes, double threshold); +std::tuple nms3d_on_sight(const at::Tensor& boxes, const at::Tensor& threshold); at::Tensor npu_rotated_overlaps(const at::Tensor& self, const at::Tensor& query_boxes, bool trans); diff --git a/kernels/op_host/nms3d_on_sight.cpp b/kernels/op_host/nms3d_on_sight.cpp index a99c0402043699f26414494b61980c658cea002d..88deecdf538ace2d40b71976f06e98cafc299813 100644 --- a/kernels/op_host/nms3d_on_sight.cpp +++ b/kernels/op_host/nms3d_on_sight.cpp @@ -44,11 +44,11 @@ static ge::graphStatus Nms3dOnSightTilingFunc(gert::TilingContext *context) auto boxShape = context->GetInputShape(0)->GetStorageShape(); // [7, N] auto maskShape = context->GetOutputShape(0)->GetStorageShape(); // [N, N_aligned] auto dtype = context->GetInputDesc(0)->GetDataType(); - auto attrs = context->GetAttrs(); + // auto attrs = context->GetAttrs(); - if (attrs == nullptr) { - return ge::GRAPH_FAILED; - } + // if (attrs == nullptr) { + // return ge::GRAPH_FAILED; + // } // 预留fp16的接口支持,目前不支持 if (ge::DT_FLOAT == dtype) { @@ -71,7 +71,7 @@ static ge::graphStatus Nms3dOnSightTilingFunc(gert::TilingContext *context) } else { return ge::GRAPH_FAILED; } - float threshold = *(attrs->GetAttrPointer(0)); + // float threshold = *(attrs->GetAttrPointer(0)); context->SetBlockDim(usedCoreNum); tiling.set_usedCoreNum(usedCoreNum); @@ -79,9 +79,9 @@ static ge::graphStatus Nms3dOnSightTilingFunc(gert::TilingContext *context) tiling.set_loopTime(loopTime); tiling.set_assignBox(assignBox); tiling.set_alignedN(alignedN); - tiling.set_threshold(threshold); - MX_DRIVING_LOGI("Nms3dOnSight tiling: usedCoreNum=%d, boxNum=%d, loopTime=%d, alignedN=%d, threshold=%f, assignBox=%d", - usedCoreNum, boxNum, loopTime, alignedN, threshold, assignBox); + // tiling.set_threshold(threshold); + // MX_DRIVING_LOGI("Nms3dOnSight tiling: usedCoreNum=%d, boxNum=%d, loopTime=%d, alignedN=%d, threshold=%f, assignBox=%d", + // usedCoreNum, boxNum, loopTime, alignedN, threshold, assignBox); // 待拆解功能 tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), @@ -116,12 +116,17 @@ public: .DataType({ge::DT_FLOAT}) .Format({ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("threshold") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); this->Output("mask") .ParamType(REQUIRED) .DataType({ge::DT_INT16}) .Format({ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND}); - this->Attr("threshold").AttrType(REQUIRED).Float(); + // this->Attr("threshold").AttrType(REQUIRED).Float(); this->SetInferShape(ge::Nms3dOnSightInferShape) .SetInferDataType(ge::Nms3dOnSightInferDataType); diff --git a/kernels/op_host/nms3d_on_sight_tiling.h b/kernels/op_host/nms3d_on_sight_tiling.h index 4cad31100caffbcb60e3bce863f6e6dd7558a354..22a2bc98a46fc4c0161fe09d4089612ffb3e367a 100644 --- a/kernels/op_host/nms3d_on_sight_tiling.h +++ b/kernels/op_host/nms3d_on_sight_tiling.h @@ -13,7 +13,7 @@ BEGIN_TILING_DATA_DEF(Nms3dOnSightTilingData) TILING_DATA_FIELD_DEF(uint32_t, loopTime) // loop times TILING_DATA_FIELD_DEF(uint32_t, assignBox) // boxesNum align 256B TILING_DATA_FIELD_DEF(uint32_t, alignedN) // boxesNum align 32B - TILING_DATA_FIELD_DEF(float, threshold) + // TILING_DATA_FIELD_DEF(float, threshold) END_TILING_DATA_DEF; REGISTER_TILING_DATA_CLASS(Nms3dOnSight, Nms3dOnSightTilingData) diff --git a/kernels/op_kernel/nms3d_on_sight.cpp b/kernels/op_kernel/nms3d_on_sight.cpp index 0b204ee6e67756b56f25de8520361d102feb7003..c8ed39085e0a934d50d68f1278f49ed6e02e2e9a 100644 --- a/kernels/op_kernel/nms3d_on_sight.cpp +++ b/kernels/op_kernel/nms3d_on_sight.cpp @@ -15,7 +15,7 @@ class KernelNms3dOnSight { public: __aicore__ inline KernelNms3dOnSight() {} - __aicore__ inline void Init(GM_ADDR boxes, GM_ADDR mask, const Nms3dOnSightTilingData* __restrict tiling_data) + __aicore__ inline void Init(GM_ADDR boxes, GM_ADDR threshold, GM_ADDR mask, const Nms3dOnSightTilingData* __restrict tiling_data) { // 所有的计算在kernel侧以alignedN进行,这样在搬入搬出时都能保证对齐 ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); @@ -23,12 +23,14 @@ public: boxNum = tiling_data->boxNum; alignedN = tiling_data->alignedN; loopTime = tiling_data->loopTime; - threshold = tiling_data->threshold; + // threshold = tiling_data->threshold; assignBox = tiling_data->assignBox; uint32_t core_id = GetBlockIdx(); boxGm.SetGlobalBuffer(reinterpret_cast<__gm__ T*>(boxes), static_cast(alignedN) * 7); + thresholdGm.SetGlobalBuffer(reinterpret_cast<__gm__ float*>(threshold), static_cast(8)); + // threshold_ = thresholdGm[static_cast(0)]; maskGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t*>(mask), static_cast(alignedN) * boxNum); pipe.InitBuffer(inQueueX, BUFFER_NUM_INPUT, assignBox * sizeof(T)); @@ -50,12 +52,17 @@ public: pipe.InitBuffer(distBuf2, alignedN * sizeof(T)); pipe.InitBuffer(fovBuf, alignedN * sizeof(T)); pipe.InitBuffer(selBuf, alignedN * sizeof(T)); + pipe.InitBuffer(thresholdBuf2, 8 * sizeof(float)); // 计算过程中用到的缓存tensor pipe.InitBuffer(comBuf, alignedN * sizeof(T)); pipe.InitBuffer(comXBuf, alignedN * sizeof(T)); pipe.InitBuffer(comYBuf, alignedN * sizeof(T)); pipe.InitBuffer(comRBuf, alignedN * sizeof(T)); + + LocalTensor thresholdLocal = thresholdBuf2.Get(); + DataCopy(thresholdLocal, thresholdGm[static_cast(0)], 8); + threshold_ = thresholdLocal.GetValue(0); } __aicore__ inline void Process() @@ -131,7 +138,7 @@ private: // 声明一个distLocal的tensor[1, N]和threshold的tensor[1, N], 数据类型为T,计算得到的值和threshold(float)进行比较 LocalTensor distLocal = distBuf.Get(); LocalTensor thresholdLocal = thresholdBuf.Get(); - Duplicate(thresholdLocal, static_cast(threshold), alignedN); + Duplicate(thresholdLocal, static_cast(threshold_), alignedN); DistBev(curX, curY, curR, xTemp, yTemp, rTemp, distLocal); @@ -274,13 +281,14 @@ private: TQue inQueueX, inQueueY, inQueueR; TQue outQueueMask; TBuf xBuf, yBuf, rBuf; - TBuf distBuf, thresholdBuf; + TBuf distBuf, thresholdBuf, thresholdBuf2; TBuf comBuf, comXBuf, comYBuf, comRBuf; TBuf maskBuf; TBuf upBuf, downBuf, distBuf1, distBuf2, fovBuf, selBuf; GlobalTensor boxGm; + GlobalTensor thresholdGm; GlobalTensor maskGm; LocalTensor xTemp, yTemp, rTemp; LocalTensor inX, inY, inR; @@ -292,17 +300,17 @@ private: uint32_t tailNum; uint32_t alignedN; uint32_t assignBox; - float threshold; + float threshold_; bool isLastCore; }; -extern "C" __global__ __aicore__ void nms3d_on_sight(GM_ADDR boxes, GM_ADDR mask, GM_ADDR workspace, GM_ADDR tiling) +extern "C" __global__ __aicore__ void nms3d_on_sight(GM_ADDR boxes, GM_ADDR threshold, GM_ADDR mask, GM_ADDR workspace, GM_ADDR tiling) { GET_TILING_DATA(tilingData, tiling); const Nms3dOnSightTilingData* __restrict tilingDevice = &tilingData; if (TILING_KEY_IS(1)) { KernelNms3dOnSight op; - op.Init(boxes, mask, tilingDevice); + op.Init(boxes, threshold, mask, tilingDevice); op.Process(); } } \ No newline at end of file diff --git a/mx_driving/_C/__init__.pyi b/mx_driving/_C/__init__.pyi index 39a5a34b4baefc58bed1e048dadd9711ed48d5cc..3b8fb82ad40b81b4a00edfceda7e6cd117863de4 100644 --- a/mx_driving/_C/__init__.pyi +++ b/mx_driving/_C/__init__.pyi @@ -221,7 +221,7 @@ def npu_prepare_subm_conv3d( ) -> Tuple[torch.Tensor, torch.Tensor]: ... def nms3d_normal(boxes: torch.Tensor, nms_overlap_thresh: float) -> Tuple[torch.Tensor, torch.Tensor]: ... def nms3d(boxes: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]: ... -def nms3d_on_sight(boxes: torch.Tensor, threshold: float) -> Tuple[torch.Tensor, torch.Tensor]: ... +def nms3d_on_sight(boxes: torch.Tensor, threshold: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: ... def npu_rotated_overlaps(self: torch.Tensor, query_boxes: torch.Tensor, trans: bool) -> torch.Tensor: ... def npu_rotated_iou( boxes: torch.Tensor, diff --git a/mx_driving/csrc/Nms3dOnSight.cpp b/mx_driving/csrc/Nms3dOnSight.cpp index f8252cd7bf235344b9b08b8255cd991756374549..654c0bde8d3df2e7248f42ec92f8a06466b41d69 100644 --- a/mx_driving/csrc/Nms3dOnSight.cpp +++ b/mx_driving/csrc/Nms3dOnSight.cpp @@ -17,7 +17,7 @@ #include "csrc/OpApiCommon.h" #include "csrc/functions.h" -std::tuple nms3d_on_sight(const at::Tensor& boxes, double threshold) +std::tuple nms3d_on_sight(const at::Tensor& boxes, const at::Tensor& threshold) { int32_t box_num = boxes.size(0); int32_t data_align = 16; diff --git a/mx_driving/ops/nms3d_on_sight.py b/mx_driving/ops/nms3d_on_sight.py index 7359b17998ab1803bce40ca8d6b6cb4390541430..052f461c5ae740ea888fda6fe05150f77f7e335a 100644 --- a/mx_driving/ops/nms3d_on_sight.py +++ b/mx_driving/ops/nms3d_on_sight.py @@ -15,13 +15,13 @@ import mx_driving._C class Nms3dOnSightFunction(Function): @staticmethod - def forward(ctx, boxes, scores, threshold: float): + def forward(ctx, boxes, scores, threshold): if boxes.shape[1] != 7: raise Exception('Input boxes shape should be (N, 7)') order = scores.sort(0, descending=True)[1] boxes = boxes[order].contiguous() - - keep, num_out = mx_driving._C.nms3d_on_sight(boxes, -threshold**2) + threshold = -torch.pow(threshold, 2) + keep, num_out = mx_driving._C.nms3d_on_sight(boxes, threshold) return order[keep[:num_out].long()].contiguous()