From d7aff32b7fafe5bbc50480de6f38c775a0bcb4b5 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Tue, 25 Feb 2025 15:30:02 +0800 Subject: [PATCH] generalize dfa --- kernels/op_kernel/deformable_aggregation.cpp | 31 +++++++++----------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/kernels/op_kernel/deformable_aggregation.cpp b/kernels/op_kernel/deformable_aggregation.cpp index 647120da..b6972bec 100644 --- a/kernels/op_kernel/deformable_aggregation.cpp +++ b/kernels/op_kernel/deformable_aggregation.cpp @@ -26,8 +26,8 @@ public: coreNum_ = tiling_data->coreNum; numChannels_ = numEmbeds_ / numGroups_; - weightBufSize_ = AlignUp(numPoints_ * numCams_ * numScales_ * numGroups_, blockAlign_); - locBufSize_ = AlignUp(numPoints_ * numCams_ * 2, blockAlign_); + weightBufSize_ = AlignUp(numCams_ * numScales_ * numGroups_, blockAlign_); + locBufSize_ = AlignUp(numCams_ * 2, blockAlign_); scaleStartBufSize_ = AlignUp(numCams_ * numScales_, blockAlign_); spatialShapeBufSize_ = AlignUp(numCams_ * numScales_ * 2, blockAlign_); @@ -95,19 +95,19 @@ public: { uint32_t batchIdx = taskIdx / numAnchors_; uint32_t anchorIdx = taskIdx % numAnchors_; - uint32_t locationOffsetGm = (batchIdx * numAnchors_ + - anchorIdx) * numPoints_ * numCams_ * 2; - uint32_t weightOffsetGm = (batchIdx * numAnchors_ + - anchorIdx) * numPoints_ * numCams_ * numScales_ * numGroups_; uint32_t refOffsetGm = (batchIdx * numAnchors_ + anchorIdx) * numEmbeds_; - SetFlag(0); - WaitFlag(0); - DataCopy(weightLocal_, weightsGm_[weightOffsetGm], weightBufSize_); - DataCopy(locationLocal_, samplingLocationGm_[locationOffsetGm], locBufSize_); Duplicate(resLocal_, 0.0f, cAligned_); for (uint32_t pointIdx = 0; pointIdx < numPoints_; ++pointIdx) { + uint32_t locationOffsetGm = (batchIdx * numAnchors_ * numPoints_ + + anchorIdx * numPoints_ + pointIdx) * numCams_ * 2; + uint32_t weightOffsetGm = (batchIdx * numAnchors_ * numPoints_ + + anchorIdx * numPoints_ + pointIdx) * numCams_ * numScales_ * numGroups_; + SetFlag(0); + WaitFlag(0); + DataCopy(weightLocal_, weightsGm_[weightOffsetGm], weightBufSize_); + DataCopy(locationLocal_, samplingLocationGm_[locationOffsetGm], locBufSize_); for (uint32_t camIdx = 0; camIdx < numCams_; ++camIdx) { - uint32_t locationOffsetLocal = (pointIdx * numCams_ + camIdx) * 2; + uint32_t locationOffsetLocal = camIdx * 2; DTYPE_F locW = locationLocal_.GetValue(locationOffsetLocal); if (locW <= 0 || locW >= 1) { continue; @@ -117,9 +117,7 @@ public: continue; } for (uint32_t scaleIdx = 0; scaleIdx < numScales_; ++scaleIdx) { - uint32_t weightOffsetLocal = (pointIdx * numCams_ * numScales_ + - camIdx * numScales_ + - scaleIdx) * numGroups_; + uint32_t weightOffsetLocal = (camIdx * numScales_ + scaleIdx) * numGroups_; uint32_t scaleStartOffset = camIdx * numScales_ + scaleIdx; uint32_t spatialShapeOffset = scaleStartOffset * 2; uint32_t scaleStartIdx = scaleStartLocal_.GetValue(scaleStartOffset); @@ -190,8 +188,7 @@ public: Axpy(vLocal_[v1Offset_ * cAligned_], vLocal_[v4Offset_ * cAligned_], w4, cAligned_); BroadCast(weightMulLocal_, weightLocal_[weightOffsetLocal], dstShape_, srcShape_); - Mul(vLocal_[v1Offset_ * cAligned_], vLocal_[v1Offset_ * cAligned_], weightMulLocal_, cAligned_); - Add(resLocal_, resLocal_, vLocal_[v1Offset_ * cAligned_], cAligned_); + MulAddDst(resLocal_, vLocal_[v1Offset_ * cAligned_], weightMulLocal_, cAligned_); } } } @@ -237,4 +234,4 @@ extern "C" __global__ __aicore__ void deformable_aggregation(GM_ADDR mc_ms_feat, op.Init(mc_ms_feat, spatial_shape, scale_start_index, sampling_location, weights, out, &tiling_data, &pipe); op.GetLocalTensor(); op.Process(); -} \ No newline at end of file +} -- Gitee