From b4757a7ba40cf8e751b29ba7a1813989f9cf572f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=BF=9E=E4=B9=83=E6=B6=B5?= Date: Wed, 10 Sep 2025 22:07:45 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Ecylinder=5Fquery=20=E7=AE=97?= =?UTF-8?q?=E5=AD=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/csrc/functions.h | 3 + kernels/op_host/cylinder_query.cpp | 233 ++++++++++++++++ kernels/op_host/cylinder_query_tiling.h | 27 ++ kernels/op_kernel/cylinder_query.cpp | 340 ++++++++++++++++++++++++ mx_driving/_C/__init__.pyi | 10 + mx_driving/__init__.py | 2 + mx_driving/csrc/CylinderQuery.cpp | 51 ++++ mx_driving/csrc/pybind.cpp | 3 + mx_driving/ops/cylinder_query.py | 42 +++ tests/torch/test_cylinder_query.py | 264 ++++++++++++++++++ 10 files changed, 975 insertions(+) create mode 100644 kernels/op_host/cylinder_query.cpp create mode 100644 kernels/op_host/cylinder_query_tiling.h create mode 100644 kernels/op_kernel/cylinder_query.cpp create mode 100644 mx_driving/csrc/CylinderQuery.cpp create mode 100644 mx_driving/ops/cylinder_query.py create mode 100644 tests/torch/test_cylinder_query.py diff --git a/include/csrc/functions.h b/include/csrc/functions.h index 27f71309..f2bd9e2d 100644 --- a/include/csrc/functions.h +++ b/include/csrc/functions.h @@ -304,4 +304,7 @@ at::Tensor graph_softmax(const at::Tensor& src, const at::Tensor& index, int N); at::Tensor graph_softmax_grad(const at::Tensor& index, const at::Tensor& softmax_out, const at::Tensor& grad_output, int32_t node_num); +at::Tensor cylinder_query(double radius, double hmin, double hmax, int nsample, const at::Tensor& new_xyz, + const at::Tensor& xyz, const at::Tensor& rot); + #endif // CSRC_FUNCTIONS_H_ diff --git a/kernels/op_host/cylinder_query.cpp b/kernels/op_host/cylinder_query.cpp new file mode 100644 index 00000000..773569ff --- /dev/null +++ b/kernels/op_host/cylinder_query.cpp @@ -0,0 +1,233 @@ +#include "ge/utils.h" +#include "cylinder_query_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/tiling_api.h" +#include "tiling/platform/platform_ascendc.h" + +#define Ceil32(num) (((num) + 31) / 32 * 32) + +namespace { + constexpr uint32_t BUFFER_NUM = 1; + constexpr uint32_t FLOAT_BYTE_SIZE = 4; + + // 输入Tensor的下标 + constexpr uint32_t INPUT_NEW_XYZ_IDX = 0; + + // Attr下标 + const size_t BATCH_SIZE_INDEX = 0; + const size_t POINT_CLOUD_SIZE_INDEX = 1; + const size_t QUERY_POINT_SIZE_INDEX = 2; + const size_t RADIUS_INDEX = 3; + const size_t HMIN_INDEX = 4; + const size_t HMAX_INDEX = 5; + const size_t NSAMPLE_INDEX = 6; + + // 输出Tensor的下标 + constexpr uint32_t OUTPUT_QUERY_RES_IDX = 0; + + // 最小数据块中的点数 / 字节数 + constexpr uint32_t BLOCK_POINT_SIZE = 8; + constexpr uint32_t BLOCK_BYTE_SIZE = 96; + + constexpr uint32_t QUERY_RES_DIM_NUM = 3; +} + +namespace ge { +static ge::graphStatus InferShapeForCylinderQuery(gert::InferShapeContext* context) +{ + const gert::RuntimeAttrs *attr = context->GetAttrs(); + gert::Shape *queryResShape = context->GetOutputShape(OUTPUT_QUERY_RES_IDX); + if (queryResShape == nullptr) { + return ge::GRAPH_FAILED; + } + + auto batchPtr = attr->GetAttrPointer(BATCH_SIZE_INDEX); + auto queryPointSize = attr->GetAttrPointer(QUERY_POINT_SIZE_INDEX); + auto nsamplePtr = attr->GetAttrPointer(NSAMPLE_INDEX); + auto pointCloudSizePtr = attr->GetAttrPointer(POINT_CLOUD_SIZE_INDEX); + + if (!batchPtr || !queryPointSize || !nsamplePtr || !pointCloudSizePtr) { + return ge::GRAPH_FAILED; + } + + queryResShape->SetDimNum(QUERY_RES_DIM_NUM); + *queryResShape = {*batchPtr, *queryPointSize, *pointCloudSizePtr}; + + return GRAPH_SUCCESS; +} + +static ge::graphStatus InferDataTypeForCylinderQuery(gert::InferDataTypeContext* context) +{ + context->SetOutputDataType(OUTPUT_QUERY_RES_IDX, ge::DataType::DT_FLOAT); + return GRAPH_SUCCESS; +} +} + +namespace optiling { +static ge::graphStatus TilingForCylinderQuery(gert::TilingContext* context) +{ + CylinderQueryTilingData tiling; + + if (context == nullptr) { + return ge::GRAPH_FAILED; + } + + // 硬件信息 + auto platformInfo = context->GetPlatformInfo(); + CHECK_NULLPTR(platformInfo); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + + // 输入属性 + const gert::RuntimeAttrs *attr = context->GetAttrs(); // 属性 + uint32_t B = *(attr->GetAttrPointer(BATCH_SIZE_INDEX)); + uint32_t N = *(attr->GetAttrPointer(POINT_CLOUD_SIZE_INDEX)); + uint32_t M = *(attr->GetAttrPointer(QUERY_POINT_SIZE_INDEX)); + float radius = *(attr->GetAttrPointer(RADIUS_INDEX)); + float hmin = *(attr->GetAttrPointer(HMIN_INDEX)); + float hmax = *(attr->GetAttrPointer(HMAX_INDEX)); + uint32_t nsample = *(attr->GetAttrPointer(NSAMPLE_INDEX)); + + // 计算单次可处理数据量最大值 + uint64_t ubSize; + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize); + uint32_t xyzBlockNum = (N + BLOCK_POINT_SIZE - 1) / BLOCK_POINT_SIZE; // 点云的总块数 + uint32_t tileBlockNum = (ubSize / 2 - 1000 - 8 * nsample) / 340; // 一次最多可以放入的点云的块数 + uint32_t tileDataNum = (tileBlockNum * BLOCK_BYTE_SIZE) / (FLOAT_BYTE_SIZE * 3); // tileBlockNum中对应的点的数量 + // 计算实际输入的分块情况 + uint32_t inputLengthAlign32 = ((N + BLOCK_POINT_SIZE - 1) / BLOCK_POINT_SIZE) * BLOCK_POINT_SIZE * (FLOAT_BYTE_SIZE * 3); // 向上对齐 + auto aivNum = ascendcPlatform.GetCoreNumAiv(); + if (aivNum == 0) { + return ge::GRAPH_FAILED; + } + aivNum = (aivNum < B * M) ? aivNum : B * M; + aivNum = (aivNum >= 1) ? aivNum : 1; + + uint32_t smallTileNum = xyzBlockNum / tileBlockNum; + uint32_t finalSmallTileNum = (xyzBlockNum % tileBlockNum == 0) ? smallTileNum : smallTileNum + 1; // 遍历点云过程中需要循环的次数 + // 最后一次需要计算的点云点个数 + uint32_t smallTileDataNum = N - (smallTileNum * tileDataNum); + smallTileDataNum = smallTileDataNum == 0 ? tileDataNum : smallTileDataNum; + uint32_t smallTileBlockNum = Ceil(smallTileDataNum, BLOCK_POINT_SIZE); // 最后一次循环中参与计算的元素块数 + smallTileBlockNum = (smallTileBlockNum == 0)? tileBlockNum: smallTileBlockNum; // smallTileBlockNum表示的是最后一次循环数据块的数量,而不是简单的取余操作 + + uint32_t totalQueryPiont = B * M; // 总查询点的数量 + uint32_t totalTask = B * M; // 总的task数量 + uint32_t coreTask = Ceil(totalQueryPiont, aivNum); // 平均每个大core的task任务 + uint32_t bigCoreCount = (totalQueryPiont % aivNum == 0)? aivNum : (totalQueryPiont % aivNum); + uint32_t tailTaskNum = Ceil(totalTask, coreTask); // 尾核的任务数量 + + bool dtype = context->GetInputDesc(INPUT_NEW_XYZ_IDX)->GetDataType() == ge::DT_FLOAT; + + context->SetBlockDim(aivNum); + + tiling.set_batchSize(B); + tiling.set_pointCloudSize(N); + tiling.set_queryPointSize(M); + tiling.set_radius(radius); + tiling.set_hmin(hmin); + tiling.set_hmax(hmax); + tiling.set_nsample(nsample); + + tiling.set_coreTask(coreTask); + tiling.set_tailTaskNum(tailTaskNum); + + tiling.set_bigCoreCount(bigCoreCount); + tiling.set_finalSmallTileNum(finalSmallTileNum); + tiling.set_smallTileDataNum(smallTileDataNum); + tiling.set_tileDataNum(tileDataNum); + tiling.set_tileBlockNum(tileBlockNum); + tiling.set_smallTileBlockNum(smallTileBlockNum); + + if (context->GetRawTilingData() == nullptr) { + return ge::GRAPH_FAILED; + } + + auto platform = platform_ascendc::PlatformAscendC(platformInfo); + + // workspace + uint32_t sysWorkspaceSize = platform.GetLibApiWorkSpaceSize(); + size_t* currentWorkspace = context->GetWorkspaceSizes(1); + CHECK_NULLPTR(currentWorkspace); + currentWorkspace[0] = sysWorkspaceSize; + + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} +} + +namespace ops { +class CylinderQuery : public OpDef { +public: + explicit CylinderQuery(const char* name) : OpDef(name) + { + // Tensor输入 + this->Input("new_xyz") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}) + .AutoContiguous(); + + this->Input("xyz") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}) + .AutoContiguous(); + + this->Input("rot") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}) + .AutoContiguous(); + + this->Input("origin_index") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}) + .AutoContiguous(); + + // 属性输入 + this->Attr("batch_size") + .AttrType(REQUIRED) + .Int(); + this->Attr("point_cloud_size") + .AttrType(REQUIRED) + .Int(); + this->Attr("query_point_size") + .AttrType(REQUIRED) + .Int(); + this->Attr("radius") + .AttrType(REQUIRED) + .Float(); + this->Attr("hmin") + .AttrType(REQUIRED) + .Float(); + this->Attr("hmax") + .AttrType(REQUIRED) + .Float(); + this->Attr("nsample") + .AttrType(REQUIRED) + .Int(); + + // Tensor输出 + this->Output("out") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + + this->SetInferShape(ge::InferShapeForCylinderQuery) + .SetInferDataType(ge::InferDataTypeForCylinderQuery); + + this->AICore().SetTiling(optiling::TilingForCylinderQuery); + this->AICore().AddConfig("ascend910b"); + this->AICore().AddConfig("ascend910_93"); + } +}; +OP_ADD(CylinderQuery); +} // namespace ops \ No newline at end of file diff --git a/kernels/op_host/cylinder_query_tiling.h b/kernels/op_host/cylinder_query_tiling.h new file mode 100644 index 00000000..cd71fcce --- /dev/null +++ b/kernels/op_host/cylinder_query_tiling.h @@ -0,0 +1,27 @@ +#ifndef CYLINDER_QUERY_TILING_H +#define CYLINDER_QUERY_TILING_H +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(CylinderQueryTilingData) +TILING_DATA_FIELD_DEF(uint32_t, batchSize); +TILING_DATA_FIELD_DEF(uint32_t, pointCloudSize); +TILING_DATA_FIELD_DEF(uint32_t, queryPointSize); +TILING_DATA_FIELD_DEF(uint32_t, nsample); +TILING_DATA_FIELD_DEF(float, radius); +TILING_DATA_FIELD_DEF(float, hmin); +TILING_DATA_FIELD_DEF(float, hmax); +TILING_DATA_FIELD_DEF(uint32_t, coreTask); // 每个核心的任务数 +TILING_DATA_FIELD_DEF(uint32_t, bigCoreCount); +TILING_DATA_FIELD_DEF(uint32_t, tailTaskNum); // 尾核的任务数 + +TILING_DATA_FIELD_DEF(uint32_t, finalSmallTileNum); // 遍历点云过程中需要循环的次数 +TILING_DATA_FIELD_DEF(uint32_t, tileDataNum); // 单次搬运点云数据中点的个数,8对齐,将96作为一个数据块,对应八个点(8 * 3 * 4) +TILING_DATA_FIELD_DEF(uint32_t, tileBlockNum); // 单次搬运中数据块个数(最大值) +TILING_DATA_FIELD_DEF(uint32_t, smallTileDataNum); // 最后一次搬运要处理的点云点个数 +TILING_DATA_FIELD_DEF(uint32_t, smallTileBlockNum); // 最后一次搬运中数据块的个数 + +END_TILING_DATA_DEF; +REGISTER_TILING_DATA_CLASS(CylinderQuery, CylinderQueryTilingData) +} // namespace optiling +#endif // CYLINDER_QUERY_TILING_H diff --git a/kernels/op_kernel/cylinder_query.cpp b/kernels/op_kernel/cylinder_query.cpp new file mode 100644 index 00000000..ccf8e4ef --- /dev/null +++ b/kernels/op_kernel/cylinder_query.cpp @@ -0,0 +1,340 @@ +#include "kernel_operator.h" +#include "boxes_operator_utils.h" + +using namespace AscendC; + +constexpr uint32_t DIMENSION_3D = 3; +constexpr uint32_t ROT_SIZE = 9; +constexpr uint32_t UB_ALIGNED_BYTE_SIZE = 32; +constexpr uint32_t VERTICES_CORR = 2; +constexpr uint32_t INT32_BYTE_SIZE = 4; +constexpr uint32_t OUTPUT_IDX_COUNT = 9; +constexpr uint32_t MASK_ALIGNED = 32; + +constexpr uint32_t VERTICE_XY_ALIGNED = 64; + +constexpr uint32_t POINT_MEM = 12; // 一个点占用12B + +// 最小任务块中的查询点数 / 字节数 +constexpr uint32_t BLOCK_POINT_SIZE = 8; +constexpr uint32_t BLOCK_BYTE_SIZE = 96; +constexpr int32_t GATHER_MASK_NUM = 96; +constexpr int32_t REPEAT_STRIDE_0 = 3; +constexpr int32_t REPEAT_STRIDE_1 = 0; +constexpr int32_t ELE_NUM_PER_REPEAT = 24; + +#define Ceil32(num) (((num) + 31) / 32 * 32) +#define Ceil64(num) (((num) + 63) / 64 * 64) +#define CeilDiv8(num) (((num) + 7) / 8) + +class CylinderQuery { +public: + __aicore__ inline void Init(TPipe *pipe, GM_ADDR newXyz, GM_ADDR xyz, GM_ADDR rot, GM_ADDR origin_index, + GM_ADDR res, const CylinderQueryTilingData* tiling) + { + this->pipe_ = pipe; + this->blkIdx_ = GetBlockIdx(); + InitTiling(tiling); + InitUB(); + InitMask(); + InitGM(newXyz, xyz, rot, origin_index, res); + InitEvent(); + } + __aicore__ inline void Process() + { + for (int i = 0; i < this->coreTask_; i++) { + uint32_t offset = this->taskOffset_ + i; + this->batchIdx_ = offset / this->queryPointSize_; + CopyNewXyzIn(offset); + SetFlag(eventSV_); + WaitFlag(eventSV_); + this->processDataNum_ = this->tileDataNum_; // 本轮运算实际参与计算的元素个数 + + for (int j = 0; j < this->finalSmallTileNum_; j++) { + SetFlag(eventMTE3MTE2_); + WaitFlag(eventMTE3MTE2_); + int resOffset = j * processDataNum_; + this->processBlockNum_ = this->tileBlockNum_; + if (j == finalSmallTileNum_ - 1) { + this->processBlockNum_ = this->smallTileBlockNum_; + this->processDataNum_ = smallTileDataNum_; + } + InitRes(resOffset); + CopyIn(this->tileDataNum_ * j); + SetFlag(eventMTE2V_); + WaitFlag(eventMTE2V_); + Compute(i, j); + SetFlag(eventVMTE3_); + WaitFlag(eventVMTE3_); + CopyOut(offset, this->tileDataNum_ * j, this->processBlockNum_); + } + } + } + +private: + __aicore__ inline void InitTiling(const CylinderQueryTilingData* tiling) + { + this->coreTask_ = tiling->coreTask; + if (blkIdx_ < tiling->bigCoreCount) { + this->taskOffset_ = blkIdx_ * coreTask_; + } else { + this->taskOffset_ = tiling->bigCoreCount * coreTask_ + + (blkIdx_ - tiling->bigCoreCount) * (coreTask_ - 1); + this->coreTask_ = this->coreTask_ - 1; + } + this->tileBlockNum_ = tiling->tileBlockNum; + this->smallTileBlockNum_ = tiling->smallTileBlockNum; + // copy tiling数据与计算 + this->radius2_ = tiling->radius * tiling->radius; + this->hmin_ = tiling->hmin; + this->hmax_ = tiling->hmax; + this->nsample_ = tiling->nsample; + this->pointCloudSize_ = tiling->pointCloudSize; + this->queryPointSize_ = tiling->queryPointSize; + this->finalSmallTileNum_ = tiling->finalSmallTileNum; + this->smallTileDataNum_ = tiling->smallTileDataNum; // 最后一次循环计算的元素个数 + this->tileDataNum_ = tiling->tileDataNum; // 每次循环计算的元素个数 + } + + // 将cpu侧ternsor搬运到kernel侧 + __aicore__ inline void InitGM(GM_ADDR newXyz, GM_ADDR xyz, GM_ADDR rot, GM_ADDR origin_index, GM_ADDR res) + { + this->newXyzGm_.SetGlobalBuffer((__gm__ float*) newXyz); + this->xyzGm_.SetGlobalBuffer((__gm__ float*) xyz); + this->rotGm_.SetGlobalBuffer((__gm__ float*) rot); + this->resGm_.SetGlobalBuffer((__gm__ float*) res); + this->originIndexGm_.SetGlobalBuffer((__gm__ float*) origin_index); + } + + __aicore__ inline void InitUB() + { + // total: 336x + 736 + nsample_ * 8 + // 96x + 96x + 9 * FLOAT_BYTE_SIZE = 192x + 9 * FLOAT_BYTE_SIZE + pipe_->InitBuffer(xyzBuf_, 2 * tileBlockNum_ * BLOCK_BYTE_SIZE); // 2(3xf + 95) + pipe_->InitBuffer(xBuf_, 2 * tileBlockNum_ * BLOCK_BYTE_SIZE / DIMENSION_3D); // 2xf + pipe_->InitBuffer(yBuf_, 2 * tileBlockNum_ * BLOCK_BYTE_SIZE / DIMENSION_3D); // 2xf + pipe_->InitBuffer(zBuf_, 2 * tileBlockNum_ * BLOCK_BYTE_SIZE / DIMENSION_3D); // 2xf + pipe_->InitBuffer(rotBuf_, 2 * ROT_SIZE * FLOAT_BYTE_SIZE); // 18f + + // 32x * 5 = 96x + pipe_->InitBuffer(rotXBuf_, 2 * tileBlockNum_ * BLOCK_BYTE_SIZE / DIMENSION_3D); // 2(xf + 255) + pipe_->InitBuffer(rotYBuf_, 2 * tileBlockNum_ * BLOCK_BYTE_SIZE / DIMENSION_3D); // 2(xf + 255) + pipe_->InitBuffer(rotZBuf_, 2 * tileBlockNum_ * BLOCK_BYTE_SIZE / DIMENSION_3D); // 2(xf + 255) + + // 16x + // compare的时候count个元素占用的字节需要时256字节,即float向64个对齐 + pipe_->InitBuffer(maskD2Buf_, 2 * 1 * CeilDiv8(Ceil64(this->tileDataNum_))); // 2((x + 63) /32) + pipe_->InitBuffer(maskHBuf_, 2 * 1 * CeilDiv8(Ceil64(this->tileDataNum_))); // 2((x + 63) /32) + + // nsample_ * INT32_BYTE_SIZE + (x * 8 + 31) * 4 = nsample_ * 4 + 32 * x + 124 + pipe_->InitBuffer(scr1PatternBuf_, 2 * nsample_ * INT32_BYTE_SIZE); + pipe_->InitBuffer(resBuf_, 2 * Ceil32(tileDataNum_) * FLOAT_BYTE_SIZE); + + // 96 * 3 * 2 = 576 + pipe_->InitBuffer(bufferMaskXBuf, 2 * BLOCK_BYTE_SIZE); + pipe_->InitBuffer(bufferMaskYBuf, 2 * BLOCK_BYTE_SIZE); + pipe_->InitBuffer(bufferMaskZBuf, 2 * BLOCK_BYTE_SIZE); + + xyzLocal_ = xyzBuf_.Get(); + xLocal_ = xBuf_.Get(); + yLocal_ = yBuf_.Get(); + zLocal_ = zBuf_.Get(); + rotLocal_ = rotBuf_.Get(); + + rotXLocal_ = rotXBuf_.Get(); + rotYLocal_ = rotYBuf_.Get(); + rotZLocal_ = rotZBuf_.Get(); + + maskD2Local_ = maskD2Buf_.Get(); + maskHLocal_ = maskHBuf_.Get(); + + resLocal_ = resBuf_.Get(); + } + + __aicore__ inline void InitRes(int offset) + { + DataCopyExtParams originIndexDataCopyParams{static_cast(1), Ceil32(this->processDataNum_) * FLOAT_BYTE_SIZE, 0, 0, 0}; + DataCopyPad(resLocal_, originIndexGm_[offset], originIndexDataCopyParams, this->originIndexPadParams_); + } + + __aicore__ inline void InitMask() + { + xPattern_ = bufferMaskXBuf.Get(); + yPattern_ = bufferMaskYBuf.Get(); + zPattern_ = bufferMaskZBuf.Get(); + + // Set pattern values for x to select first element of three + xPattern_.SetValue(0, 0b1001001001001001001001001001001); + xPattern_.SetValue(1, 0b10010010010010010010010010010010); + xPattern_.SetValue(2, 0b100100100100100100100100100100); + + // // Set pattern values for y to select second element of three + yPattern_.SetValue(0, 0b10010010010010010010010010010010); + yPattern_.SetValue(1, 0b100100100100100100100100100100); + yPattern_.SetValue(2, 0b1001001001001001001001001001001); + + // // Set pattern values for z to select third element of three + zPattern_.SetValue(0, 0b100100100100100100100100100100); + zPattern_.SetValue(1, 0b1001001001001001001001001001001); + zPattern_.SetValue(2, 0b10010010010010010010010010010010); + } + + __aicore__ inline void InitEvent() + { + eventMTE2V_ = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_V)); + eventVMTE3_ = static_cast(GetTPipePtr()->FetchEventID(HardEvent::V_MTE3)); + eventMTE2S_ = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE2_S)); + eventMTE3MTE2_ = static_cast(GetTPipePtr()->FetchEventID(HardEvent::MTE3_MTE2)); + eventSV_ = static_cast(GetTPipePtr()->FetchEventID(HardEvent::S_V)); + } + + // 根据查询点偏移量读取数据 + __aicore__ inline void CopyNewXyzIn(uint32_t offset); + __aicore__ inline void CopyIn(uint32_t offset); + __aicore__ inline void CopyOut(uint32_t taskOffset, uint32_t dataOffset, uint32_t taskCount); + __aicore__ inline void Compute(int taskOffset, int xyzOffset); + +private: + TPipe* pipe_; + int32_t eventMTE2V_, eventVMTE3_, eventMTE2S_, eventMTE3MTE2_, eventSV_; + uint32_t blkIdx_; + uint32_t processBlockNum_, processDataNum_; + uint32_t coreTask_, taskOffset_; + uint32_t finalSmallTileNum_, smallTileDataNum_, tileDataNum_; + uint32_t tileBlockNum_, smallTileBlockNum_; + + uint32_t batchIdx_; // 记录查询点在哪一个batch + + GlobalTensor newXyzGm_; + GlobalTensor xyzGm_; + GlobalTensor rotGm_; + GlobalTensor resGm_; + GlobalTensor originIndexGm_; + + float radius2_; + float hmin_, hmax_; + uint32_t nsample_, pointCloudSize_, queryPointSize_; + + TBuf xyzBuf_, xBuf_, yBuf_, zBuf_, rotBuf_, resBuf_, scr1PatternBuf_; + TBuf bufferMaskXBuf, bufferMaskYBuf, bufferMaskZBuf; + TBuf rotXBuf_, rotYBuf_, rotZBuf_; + TBuf maskD2Buf_, maskHBuf_; + + LocalTensor xyzLocal_, xLocal_, yLocal_, zLocal_, rotLocal_; + LocalTensor rotXLocal_, rotYLocal_, rotZLocal_; + LocalTensor maskD2Local_, maskHLocal_; + float r0, r1, r2, r3, r4, r5, r6, r7, r8; + + LocalTensor resLocal_; + LocalTensor xPattern_, yPattern_, zPattern_; + + uint32_t mask_ = 0; + float newX_, newY_, newZ_; + + // 不填充数据 + DataCopyPadExtParams newXyzPadParams_{false, 0, 0, 0}; + DataCopyPadExtParams xyzPadParams_{false, 0, 0, 0}; + DataCopyPadExtParams rotPadParams_{false, 0, 0, 0}; + DataCopyPadExtParams originIndexPadParams_{false, 0, 0, 0}; +}; + + +__aicore__ inline void CylinderQuery::CopyNewXyzIn(uint32_t offset) +{ + this->newX_ = newXyzGm_.GetValue(offset * DIMENSION_3D + 0); + this->newY_ = newXyzGm_.GetValue(offset * DIMENSION_3D + 1); + this->newZ_ = newXyzGm_.GetValue(offset * DIMENSION_3D + 2); + DataCopyExtParams rotDataCopyParams{static_cast(1), ROT_SIZE * FLOAT_BYTE_SIZE, 0, 0, 0}; + // MTE2 + DataCopyPad(rotLocal_, rotGm_[static_cast(offset * ROT_SIZE)], rotDataCopyParams, this->rotPadParams_); + SetFlag(eventMTE2S_); + WaitFlag(eventMTE2S_); + + r0 = rotLocal_.GetValue(0); + r1 = rotLocal_.GetValue(1); + r2 = rotLocal_.GetValue(2); + r3 = rotLocal_.GetValue(3); + r4 = rotLocal_.GetValue(4); + r5 = rotLocal_.GetValue(5); + r6 = rotLocal_.GetValue(6); + r7 = rotLocal_.GetValue(7); + r8 = rotLocal_.GetValue(8); +} + +// offset:查询点偏移量 +__aicore__ inline void CylinderQuery::CopyIn(uint32_t offset) +{ + DataCopyExtParams xyzDataCopyParams{static_cast(1), processDataNum_ * 3 * FLOAT_BYTE_SIZE, 0, 0, 0}; + DataCopyPad(xyzLocal_, xyzGm_[DIMENSION_3D * (static_cast(offset) + this->batchIdx_ * this->pointCloudSize_)], xyzDataCopyParams, this->xyzPadParams_); +} + +__aicore__ inline void CylinderQuery::CopyOut(uint32_t taskOffset, uint32_t dataOffset, uint32_t blockCount) +{ + DataCopyExtParams xyzDataCopyParams{static_cast(1), processDataNum_ * FLOAT_BYTE_SIZE, 0, 0, 0}; + DataCopyPad(resGm_[static_cast(dataOffset) + taskOffset * pointCloudSize_], + resLocal_, xyzDataCopyParams); +} + +__aicore__ inline void CylinderQuery::Compute(int taskOffset, int xyzOffset) +{ + // 需要先将x,y,z分别取出 + uint32_t processDataNumAlign8 = this->processBlockNum_ * BLOCK_POINT_SIZE; + uint32_t processDataNumAlign64 = Ceil64(processDataNumAlign8); + + bool reduceMode = true; + uint32_t mask = BLOCK_POINT_SIZE * 3; + uint8_t src1Pattern = 2; + + uint8_t src0BlockStride = 1; + uint16_t repeatTimes = this->processBlockNum_; + uint8_t src0RepeatStride = REPEAT_STRIDE_0; + uint8_t src1RepeatStride = REPEAT_STRIDE_1; + + uint64_t rsvdCnt = 0; + GatherMask(xLocal_, xyzLocal_, xPattern_, reduceMode, mask, + {1, repeatTimes, src0RepeatStride, src1RepeatStride}, rsvdCnt); + GatherMask(yLocal_, xyzLocal_, yPattern_, reduceMode, mask, + {1, repeatTimes, src0RepeatStride, src1RepeatStride}, rsvdCnt); + GatherMask(zLocal_, xyzLocal_, zPattern_, reduceMode, mask, + {1, repeatTimes, src0RepeatStride, src1RepeatStride}, rsvdCnt); + // 计算相对位置 + + Adds(xLocal_, xLocal_, -newX_, processDataNumAlign8); + Adds(yLocal_, yLocal_, -newY_, processDataNumAlign8); + Adds(zLocal_, zLocal_, -newZ_, processDataNumAlign8); + + Muls(rotXLocal_, xLocal_, r0, processDataNumAlign8); + Axpy(rotXLocal_, yLocal_, r3, processDataNumAlign8); + Axpy(rotXLocal_, zLocal_, r6, processDataNumAlign8); + + Muls(rotYLocal_, xLocal_, r1, processDataNumAlign8); + Axpy(rotYLocal_, yLocal_, r4, processDataNumAlign8); + Axpy(rotYLocal_, zLocal_, r7, processDataNumAlign8); + + Muls(rotZLocal_, xLocal_, r2, processDataNumAlign8); + Axpy(rotZLocal_, yLocal_, r5, processDataNumAlign8); + Axpy(rotZLocal_, zLocal_, r8, processDataNumAlign8); + + // 节省空间服用tensor + Mul(rotYLocal_, rotYLocal_, rotYLocal_, processDataNumAlign8); + Mul(rotZLocal_, rotZLocal_, rotZLocal_, processDataNumAlign8); + Add(rotYLocal_, rotYLocal_, rotZLocal_, processDataNumAlign8); + CompareScalar(maskD2Local_, rotYLocal_, this->radius2_, AscendC::CMPMODE::LT, processDataNumAlign64); + CompareScalar(maskHLocal_, rotXLocal_, this->hmin_, AscendC::CMPMODE::GT, processDataNumAlign64); + And(maskD2Local_, maskD2Local_, maskHLocal_, CeilDiv8(processDataNumAlign64)); + CompareScalar(maskHLocal_, rotXLocal_, this->hmax_, AscendC::CMPMODE::LT, processDataNumAlign64); + And(maskD2Local_, maskD2Local_, maskHLocal_, CeilDiv8(processDataNumAlign64)); + + Select(resLocal_, maskD2Local_, resLocal_, float(int32_t(pointCloudSize_)), AscendC::SELMODE::VSEL_TENSOR_SCALAR_MODE, this->processDataNum_); +} + + +extern "C" __global__ __aicore__ void cylinder_query(GM_ADDR newXyz, GM_ADDR xyz, GM_ADDR rot, GM_ADDR origin_index, GM_ADDR out, + GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(cylinderQueryTiling, tiling); + TPipe pipe; + CylinderQuery op; + op.Init(&pipe, newXyz, xyz, rot, origin_index, out, &cylinderQueryTiling); + op.Process(); +} \ No newline at end of file diff --git a/mx_driving/_C/__init__.pyi b/mx_driving/_C/__init__.pyi index 6a8f30d6..90d49b78 100644 --- a/mx_driving/_C/__init__.pyi +++ b/mx_driving/_C/__init__.pyi @@ -447,6 +447,15 @@ def graph_softmax_grad( softmax_out: torch.Tensor, grad_output: torch.Tensor, ) -> torch.Tensor: ... +def cylinder_query( + radius: float, + hmin: float, + hmax: float, + nsample: int, + new_xyz: torch.Tensor, + xyz: torch.Tensor, + rot: torch.Tensor, +) -> torch.Tensor: ... __all__ = [ "knn", @@ -504,4 +513,5 @@ __all__ = [ "radius", "graph_softmax", "graph_softmax_grad", + "cylinder_query", ] diff --git a/mx_driving/__init__.py b/mx_driving/__init__.py index baa35e40..8630d2b6 100644 --- a/mx_driving/__init__.py +++ b/mx_driving/__init__.py @@ -66,6 +66,7 @@ __all__ = [ "radius", "npu_unique", "graph_softmax", + "cylinder_query", ] import os @@ -135,6 +136,7 @@ from .ops.radius import radius from .ops.min_area_polygons import min_area_polygons from .ops.npu_unique import npu_unique from .ops.graph_softmax import graph_softmax +from .ops.cylinder_query import cylinder_query def _set_env(): diff --git a/mx_driving/csrc/CylinderQuery.cpp b/mx_driving/csrc/CylinderQuery.cpp new file mode 100644 index 00000000..2878a158 --- /dev/null +++ b/mx_driving/csrc/CylinderQuery.cpp @@ -0,0 +1,51 @@ +// Copyright (c) 2024 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "csrc/OpApiCommon.h" +#include "csrc/functions.h" + +at::Tensor cylinder_query(double radius, double hmin, double hmax, int nsample, const at::Tensor& new_xyz, + const at::Tensor& xyz, const at::Tensor& rot) +{ + TORCH_CHECK_NPU(new_xyz); + TORCH_CHECK_NPU(xyz); + TORCH_CHECK_NPU(rot); + TORCH_CHECK(new_xyz.dim() == 3, "new_xyz must be a 3D Tensor, but got: ", new_xyz.dim()); + TORCH_CHECK(xyz.dim() == 3, "xyz must be a 3D Tensor, but got: ", xyz.dim()); + TORCH_CHECK(rot.dim() == 3, "rot must be a 3D Tensor, but got: ", rot.dim()); + + TORCH_CHECK(rot.size(0) == new_xyz.size(0), "The batch sizes of rot and new_xyz must be equal."); + TORCH_CHECK(rot.size(0) == xyz.size(0), "The batch sizes of rot and xyz must be equal."); + + TORCH_CHECK(new_xyz.size(2) == 3, "new_xyz Coordinates should be represented by 3 numbers, bug got: ", new_xyz.size(2)); + TORCH_CHECK(xyz.size(2) == 3, "xyz Coordinates should be represented by 3 numbers, bug got: ", xyz.size(2)); + TORCH_CHECK(rot.size(2) == 9, "The size of the last dimension in rot should be 9, bug got: ", xyz.size(2)); + + TORCH_CHECK(rot.size(1) == new_xyz.size(1), "The number of rot and new_xyz must be equal."); + + TORCH_CHECK(hmin < hmax, "The value of hmin needs to be less than the value of hmax."); + TORCH_CHECK(nsample <= xyz.size(1), "The value of nsample should be greater than the number of points in the tensor xyz."); + TORCH_CHECK(nsample > 0, "The value of nsample should be greater than 0."); + + uint32_t B = static_cast(new_xyz.size(0)); + uint32_t N = static_cast(xyz.size(1)); + uint32_t M = static_cast(new_xyz.size(1)); + + at::Tensor origin_index = at::arange(0, xyz.size(1), new_xyz.options().dtype(at::kFloat)); + at::Tensor out = at::empty({B, M, N}, new_xyz.options()); + EXEC_NPU_CMD(aclnnCylinderQuery, new_xyz, xyz, rot, origin_index, B, N, M, radius, hmin, hmax, nsample, out); + return out; +} \ No newline at end of file diff --git a/mx_driving/csrc/pybind.cpp b/mx_driving/csrc/pybind.cpp index 4727c880..febf8542 100644 --- a/mx_driving/csrc/pybind.cpp +++ b/mx_driving/csrc/pybind.cpp @@ -261,4 +261,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // graph_softmax_grad m.def("graph_softmax_grad", &graph_softmax_grad); + + // cylinder_query + m.def("cylinder_query", &cylinder_query); } diff --git a/mx_driving/ops/cylinder_query.py b/mx_driving/ops/cylinder_query.py new file mode 100644 index 00000000..1e64a689 --- /dev/null +++ b/mx_driving/ops/cylinder_query.py @@ -0,0 +1,42 @@ +""" +Copyright (c) OpenMMLab. All rights reserved. +Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. +Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. +Modification by: Huawei Developers +Modification date: 2025-09-10 +Modification Description: +Modification 1. Add support for Ascend NPU +""" + +from typing import Tuple, Union +import torch +from torch.autograd import Function +import torch_npu +import mx_driving._C + + +class CylinderQuery(Function): + @staticmethod + def forward(ctx, radius, hmin, hmax, nsample, new_xyz, xyz, rot): + rot = rot.reshape(rot.shape[0], rot.shape[1], 9) + group_idx = mx_driving._C.cylinder_query(radius, hmin, hmax, nsample, new_xyz, xyz, rot) + out = CylinderQuery.sortRes(group_idx, nsample) + return out + + @staticmethod + def backward(ctx, gradout): + return () + + @classmethod + def sortRes(cls, group_idx, nsample): + b = group_idx.shape[0] + m = group_idx.shape[1] + n = group_idx.shape[2] + group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] + mask = group_idx >= n + group_first = torch.where(mask, 0, group_idx) + group_idx = torch.where(mask, group_first[..., 0:1], group_idx) + return group_idx.to(dtype=torch.int32) + + +cylinder_query = CylinderQuery.apply \ No newline at end of file diff --git a/tests/torch/test_cylinder_query.py b/tests/torch/test_cylinder_query.py new file mode 100644 index 00000000..e5c0c4df --- /dev/null +++ b/tests/torch/test_cylinder_query.py @@ -0,0 +1,264 @@ +import unittest + +import torch +import torch_npu +import numpy as np +from data_cache import golden_data_cache +from torch_npu.testing.testcase import TestCase, run_tests +from mx_driving.ops.cylinder_query import cylinder_query + + +# 'pylint: disable=too-many-arguments,huawei-too-many-arguments +@golden_data_cache(__file__) +def gen_data(B, N, M): + new_xyz = np.random.randn(B, M, 3).astype(np.float32) + xyz = np.random.randn(B, N, 3).astype(np.float32) + rot = np.random.randn(B, M, 9).astype(np.float32) + return new_xyz, xyz, rot + + +class TestCylinderQuert(TestCase): + # 'pylint: disable=too-many-arguments,huawei-too-many-arguments + def cpu_forward_op(self, + radius, + hmin, + hmax, + nsample, + new_xyz, + xyz, + rot): + B = xyz.shape[0] + N = xyz.shape[1] + M = new_xyz.shape[1] + + xyzTrans = xyz[:, None, :, :] - new_xyz[:, :, None, :] # (b, m, n, 3) + rot = rot.reshape(B, M, 3, 3) + xyzTrans = (rot[:, :, None, :, :] * xyzTrans[:, :, :, :, None]).sum(-2) + d2 = xyzTrans[:, :, :, 1] * xyzTrans[:, :, :, 1] + xyzTrans[:, :, :, 2] * xyzTrans[:, :, :, 2] + h = xyzTrans[:, :, :, 0] + radius2 = radius ** 2 + + not_in_cylinder = np.logical_or(np.logical_or(d2 >= radius2, h <= hmin), h >= hmax) + + group_idx = np.arange(N, dtype=np.int32).reshape(1, 1, N) + group_idx = np.tile(group_idx, (B, M, 1)) + group_idx[not_in_cylinder] = N + + group_idx = np.sort(group_idx, axis=-1)[..., :nsample] + group_first = group_idx[..., 0, np.newaxis] # 对应view(B, M, 1) + group_first = np.tile(group_first, (1, 1, nsample)) + group_first[group_first == N] = 0 + mask = (group_idx == N) + group_idx[mask] = group_first[mask] + return group_idx + + def test_cylinder_query_return_right_value_when_shape_is_all_one(self): + B = 1 + N = 1 + M = 1 + radius = 10 + hmin = -100 + hmax = 100 + nsample = 1 + new_xyz, xyz, rot = gen_data(B, N, M) + expected_output = self.cpu_forward_op(radius, hmin, hmax, nsample, new_xyz, xyz, rot) + + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + output = output.cpu().numpy() + self.assertRtolEqual(expected_output, output) + + def test_cylinder_query_should_return_right_value_when_shape_is_align_to_8(self): + B = 8 + N = 128 + M = 32 + radius = 10 + hmin = -100 + hmax = 100 + nsample = 8 + new_xyz, xyz, rot = gen_data(B, N, M) + expected_output = self.cpu_forward_op(radius, hmin, hmax, nsample, new_xyz, xyz, rot) + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + output = output.cpu().numpy() + self.assertRtolEqual(expected_output, output) + + def test_cylinder_query_should_return_right_value_when_shape_is_not_align(self): + B = 7 + N = 129 + M = 31 + radius = 10 + hmin = -100 + hmax = 100 + nsample = 9 + new_xyz, xyz, rot = gen_data(B, N, M) + expected_output = self.cpu_forward_op(radius, hmin, hmax, nsample, new_xyz, xyz, rot) + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + output = output.cpu().numpy() + self.assertRtolEqual(expected_output, output) + + def test_cylinder_query_should_return_right_value_when_N_is_1000(self): + B = 7 + N = 1000 + M = 31 + radius = 10 + hmin = -100 + hmax = 100 + nsample = 9 + new_xyz, xyz, rot = gen_data(B, N, M) + + expected_output = self.cpu_forward_op(radius, hmin, hmax, nsample, new_xyz, xyz, rot) + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + output = output.cpu().numpy() + self.assertRtolEqual(expected_output, output) + + def test_cylinder_query_should_return_right_value_when_N_is_20000_and_M_is_1024_and_nsample_is_64(self): + B = 7 + N = 20000 + M = 1024 + radius = 10 + hmin = -0.5 + hmax = 0.5 + nsample = 64 + new_xyz, xyz, rot = gen_data(B, N, M) + + expected_output = self.cpu_forward_op(radius, hmin, hmax, nsample, new_xyz, xyz, rot) + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + output = output.cpu().numpy() + self.assertRtolEqual(expected_output, output) + + def test_cylinder_query_should_raise_error_value_when_nsample_is_larger_than_N(self): + B = 7 + N = 129 + M = 31 + radius = 10 + hmin = -100 + hmax = 100 + nsample = 110 + new_xyz, xyz, rot = gen_data(B, N, M) + + try: + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + except Exception as e: + assert "The value of nsample should be greater than the number of points in the tensor xyz." in str(e) + + def test_cylinder_query_should_raise_error_value_when_hmin_is_equal_to_hmax(self): + B = 7 + N = 129 + M = 31 + radius = 10 + hmin = 11 + hmax = 11 + nsample = 110 + new_xyz, xyz, rot = gen_data(B, N, M) + + try: + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + except Exception as e: + assert "The value of hmin needs to be less than the value of hmax." in str(e) + + def test_cylinder_query_should_raise_error_value_when_hmin_is_larger_than_hmax(self): + B = 7 + N = 129 + M = 31 + radius = 10 + hmin = 11 + hmax = 10 + nsample = 110 + new_xyz, xyz, rot = gen_data(B, N, M) + + try: + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + except Exception as e: + assert "The value of hmin needs to be less than the value of hmax." in str(e) + + def test_cylinder_query_should_raise_error_value_when_nsample_is_zero(self): + B = 7 + N = 129 + M = 31 + radius = 10 + hmin = -100 + hmax = 100 + nsample = 0 + new_xyz, xyz, rot = gen_data(B, N, M) + + try: + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + except Exception as e: + assert "The value of nsample should be greater than 0." in str(e) + + def test_cylinder_query_should_raise_error_value_when_nsample_is_less_than_zero(self): + B = 7 + N = 129 + M = 31 + radius = 10 + hmin = -100 + hmax = 100 + nsample = -1 + new_xyz, xyz, rot = gen_data(B, N, M) + + try: + output = cylinder_query(radius, + hmin, + hmax, + nsample, + torch.from_numpy(new_xyz).npu(), + torch.from_numpy(xyz).npu(), + torch.from_numpy(rot).npu()) + except Exception as e: + assert "The value of nsample should be greater than 0." in str(e) + +if __name__ == "__main__": + run_tests() \ No newline at end of file -- Gitee