From 2fa0aadb4dd64a55c4dc52d5814ddca6ac4a6f7d Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Fri, 8 Mar 2024 18:44:37 +0800 Subject: [PATCH 1/9] refine msdagrad --- .../multi_scale_deformable_attention_grad.cpp | 395 ++++++++---------- 1 file changed, 184 insertions(+), 211 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index 8a038713..29cf6802 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -22,10 +22,6 @@ #include "kernel_tiling/kernel_tiling.h" using namespace AscendC; -#include "kernel_operator.h" -#include "kernel_tiling/kernel_tiling.h" -using namespace AscendC; - namespace { constexpr static int32_t BUFFER_NUM = 1; }; @@ -34,9 +30,9 @@ class MultiScaleDeformableAttentionGrad { public: __aicore__ inline MultiScaleDeformableAttentionGrad(){}; __aicore__ inline void Init(GM_ADDR value_gm, GM_ADDR spatial_shapes_gm, GM_ADDR level_start_index_gm, - GM_ADDR sampling_loc_gm, GM_ADDR attn_weight_gm, GM_ADDR grad_output_gm, - GM_ADDR grad_value_gm, GM_ADDR grad_sampling_loc_gm, GM_ADDR grad_attn_weight_gm, - MultiScaleDeformableAttentionGradTilingData *tiling_data, TPipe *tmpPipe) + GM_ADDR sampling_loc_gm, GM_ADDR attn_weight_gm, GM_ADDR grad_output_gm, + GM_ADDR grad_value_gm, GM_ADDR grad_sampling_loc_gm, GM_ADDR grad_attn_weight_gm, + MultiScaleDeformableAttentionGradTilingData *tiling_data, TPipe *tmpPipe) { pipe = tmpPipe; curBlockIdx = GetBlockIdx(); @@ -50,8 +46,6 @@ public: numPoints = tiling_data->numPoints; batchSize = tiling_data->batchSize; coreNum = tiling_data->coreNum; - - wStride = numHeads * embedDims; taskNum = numQueries; taskNumPerCore = DivCeil(taskNum, coreNum); @@ -60,15 +54,31 @@ public: numPointsAlign = AlignUp(numPoints, dataAlign); numLevelsAlign = AlignUp(numLevels, dataAlign); - batchOffset = numPoints * embedDimsAlign; - - curBlockIdx = GetBlockIdx(); startOffset = curBlockIdx * taskNumPerCore; endOffset = (curBlockIdx + 1) * taskNumPerCore; if (endOffset > taskNum) { endOffset = taskNum; } + // offsets + gradOutStride0 = embedDims; + gradOutStride1 = numHeads * gradOutStride0; + gradOutStride2 = numQueries * gradOutStride1; + weightStride0 = numLevels * numPoints; + weightStride1 = numHeads * weightStride0; + weightStride2 = numQueries * weightStride1; + valueStride0 = embedDims; + valueStride1 = numHeads * valueStride0; + valueStride2 = numKeys * valueStride1; + + eventIdVToMte3 = static_cast(pipe->AllocEventID()); + eventIdMte2ToV = static_cast(pipe->AllocEventID()); + eventIdMte3ToV = static_cast(pipe->AllocEventID()); + + copyParamsA = {1, (uint16_t)(embedDims * sizeof(DTYPE_VALUE)), 0, 0}; + copyParamsB = {1, (uint16_t)(numPoints * sizeof(DTYPE_VALUE)), 0, 0}; + sumParams = {numPoints, embedDimsAlign, embedDims}; + valueGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(value_gm), batchSize * numKeys * numHeads * embedDims); @@ -91,22 +101,17 @@ public: batchSize * numQueries * numHeads * numLevels * 2 * numPoints); gradWeightGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(grad_attn_weight_gm), batchSize * numQueries * numHeads * numLevels * numPoints); + } - pipe->InitBuffer(shapeQueue, BUFFER_NUM, AlignUp(numLevels * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(offsetQueue, BUFFER_NUM, numLevelsAlign * sizeof(DTYPE_VALUE)); + __aicore__ inline void InitBuffer() + { + pipe->InitBuffer(shapeUb, BUFFER_NUM, AlignUp(numLevels * 2, dataAlign) * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(offsetUb, BUFFER_NUM, numLevelsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(locationQueue, BUFFER_NUM, + pipe->InitBuffer(locationUb, BUFFER_NUM, AlignUp(numHeads * numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); pipe->InitBuffer(attentionWeightsUb, BUFFER_NUM, AlignUp(numHeads * numLevels * numPoints, dataAlign) * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(gradQueue, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - - pipe->InitBuffer(gradValueQueue, BUFFER_NUM, - AlignUp(numHeads * numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(gradLocationQueue, BUFFER_NUM, - AlignUp(numHeads * numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(gradWeightQueue, BUFFER_NUM, - AlignUp(numHeads * numLevels * numPoints, dataAlign) * sizeof(DTYPE_VALUE)); pipe->InitBuffer(floatOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(topGradUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); @@ -115,9 +120,7 @@ public: pipe->InitBuffer(tmpXUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmpYUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(weightSumUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(weightQueue, BUFFER_NUM, 4 * numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe->InitBuffer(valueUb, BUFFER_NUM, batchOffset * 4 * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(weightUb, BUFFER_NUM, 4 * numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(locWUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(locHUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); @@ -158,7 +161,6 @@ public: pipe->InitBuffer(topGradValueUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(gradWeightUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmp1Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmp2Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmp3Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); @@ -170,15 +172,85 @@ public: pipe->InitBuffer(tmp9Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmp10Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpAUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpBUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(midUb, BUFFER_NUM, 4 * numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(gradSampleXLocUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(gradSampleYLocUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); } + + __aicore__ inline void GetLocalTensor() + { + locationLocal = locationUb.Get(); + attentionWeightLocal = attentionWeightsUb.Get(); + shapesLocal = shapeUb.Get(); + offsetLocal = offsetUb.Get(); + weightLocal = weightUb.Get(); + xLocal = tmpXUb.Get(); + yLocal = tmpYUb.Get(); + weightSumLocal = weightSumUb.Get(); + floatOneLocal = floatOneUb.Get(); + topGradLocal = topGradUb.Get(); + lwLocal = lwUb.Get(); + lhLocal = lhUb.Get(); + locWLocal = locWUb.Get(); + locHLocal = locHUb.Get(); + + hImLocal = hImUb.Get(); + wImLocal = wImUb.Get(); + hLowLocal = hLowUb.Get(); + wLowLocal = wLowUb.Get(); + hHighLocal = hHighUb.Get(); + wHighLocal = wHighUb.Get(); + + hLowFloatLocal = hLowFloatUb.Get(); + wLowFloatLocal = wLowFloatUb.Get(); + + hHighPtrOffsetLocal = hHighPtrOffsetUb.Get(); + hLowPtrOffsetLocal = hLowPtrOffsetUb.Get(); + wHighPtrOffsetLocal = wHighPtrOffsetUb.Get(); + wLowPtrOffsetLocal = wLowPtrOffsetUb.Get(); + w1Local = w1Ub.Get(); + w2Local = w2Ub.Get(); + w3Local = w3Ub.Get(); + w4Local = w4Ub.Get(); + + v1Local = v1Ub.Get(); + v2Local = v2Ub.Get(); + v3Local = v3Ub.Get(); + v4Local = v4Ub.Get(); + + hwLocal = hwUb.Get(); + hhLocal = hhUb.Get(); + + gradHWeightLocal = gradHWeightUb.Get(); + gradWWeightLocal = gradWWeightUb.Get(); + topGradValueLocal = topGradValueUb.Get(); + gradWeightLocal = gradWeightUb.Get(); + + tmp1Local = tmp1Ub.Get(); + tmp2Local = tmp2Ub.Get(); + tmp3Local = tmp3Ub.Get(); + tmp4Local = tmp4Ub.Get(); + tmp5Local = tmp5Ub.Get(); + tmp6Local = tmp6Ub.Get(); + tmp7Local = tmp7Ub.Get(); + tmp8Local = tmp8Ub.Get(); + tmp9Local = tmp9Ub.Get(); + tmp10Local = tmp10Ub.Get(); + + tmpLocal = tmpUb.Get(); + midLocal = midUb.Get(); + + gradSampleXLocLocal = gradSampleXLocUb.Get(); + gradSampleYLocLocal = gradSampleYLocUb.Get(); + } + __aicore__ inline void Process() { + DataCopy(shapesLocal, valueSpatialShapesGm, AlignUp(numLevels * 2, dataAlign)); + DataCopy(offsetLocal, valueLevelStartIndexGm, numLevelsAlign); + Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign); for (uint32_t taskIdx = startOffset; taskIdx < endOffset; taskIdx++) { SetAtomicAdd(); Compute(taskIdx); @@ -186,101 +258,66 @@ public: } } -private: - __aicore__ inline void Compute(uint32_t query) + __aicore__ inline void ReleaseEventID() { - LocalTensor locationLocal = locationQueue.Get(); - LocalTensor attentionWeightLocal = attentionWeightsUb.Get(); - - LocalTensor shapesLocal = shapeQueue.Get(); - LocalTensor offsetLocal = offsetQueue.Get(); - - DataCopy(shapesLocal, valueSpatialShapesGm, AlignUp(numLevels * 2, dataAlign)); - DataCopy(offsetLocal, valueLevelStartIndexGm, numLevelsAlign); - - DataCopyParams copyParamsA{1, (uint16_t)(embedDims * sizeof(DTYPE_VALUE)), 0, 0}; - DataCopyParams copyParamsB{1, (uint16_t)(numPoints * sizeof(DTYPE_VALUE)), 0, 0}; - - LocalTensor valueLocal = valueUb.Get(); + pipe->ReleaseEventID(eventIdVToMte3); + pipe->ReleaseEventID(eventIdMte2ToV); + pipe->ReleaseEventID(eventIdMte3ToV); + } - event_t eventIdVToMte3 = static_cast(GetTPipePtr()->AllocEventID()); - event_t eventIdMte2ToV = static_cast(GetTPipePtr()->AllocEventID()); - event_t eventIdMte3ToV = static_cast(GetTPipePtr()->AllocEventID()); +private: + template + __aicore__ inline void ComputeGrad(uint32_t point, uint32_t midId, + LocalTensor &hPtrOffsetLocal, + LocalTensor &wPtrOffsetLocal, + LocalTensor &wLocal, LocalTensor &vLocal, + LocalTensor &distanceHLocal, + LocalTensor &distanceWLocal) + { + offsetMid = (point + midId * numPoints) * embedDimsAlign; + ptr = hPtrOffsetLocal.GetValue(point) + wPtrOffsetLocal.GetValue(point) + basePtr; + DataCopy(vLocal[point * embedDimsAlign], valueGm[offsetValue + ptr], embedDimsAlign); + SetFlagHardEvent::MTE2_V(eventIdMte2ToV); + WaitFlagHardEvent::MTE2_V(eventIdMte2ToV); + Muls(tmpLocal, vLocal[point * embedDimsAlign], distanceWLocal.GetValue(point), embedDims); + if (AddH) { + Add(gradHWeightLocal[point * embedDimsAlign], gradHWeightLocal[point * embedDimsAlign], tmpLocal, + embedDims); + } else { + Sub(gradHWeightLocal[point * embedDimsAlign], gradHWeightLocal[point * embedDimsAlign], tmpLocal, + embedDims); + } + Muls(tmpLocal, vLocal[point * embedDimsAlign], distanceHLocal.GetValue(point), embedDims); + if (AddW) { + Add(gradWWeightLocal[point * embedDimsAlign], gradWWeightLocal[point * embedDimsAlign], tmpLocal, + embedDims); + } else { + Sub(gradWWeightLocal[point * embedDimsAlign], gradWWeightLocal[point * embedDimsAlign], tmpLocal, + embedDims); + } + Muls(midLocal[offsetMid], topGradValueLocal[point * embedDimsAlign], wLocal.GetValue(point), embedDims); + SetFlagHardEvent::V_MTE3(eventIdVToMte3); + WaitFlagHardEvent::V_MTE3(eventIdVToMte3); + DataCopyPad(gradValueGm[offsetValue + ptr], midLocal[offsetMid], copyParamsA); + } + __aicore__ inline void Compute(uint32_t query) + { for (uint32_t batch = 0; batch < batchSize; batch++) { - LocalTensor weightLocal = weightQueue.Get(); - LocalTensor xLocal = tmpXUb.Get(); - LocalTensor yLocal = tmpYUb.Get(); - LocalTensor weightSumLocal = weightSumUb.Get(); - LocalTensor floatOneLocal = floatOneUb.Get(); - LocalTensor topGradLocal = topGradUb.Get(); - LocalTensor lwLocal = lwUb.Get(); - LocalTensor lhLocal = lhUb.Get(); - LocalTensor locWLocal = locWUb.Get(); - LocalTensor locHLocal = locHUb.Get(); - - LocalTensor hImLocal = hImUb.Get(); - LocalTensor wImLocal = wImUb.Get(); - LocalTensor hLowLocal = hLowUb.Get(); - LocalTensor wLowLocal = wLowUb.Get(); - LocalTensor hHighLocal = hHighUb.Get(); - LocalTensor wHighLocal = wHighUb.Get(); - - LocalTensor hLowFloatLocal = hLowFloatUb.Get(); - LocalTensor wLowFloatLocal = wLowFloatUb.Get(); - - LocalTensor hHighPtrOffsetLocal = hHighPtrOffsetUb.Get(); - LocalTensor hLowPtrOffsetLocal = hLowPtrOffsetUb.Get(); - LocalTensor wHighPtrOffsetLocal = wHighPtrOffsetUb.Get(); - LocalTensor wLowPtrOffsetLocal = wLowPtrOffsetUb.Get(); - LocalTensor w1Local = w1Ub.Get(); - LocalTensor w2Local = w2Ub.Get(); - LocalTensor w3Local = w3Ub.Get(); - LocalTensor w4Local = w4Ub.Get(); - - LocalTensor v1Local = v1Ub.Get(); - LocalTensor v2Local = v2Ub.Get(); - LocalTensor v3Local = v3Ub.Get(); - LocalTensor v4Local = v4Ub.Get(); - - LocalTensor hwLocal = hwUb.Get(); - LocalTensor hhLocal = hhUb.Get(); - - LocalTensor gradHWeightLocal = gradHWeightUb.Get(); - LocalTensor gradWWeightLocal = gradWWeightUb.Get(); - LocalTensor topGradValueLocal = topGradValueUb.Get(); - LocalTensor gradWeightLocal = gradWeightUb.Get(); - - LocalTensor tmpLocal = tmpUb.Get(); - LocalTensor tmp1Local = tmp1Ub.Get(); - LocalTensor tmp2Local = tmp2Ub.Get(); - LocalTensor tmp3Local = tmp3Ub.Get(); - LocalTensor tmp4Local = tmp4Ub.Get(); - LocalTensor tmp5Local = tmp5Ub.Get(); - LocalTensor tmp6Local = tmp6Ub.Get(); - LocalTensor tmp7Local = tmp7Ub.Get(); - LocalTensor tmp8Local = tmp8Ub.Get(); - LocalTensor tmp9Local = tmp9Ub.Get(); - LocalTensor tmp10Local = tmp10Ub.Get(); - - LocalTensor tmpALocal = tmpAUb.Get(); - LocalTensor tmpBLocal = tmpBUb.Get(); - LocalTensor midLocal = midUb.Get(); - - LocalTensor gradSampleXLocLocal = gradSampleXLocUb.Get(); - LocalTensor gradSampleYLocLocal = gradSampleYLocUb.Get(); - - Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign); for (uint32_t head = 0; head < numHeads; head++) { - offsetWeight = (batch * numQueries * numHeads + query * numHeads + head) * numLevels * numPoints; + offsetWeight = batch * weightStride2 + query * weightStride1 + head * weightStride0; offsetLocation = 2 * offsetWeight; - DataCopy(topGradLocal, gradOutputGm[batch * numQueries * wStride + query * wStride + head * embedDims], + basePtr = head * embedDims; + DataCopy(topGradLocal, + gradOutputGm[batch * gradOutStride2 + query * gradOutStride1 + head * gradOutStride0], embedDimsAlign); for (uint32_t level = 0; level < numLevels; level++) { levelStartId = offsetLocal.GetValue(level); h = shapesLocal.GetValue(level * 2); w = shapesLocal.GetValue(level * 2 + 1); - offsetValue = batch * numKeys * numHeads * embedDims + levelStartId * numHeads * embedDims; + offsetValue = batch * valueStride2 + levelStartId * valueStride1; + wStride = numHeads * embedDims; + hStride = w * wStride; DataCopy(locWLocal, locationGm[offsetLocation + level * numPoints * 2], numPointsAlign); DataCopy(locHLocal, locationGm[offsetLocation + level * numPoints * 2 + numPoints], numPointsAlign); DataCopy(attentionWeightLocal, attentionWeightsGm[offsetWeight + level * numPoints], @@ -304,13 +341,11 @@ private: Sub(hhLocal, floatOneLocal, lhLocal, numPointsAlign); Sub(hwLocal, floatOneLocal, lwLocal, numPointsAlign); - wStride = numHeads * embedDims; - hStride = w * wStride; + Muls(hLowPtrOffsetLocal, hLowLocal, hStride, numPointsAlign); Adds(hHighPtrOffsetLocal, hLowPtrOffsetLocal, hStride, numPointsAlign); Muls(wLowPtrOffsetLocal, wLowLocal, wStride, numPointsAlign); Adds(wHighPtrOffsetLocal, wLowPtrOffsetLocal, wStride, numPointsAlign); - basePtr = head * embedDims; Mul(w1Local, hhLocal, hwLocal, numPointsAlign); Mul(w2Local, hhLocal, lwLocal, numPointsAlign); @@ -334,97 +369,22 @@ private: attentionWeightLocal.GetValue(point), embedDimsAlign); if (hLowLocal.GetValue(point) >= 0) { if (wLowLocal.GetValue(point) >= 0) { - ptr = hLowPtrOffsetLocal.GetValue(point) + wLowPtrOffsetLocal.GetValue(point) + - basePtr; - DataCopy(v1Local[point * embedDimsAlign], valueGm[offsetValue + ptr], - embedDimsAlign); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - Muls(tmpALocal, v1Local[point * embedDimsAlign], hwLocal.GetValue(point), - embedDims); - Muls(tmpBLocal, v1Local[point * embedDimsAlign], hhLocal.GetValue(point), - embedDims); - Sub(gradHWeightLocal[point * embedDimsAlign], - gradHWeightLocal[point * embedDimsAlign], tmpALocal, embedDims); - Sub(gradWWeightLocal[point * embedDimsAlign], - gradWWeightLocal[point * embedDimsAlign], tmpBLocal, embedDims); - Muls(midLocal[point * embedDimsAlign], topGradValueLocal[point * embedDimsAlign], - w1Local.GetValue(point), embedDims); - SetFlag(eventIdVToMte3); - WaitFlag(eventIdVToMte3); - DataCopyPad(gradValueGm[offsetValue + ptr], midLocal[point * embedDimsAlign], - copyParamsA); + ComputeGrad(point, 0, hLowPtrOffsetLocal, wLowPtrOffsetLocal, w1Local, + v1Local, hhLocal, hwLocal); } if (wHighLocal.GetValue(point) < w) { - ptr = hLowPtrOffsetLocal.GetValue(point) + wHighPtrOffsetLocal.GetValue(point) + - basePtr; - DataCopy(v2Local[point * embedDimsAlign], valueGm[offsetValue + ptr], - embedDimsAlign); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - Muls(tmpALocal, v2Local[point * embedDimsAlign], lwLocal.GetValue(point), - embedDims); - Muls(tmpBLocal, v2Local[point * embedDimsAlign], hhLocal.GetValue(point), - embedDims); - Sub(gradHWeightLocal[point * embedDimsAlign], - gradHWeightLocal[point * embedDimsAlign], tmpALocal, embedDims); - Add(gradWWeightLocal[point * embedDimsAlign], - gradWWeightLocal[point * embedDimsAlign], tmpBLocal, embedDims); - Muls(midLocal[point * embedDimsAlign + numPoints * embedDimsAlign], - topGradValueLocal[point * embedDimsAlign], w2Local.GetValue(point), embedDims); - SetFlag(eventIdVToMte3); - WaitFlag(eventIdVToMte3); - DataCopyPad(gradValueGm[offsetValue + ptr], - midLocal[point * embedDimsAlign + numPoints * embedDimsAlign], - copyParamsA); + ComputeGrad(point, 1, hLowPtrOffsetLocal, wHighPtrOffsetLocal, w2Local, + v2Local, hhLocal, lwLocal); } } if (hHighLocal.GetValue(point) < h) { if (wLowLocal.GetValue(point) >= 0) { - ptr = hHighPtrOffsetLocal.GetValue(point) + wLowPtrOffsetLocal.GetValue(point) + - basePtr; - DataCopy(v3Local[point * embedDimsAlign], valueGm[offsetValue + ptr], - embedDimsAlign); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - Muls(tmpALocal, v3Local[point * embedDimsAlign], hwLocal.GetValue(point), - embedDims); - Muls(tmpBLocal, v3Local[point * embedDimsAlign], lhLocal.GetValue(point), - embedDims); - Add(gradHWeightLocal[point * embedDimsAlign], - gradHWeightLocal[point * embedDimsAlign], tmpALocal, embedDims); - Sub(gradWWeightLocal[point * embedDimsAlign], - gradWWeightLocal[point * embedDimsAlign], tmpBLocal, embedDims); - Muls(midLocal[point * embedDimsAlign + numPoints * embedDimsAlign * 2], - topGradValueLocal[point * embedDimsAlign], w3Local.GetValue(point), embedDims); - SetFlag(eventIdVToMte3); - WaitFlag(eventIdVToMte3); - DataCopyPad(gradValueGm[offsetValue + ptr], - midLocal[point * embedDimsAlign + numPoints * embedDimsAlign * 2], - copyParamsA); + ComputeGrad(point, 2, hHighPtrOffsetLocal, wLowPtrOffsetLocal, w3Local, + v3Local, lhLocal, hwLocal); } if (wHighLocal.GetValue(point) < w) { - ptr = hHighPtrOffsetLocal.GetValue(point) + wHighPtrOffsetLocal.GetValue(point) + - basePtr; - DataCopy(v4Local[point * embedDimsAlign], valueGm[offsetValue + ptr], - embedDimsAlign); - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - Muls(tmpALocal, v4Local[point * embedDimsAlign], lwLocal.GetValue(point), - embedDims); - Muls(tmpBLocal, v4Local[point * embedDimsAlign], lhLocal.GetValue(point), - embedDims); - Add(gradHWeightLocal[point * embedDimsAlign], - gradHWeightLocal[point * embedDimsAlign], tmpALocal, embedDims); - Add(gradWWeightLocal[point * embedDimsAlign], - gradWWeightLocal[point * embedDimsAlign], tmpBLocal, embedDims); - Muls(midLocal[point * embedDimsAlign + numPoints * embedDimsAlign * 3], - topGradValueLocal[point * embedDimsAlign], w4Local.GetValue(point), embedDims); - SetFlag(eventIdVToMte3); - WaitFlag(eventIdVToMte3); - DataCopyPad(gradValueGm[offsetValue + ptr], - midLocal[point * embedDimsAlign + numPoints * embedDimsAlign * 3], - copyParamsA); + ComputeGrad(point, 3, hHighPtrOffsetLocal, wHighPtrOffsetLocal, w4Local, + v4Local, lhLocal, lwLocal); } } SetFlag(eventIdMte3ToV); @@ -451,7 +411,6 @@ private: Muls(gradSampleXLocLocal, tmp9Local, (DTYPE_VALUE)w, numPoints * embedDimsAlign); Mul(tmp10Local, topGradValueLocal, gradHWeightLocal, numPoints * embedDimsAlign); Muls(gradSampleYLocLocal, tmp10Local, (DTYPE_VALUE)h, numPoints * embedDimsAlign); - SumParams sumParams{numPoints, embedDimsAlign, embedDims}; Sum(xLocal, gradSampleXLocLocal, sumParams); Sum(yLocal, gradSampleYLocLocal, sumParams); Sum(weightSumLocal, gradWeightLocal, sumParams); @@ -464,9 +423,6 @@ private: } } } - GetTPipePtr()->ReleaseEventID(eventIdVToMte3); - GetTPipePtr()->ReleaseEventID(eventIdMte2ToV); - GetTPipePtr()->ReleaseEventID(eventIdMte3ToV); } private: @@ -475,18 +431,17 @@ private: gradWeightGm; GlobalTensor valueSpatialShapesGm, valueLevelStartIndexGm; - TBuf locationQueue, attentionWeightsUb, shapeQueue, offsetQueue, gradQueue; - TBuf gradValueQueue, gradLocationQueue, gradWeightQueue; + TBuf locationUb, attentionWeightsUb, shapeUb, offsetUb; TBuf tmpXUb, tmpYUb, weightSumUb; - TBuf intOneUb, floatOneUb, weightQueue, emptyUb, topGradUb; + TBuf floatOneUb, weightUb, topGradUb; TBuf valueUb, locWUb, locHUb, hImUb, wImUb, hLowUb, wLowUb, hHighUb, wHighUb, hLowFloatUb, wLowFloatUb, hHighFloatUb, wHighFloatUb, hHighPtrOffsetUb, hLowPtrOffsetUb, wHighPtrOffsetUb, wLowPtrOffsetUb; TBuf lwUb, lhUb, hwUb, hhUb, w1Ub, w2Ub, w3Ub, w4Ub, v1Ub, v2Ub, v3Ub, v4Ub; TBuf tmpUb, tmp1Ub, tmp2Ub, tmp3Ub, tmp4Ub, tmp5Ub, tmp6Ub, tmp7Ub, tmp8Ub, tmp9Ub, tmp10Ub, - tmpAUb, tmpBUb, midUb; + midUb; TBuf gradHWeightUb, gradWWeightUb, topGradValueUb, gradWeightUb, gradSampleXLocUb, gradSampleYLocUb; @@ -507,6 +462,7 @@ private: uint32_t batch; uint32_t query; uint32_t head; + uint32_t point; uint32_t taskNum; uint32_t taskNumPerCore; @@ -516,11 +472,26 @@ private: uint32_t dataAlign; uint32_t blockBytes = 32; - DTYPE_VALUE tmp1, tmp2, leftTopWeight, rightTopWeiight, leftBottomWeight, rightBottomWeight, attnWeight; - DTYPE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, batchOffset, levelStartId, - offsetValue; + uint32_t gradOutStride0, gradOutStride1, gradOutStride2; + uint32_t weightStride0, weightStride1, weightStride2; + uint32_t valueStride0, valueStride1, valueStride2; + + DTYPE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, levelStartId, offsetValue; DTYPE_SPATIAL_SHAPES offsetWeight, offsetLocation, wStride, hStride, basePtr, ptr; + LocalTensor shapesLocal, offsetLocal, hLowLocal, wLowLocal, hHighLocal, wHighLocal, + hHighPtrOffsetLocal, hLowPtrOffsetLocal, wHighPtrOffsetLocal, wLowPtrOffsetLocal; + LocalTensor weightLocal, xLocal, yLocal, weightSumLocal, floatOneLocal, topGradLocal, lwLocal, lhLocal, + locWLocal, locHLocal, hImLocal, wImLocal, hLowFloatLocal, wLowFloatLocal, w1Local, w2Local, w3Local, w4Local, + v1Local, v2Local, v3Local, v4Local, hwLocal, hhLocal, gradHWeightLocal, gradWWeightLocal, topGradValueLocal, + gradWeightLocal, tmp1Local, tmp2Local, tmp3Local, tmp4Local, tmp5Local, tmp6Local, tmp7Local, tmp8Local, + tmp9Local, tmp10Local, tmpLocal, midLocal, gradSampleXLocLocal, gradSampleYLocLocal, locationLocal, + attentionWeightLocal; + + event_t eventIdVToMte3, eventIdMte2ToV, eventIdMte3ToV; + + DataCopyParams copyParamsA, copyParamsB; + SumParams sumParams; }; // core func @@ -535,6 +506,8 @@ extern "C" __global__ __aicore__ void multi_scale_deformable_attention_grad( MultiScaleDeformableAttentionGrad op; op.Init(value_gm, spatial_shapes_gm, level_start_index_gm, sampling_loc_gm, attn_weight_gm, grad_output_gm, grad_value_gm, grad_sampling_loc_gm, grad_attn_weight_gm, &tiling_datas, &pipe); - + op.InitBuffer(); + op.GetLocalTensor(); op.Process(); + op.ReleaseEventID(); } -- Gitee From ac395deee5be7e2b0f4cbc0a7b9ac0b50fe48367 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Sat, 9 Mar 2024 14:18:55 +0800 Subject: [PATCH 2/9] 1 --- .../multi_scale_deformable_attention_grad.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index 29cf6802..c1a78109 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -274,11 +274,11 @@ private: LocalTensor &distanceHLocal, LocalTensor &distanceWLocal) { - offsetMid = (point + midId * numPoints) * embedDimsAlign; + uint32_t offsetMid = (point + midId * numPoints) * embedDimsAlign; ptr = hPtrOffsetLocal.GetValue(point) + wPtrOffsetLocal.GetValue(point) + basePtr; DataCopy(vLocal[point * embedDimsAlign], valueGm[offsetValue + ptr], embedDimsAlign); - SetFlagHardEvent::MTE2_V(eventIdMte2ToV); - WaitFlagHardEvent::MTE2_V(eventIdMte2ToV); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); Muls(tmpLocal, vLocal[point * embedDimsAlign], distanceWLocal.GetValue(point), embedDims); if (AddH) { Add(gradHWeightLocal[point * embedDimsAlign], gradHWeightLocal[point * embedDimsAlign], tmpLocal, @@ -296,8 +296,8 @@ private: embedDims); } Muls(midLocal[offsetMid], topGradValueLocal[point * embedDimsAlign], wLocal.GetValue(point), embedDims); - SetFlagHardEvent::V_MTE3(eventIdVToMte3); - WaitFlagHardEvent::V_MTE3(eventIdVToMte3); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); DataCopyPad(gradValueGm[offsetValue + ptr], midLocal[offsetMid], copyParamsA); } -- Gitee From 89c6fc40e9cc40e5f924a5582bad53ad9d7f1d27 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Sat, 9 Mar 2024 14:36:45 +0800 Subject: [PATCH 3/9] 2 --- .../multi_scale_deformable_attention_grad.cpp | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index c1a78109..2decdfb6 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -267,7 +267,7 @@ public: private: template - __aicore__ inline void ComputeGrad(uint32_t point, uint32_t midId, + __aicore__ inline void ComputeGrad(uint32_t midId, LocalTensor &hPtrOffsetLocal, LocalTensor &wPtrOffsetLocal, LocalTensor &wLocal, LocalTensor &vLocal, @@ -303,15 +303,15 @@ private: __aicore__ inline void Compute(uint32_t query) { - for (uint32_t batch = 0; batch < batchSize; batch++) { - for (uint32_t head = 0; head < numHeads; head++) { + for (batch = 0; batch < batchSize; batch++) { + for (head = 0; head < numHeads; head++) { offsetWeight = batch * weightStride2 + query * weightStride1 + head * weightStride0; offsetLocation = 2 * offsetWeight; basePtr = head * embedDims; DataCopy(topGradLocal, gradOutputGm[batch * gradOutStride2 + query * gradOutStride1 + head * gradOutStride0], embedDimsAlign); - for (uint32_t level = 0; level < numLevels; level++) { + for (level = 0; level < numLevels; level++) { levelStartId = offsetLocal.GetValue(level); h = shapesLocal.GetValue(level * 2); w = shapesLocal.GetValue(level * 2 + 1); @@ -362,28 +362,28 @@ private: Duplicate(v3Local, (DTYPE_VALUE)0, numPoints * embedDimsAlign); Duplicate(v4Local, (DTYPE_VALUE)0, numPoints * embedDimsAlign); - for (uint32_t point = 0; point < numPoints; point++) { + for (point = 0; point < numPoints; point++) { if (hImLocal.GetValue(point) > -1 && wImLocal.GetValue(point) > -1 && hImLocal.GetValue(point) < h && wImLocal.GetValue(point) < w) { Muls(topGradValueLocal[point * embedDimsAlign], topGradLocal, attentionWeightLocal.GetValue(point), embedDimsAlign); if (hLowLocal.GetValue(point) >= 0) { if (wLowLocal.GetValue(point) >= 0) { - ComputeGrad(point, 0, hLowPtrOffsetLocal, wLowPtrOffsetLocal, w1Local, + ComputeGrad(0, hLowPtrOffsetLocal, wLowPtrOffsetLocal, w1Local, v1Local, hhLocal, hwLocal); } if (wHighLocal.GetValue(point) < w) { - ComputeGrad(point, 1, hLowPtrOffsetLocal, wHighPtrOffsetLocal, w2Local, + ComputeGrad(1, hLowPtrOffsetLocal, wHighPtrOffsetLocal, w2Local, v2Local, hhLocal, lwLocal); } } if (hHighLocal.GetValue(point) < h) { if (wLowLocal.GetValue(point) >= 0) { - ComputeGrad(point, 2, hHighPtrOffsetLocal, wLowPtrOffsetLocal, w3Local, + ComputeGrad(2, hHighPtrOffsetLocal, wLowPtrOffsetLocal, w3Local, v3Local, lhLocal, hwLocal); } if (wHighLocal.GetValue(point) < w) { - ComputeGrad(point, 3, hHighPtrOffsetLocal, wHighPtrOffsetLocal, w4Local, + ComputeGrad(3, hHighPtrOffsetLocal, wHighPtrOffsetLocal, w4Local, v4Local, lhLocal, lwLocal); } } @@ -462,6 +462,7 @@ private: uint32_t batch; uint32_t query; uint32_t head; + uint32_t level; uint32_t point; uint32_t taskNum; -- Gitee From d3d8d6f56c640c066a6ce00b3298c4dc7a90d115 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Sat, 9 Mar 2024 15:37:58 +0800 Subject: [PATCH 4/9] 2 --- .../multi_scale_deformable_attention_grad.cpp | 62 ++++++++----------- 1 file changed, 25 insertions(+), 37 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index 2decdfb6..e89a4f4a 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -161,16 +161,11 @@ public: pipe->InitBuffer(topGradValueUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(gradWeightUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp1Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp2Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp3Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp4Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(w1v1Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(w2v2Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(w3v3Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(w4v4Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmp5Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp6Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp7Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp8Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp9Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmp10Ub, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmpUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(midUb, BUFFER_NUM, 4 * numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); @@ -228,16 +223,11 @@ public: topGradValueLocal = topGradValueUb.Get(); gradWeightLocal = gradWeightUb.Get(); - tmp1Local = tmp1Ub.Get(); - tmp2Local = tmp2Ub.Get(); - tmp3Local = tmp3Ub.Get(); - tmp4Local = tmp4Ub.Get(); + w1v1Local = w1v1Ub.Get(); + w2v2Local = w2v2Ub.Get(); + w3v3Local = w3v3Ub.Get(); + w4v4Local = w4v4Ub.Get(); tmp5Local = tmp5Ub.Get(); - tmp6Local = tmp6Ub.Get(); - tmp7Local = tmp7Ub.Get(); - tmp8Local = tmp8Ub.Get(); - tmp9Local = tmp9Ub.Get(); - tmp10Local = tmp10Ub.Get(); tmpLocal = tmpUb.Get(); midLocal = midUb.Get(); @@ -389,28 +379,28 @@ private: } SetFlag(eventIdMte3ToV); WaitFlag(eventIdMte3ToV); - Muls(tmp1Local[point * embedDimsAlign], v1Local[point * embedDimsAlign], + Muls(w1v1Local[point * embedDimsAlign], v1Local[point * embedDimsAlign], w1Local.GetValue(point), embedDimsAlign); - Muls(tmp2Local[point * embedDimsAlign], v2Local[point * embedDimsAlign], + Muls(w2v2Local[point * embedDimsAlign], v2Local[point * embedDimsAlign], w2Local.GetValue(point), embedDimsAlign); - Muls(tmp3Local[point * embedDimsAlign], v3Local[point * embedDimsAlign], + Muls(w3v3Local[point * embedDimsAlign], v3Local[point * embedDimsAlign], w3Local.GetValue(point), embedDimsAlign); - Muls(tmp4Local[point * embedDimsAlign], v4Local[point * embedDimsAlign], + Muls(w4v4Local[point * embedDimsAlign], v4Local[point * embedDimsAlign], w4Local.GetValue(point), embedDimsAlign); - Add(tmp5Local[point * embedDimsAlign], tmp1Local[point * embedDimsAlign], - tmp2Local[point * embedDimsAlign], embedDimsAlign); - Add(tmp6Local[point * embedDimsAlign], tmp3Local[point * embedDimsAlign], - tmp4Local[point * embedDimsAlign], embedDimsAlign); - Add(tmp7Local[point * embedDimsAlign], tmp5Local[point * embedDimsAlign], - tmp6Local[point * embedDimsAlign], embedDimsAlign); + Add(w1v1Local[point * embedDimsAlign], w1v1Local[point * embedDimsAlign], + w2v2Local[point * embedDimsAlign], embedDimsAlign); + Add(w1v1Local[point * embedDimsAlign], w1v1Local[point * embedDimsAlign], + w3v3Local[point * embedDimsAlign], embedDimsAlign); + Add(w1v1Local[point * embedDimsAlign], w1v1Local[point * embedDimsAlign], + w4v4Local[point * embedDimsAlign], embedDimsAlign); Mul(gradWeightLocal[point * embedDimsAlign], topGradLocal, - tmp7Local[point * embedDimsAlign], embedDimsAlign); + w1v1Local[point * embedDimsAlign], embedDimsAlign); } } - Mul(tmp9Local, topGradValueLocal, gradWWeightLocal, numPoints * embedDimsAlign); - Muls(gradSampleXLocLocal, tmp9Local, (DTYPE_VALUE)w, numPoints * embedDimsAlign); - Mul(tmp10Local, topGradValueLocal, gradHWeightLocal, numPoints * embedDimsAlign); - Muls(gradSampleYLocLocal, tmp10Local, (DTYPE_VALUE)h, numPoints * embedDimsAlign); + Mul(tmp5Local, topGradValueLocal, gradWWeightLocal, numPoints * embedDimsAlign); + Muls(gradSampleXLocLocal, tmp5Local, (DTYPE_VALUE)w, numPoints * embedDimsAlign); + Mul(tmp5Local, topGradValueLocal, gradHWeightLocal, numPoints * embedDimsAlign); + Muls(gradSampleYLocLocal, tmp5Local, (DTYPE_VALUE)h, numPoints * embedDimsAlign); Sum(xLocal, gradSampleXLocLocal, sumParams); Sum(yLocal, gradSampleYLocLocal, sumParams); Sum(weightSumLocal, gradWeightLocal, sumParams); @@ -440,8 +430,7 @@ private: TBuf lwUb, lhUb, hwUb, hhUb, w1Ub, w2Ub, w3Ub, w4Ub, v1Ub, v2Ub, v3Ub, v4Ub; - TBuf tmpUb, tmp1Ub, tmp2Ub, tmp3Ub, tmp4Ub, tmp5Ub, tmp6Ub, tmp7Ub, tmp8Ub, tmp9Ub, tmp10Ub, - midUb; + TBuf tmpUb, w1v1Ub, w2v2Ub, w3v3Ub, w4v4Ub, tmp5Ub, midUb; TBuf gradHWeightUb, gradWWeightUb, topGradValueUb, gradWeightUb, gradSampleXLocUb, gradSampleYLocUb; @@ -485,8 +474,7 @@ private: LocalTensor weightLocal, xLocal, yLocal, weightSumLocal, floatOneLocal, topGradLocal, lwLocal, lhLocal, locWLocal, locHLocal, hImLocal, wImLocal, hLowFloatLocal, wLowFloatLocal, w1Local, w2Local, w3Local, w4Local, v1Local, v2Local, v3Local, v4Local, hwLocal, hhLocal, gradHWeightLocal, gradWWeightLocal, topGradValueLocal, - gradWeightLocal, tmp1Local, tmp2Local, tmp3Local, tmp4Local, tmp5Local, tmp6Local, tmp7Local, tmp8Local, - tmp9Local, tmp10Local, tmpLocal, midLocal, gradSampleXLocLocal, gradSampleYLocLocal, locationLocal, + gradWeightLocal, w1v1Local, w2v2Local, w3v3Local, w4v4Local, tmp5Local, tmpLocal, midLocal, gradSampleXLocLocal, gradSampleYLocLocal, locationLocal, attentionWeightLocal; event_t eventIdVToMte3, eventIdMte2ToV, eventIdMte3ToV; -- Gitee From 4af8ed92768eb17889b13740b3913d28dadf5d5b Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Sat, 9 Mar 2024 16:11:44 +0800 Subject: [PATCH 5/9] 3 --- .../multi_scale_deformable_attention_grad.cpp | 68 ++++++++----------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index e89a4f4a..b0d6b766 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -36,6 +36,7 @@ public: { pipe = tmpPipe; curBlockIdx = GetBlockIdx(); + blockBytes = 32; dataAlign = blockBytes / sizeof(DTYPE_VALUE); numKeys = tiling_data->numKeys; @@ -120,7 +121,6 @@ public: pipe->InitBuffer(tmpXUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmpYUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(weightSumUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(weightUb, BUFFER_NUM, 4 * numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(locWUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(locHUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); @@ -180,7 +180,6 @@ public: attentionWeightLocal = attentionWeightsUb.Get(); shapesLocal = shapeUb.Get(); offsetLocal = offsetUb.Get(); - weightLocal = weightUb.Get(); xLocal = tmpXUb.Get(); yLocal = tmpYUb.Get(); weightSumLocal = weightSumUb.Get(); @@ -257,8 +256,7 @@ public: private: template - __aicore__ inline void ComputeGrad(uint32_t midId, - LocalTensor &hPtrOffsetLocal, + __aicore__ inline void ComputeGrad(uint32_t midId, LocalTensor &hPtrOffsetLocal, LocalTensor &wPtrOffsetLocal, LocalTensor &wLocal, LocalTensor &vLocal, LocalTensor &distanceHLocal, @@ -424,8 +422,8 @@ private: TBuf locationUb, attentionWeightsUb, shapeUb, offsetUb; TBuf tmpXUb, tmpYUb, weightSumUb; - TBuf floatOneUb, weightUb, topGradUb; - TBuf valueUb, locWUb, locHUb, hImUb, wImUb, hLowUb, wLowUb, hHighUb, wHighUb, hLowFloatUb, + TBuf floatOneUb, topGradUb; + TBuf locWUb, locHUb, hImUb, wImUb, hLowUb, wLowUb, hHighUb, wHighUb, hLowFloatUb, wLowFloatUb, hHighFloatUb, wHighFloatUb, hHighPtrOffsetUb, hLowPtrOffsetUb, wHighPtrOffsetUb, wLowPtrOffsetUb; TBuf lwUb, lhUb, hwUb, hhUb, w1Ub, w2Ub, w3Ub, w4Ub, v1Ub, v2Ub, v3Ub, v4Ub; @@ -434,51 +432,43 @@ private: TBuf gradHWeightUb, gradWWeightUb, topGradValueUb, gradWeightUb, gradSampleXLocUb, gradSampleYLocUb; - uint32_t batchSize; - uint32_t numKeys; - uint32_t numHeads; - uint32_t embedDims; - - uint32_t numLevels; - uint32_t numQueries; - uint32_t numPoints; uint32_t coreNum; + uint32_t batchSize, numKeys, numHeads, embedDims, numLevels, numQueries, numPoints; + uint32_t embedDimsAlign, numPointsAlign, numLevelsAlign; - uint32_t embedDimsAlign; - uint32_t numPointsAlign; - uint32_t numLevelsAlign; - - uint32_t batch; - uint32_t query; - uint32_t head; - uint32_t level; - uint32_t point; + uint32_t batch, query, head, level, point; - uint32_t taskNum; - uint32_t taskNumPerCore; uint32_t curBlockIdx; - uint32_t startOffset; - uint32_t endOffset; - uint32_t dataAlign; - uint32_t blockBytes = 32; + uint32_t taskNum, taskNumPerCore; + uint32_t startOffset, endOffset; + + uint32_t dataAlign, blockBytes; uint32_t gradOutStride0, gradOutStride1, gradOutStride2; uint32_t weightStride0, weightStride1, weightStride2; uint32_t valueStride0, valueStride1, valueStride2; - DTYPE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, levelStartId, offsetValue; - + DTYPE_SPATIAL_SHAPES h, w, valueOffset, weightOffset, locationOffset, levelStartId, offsetValue; DTYPE_SPATIAL_SHAPES offsetWeight, offsetLocation, wStride, hStride, basePtr, ptr; - LocalTensor shapesLocal, offsetLocal, hLowLocal, wLowLocal, hHighLocal, wHighLocal, - hHighPtrOffsetLocal, hLowPtrOffsetLocal, wHighPtrOffsetLocal, wLowPtrOffsetLocal; - LocalTensor weightLocal, xLocal, yLocal, weightSumLocal, floatOneLocal, topGradLocal, lwLocal, lhLocal, - locWLocal, locHLocal, hImLocal, wImLocal, hLowFloatLocal, wLowFloatLocal, w1Local, w2Local, w3Local, w4Local, - v1Local, v2Local, v3Local, v4Local, hwLocal, hhLocal, gradHWeightLocal, gradWWeightLocal, topGradValueLocal, - gradWeightLocal, w1v1Local, w2v2Local, w3v3Local, w4v4Local, tmp5Local, tmpLocal, midLocal, gradSampleXLocLocal, gradSampleYLocLocal, locationLocal, - attentionWeightLocal; + + LocalTensor shapesLocal, offsetLocal; + LocalTensor hLowLocal, wLowLocal, hHighLocal, wHighLocal; + LocalTensor hHighPtrOffsetLocal, hLowPtrOffsetLocal, wHighPtrOffsetLocal, wLowPtrOffsetLocal; + LocalTensor floatOneLocal; + LocalTensor xLocal, yLocal; + LocalTensor lwLocal, lhLocal, hwLocal, hhLocal; + LocalTensor locWLocal, locHLocal; + LocalTensor hImLocal, wImLocal; + LocalTensor hLowFloatLocal, wLowFloatLocal; + LocalTensor w1Local, w2Local, w3Local, w4Local; + LocalTensor v1Local, v2Local, v3Local, v4Local; + LocalTensor w1v1Local, w2v2Local, w3v3Local, w4v4Local; + LocalTensor weightSumLocal, midLocal, tmpLocal, tmp5Local; + LocalTensor gradHWeightLocal, gradWWeightLocal, gradWeightLocal, topGradValueLocal; + LocalTensor gradSampleXLocLocal, gradSampleYLocLocal; + LocalTensor topGradLocal, locationLocal, attentionWeightLocal; event_t eventIdVToMte3, eventIdMte2ToV, eventIdMte3ToV; - DataCopyParams copyParamsA, copyParamsB; SumParams sumParams; }; -- Gitee From 62044966bf27449ce814f517c3501c0f3a70f405 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Sat, 9 Mar 2024 16:21:35 +0800 Subject: [PATCH 6/9] 4 --- .../op_kernel/multi_scale_deformable_attention_grad.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index b0d6b766..773d92cc 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -423,8 +423,8 @@ private: TBuf tmpXUb, tmpYUb, weightSumUb; TBuf floatOneUb, topGradUb; - TBuf locWUb, locHUb, hImUb, wImUb, hLowUb, wLowUb, hHighUb, wHighUb, hLowFloatUb, - wLowFloatUb, hHighFloatUb, wHighFloatUb, hHighPtrOffsetUb, hLowPtrOffsetUb, wHighPtrOffsetUb, wLowPtrOffsetUb; + TBuf locWUb, locHUb, hImUb, wImUb, hLowUb, wLowUb, hHighUb, wHighUb, hLowFloatUb, wLowFloatUb, + hHighFloatUb, wHighFloatUb, hHighPtrOffsetUb, hLowPtrOffsetUb, wHighPtrOffsetUb, wLowPtrOffsetUb; TBuf lwUb, lhUb, hwUb, hhUb, w1Ub, w2Ub, w3Ub, w4Ub, v1Ub, v2Ub, v3Ub, v4Ub; @@ -441,7 +441,7 @@ private: uint32_t curBlockIdx; uint32_t taskNum, taskNumPerCore; uint32_t startOffset, endOffset; - + uint32_t dataAlign, blockBytes; uint32_t gradOutStride0, gradOutStride1, gradOutStride2; @@ -450,7 +450,7 @@ private: DTYPE_SPATIAL_SHAPES h, w, valueOffset, weightOffset, locationOffset, levelStartId, offsetValue; DTYPE_SPATIAL_SHAPES offsetWeight, offsetLocation, wStride, hStride, basePtr, ptr; - + LocalTensor shapesLocal, offsetLocal; LocalTensor hLowLocal, wLowLocal, hHighLocal, wHighLocal; LocalTensor hHighPtrOffsetLocal, hLowPtrOffsetLocal, wHighPtrOffsetLocal, wLowPtrOffsetLocal; -- Gitee From dd4f1f653f4b299178e0e80e733b19b51c77ed7e Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Tue, 12 Mar 2024 15:59:13 +0800 Subject: [PATCH 7/9] 4 --- .../kernels/op_kernel/multi_scale_deformable_attention_grad.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index 773d92cc..cb67bd6b 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -32,7 +32,7 @@ public: __aicore__ inline void Init(GM_ADDR value_gm, GM_ADDR spatial_shapes_gm, GM_ADDR level_start_index_gm, GM_ADDR sampling_loc_gm, GM_ADDR attn_weight_gm, GM_ADDR grad_output_gm, GM_ADDR grad_value_gm, GM_ADDR grad_sampling_loc_gm, GM_ADDR grad_attn_weight_gm, - MultiScaleDeformableAttentionGradTilingData *tiling_data, TPipe *tmpPipe) + const MultiScaleDeformableAttentionGradTilingData *tiling_data, TPipe *tmpPipe) { pipe = tmpPipe; curBlockIdx = GetBlockIdx(); -- Gitee From 4e4f9cb777434810bf7b4511fcabdf09adbd6115 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Wed, 13 Mar 2024 09:49:49 +0800 Subject: [PATCH 8/9] 1 --- .../op_kernel/multi_scale_deformable_attention_grad.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index cb67bd6b..a7631a4a 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -308,10 +308,10 @@ private: hStride = w * wStride; DataCopy(locWLocal, locationGm[offsetLocation + level * numPoints * 2], numPointsAlign); DataCopy(locHLocal, locationGm[offsetLocation + level * numPoints * 2 + numPoints], numPointsAlign); - DataCopy(attentionWeightLocal, attentionWeightsGm[offsetWeight + level * numPoints], - numPointsAlign); SetFlag(eventIdMte2ToV); WaitFlag(eventIdMte2ToV); + DataCopy(attentionWeightLocal, attentionWeightsGm[offsetWeight + level * numPoints], + numPointsAlign); Muls(hImLocal, locHLocal, (DTYPE_VALUE)h, numPointsAlign); Muls(wImLocal, locWLocal, (DTYPE_VALUE)w, numPointsAlign); Adds(hImLocal, hImLocal, DTYPE_VALUE(-0.5), numPointsAlign); @@ -350,6 +350,9 @@ private: Duplicate(v3Local, (DTYPE_VALUE)0, numPoints * embedDimsAlign); Duplicate(v4Local, (DTYPE_VALUE)0, numPoints * embedDimsAlign); + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + for (point = 0; point < numPoints; point++) { if (hImLocal.GetValue(point) > -1 && wImLocal.GetValue(point) > -1 && hImLocal.GetValue(point) < h && wImLocal.GetValue(point) < w) { -- Gitee From 813143843d1a92333647f4803c8cb5890a328887 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Wed, 13 Mar 2024 17:59:48 +0800 Subject: [PATCH 9/9] 1 --- .../op_kernel/multi_scale_deformable_attention_grad.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index a7631a4a..dbf4afc6 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -378,8 +378,6 @@ private: v4Local, lhLocal, lwLocal); } } - SetFlag(eventIdMte3ToV); - WaitFlag(eventIdMte3ToV); Muls(w1v1Local[point * embedDimsAlign], v1Local[point * embedDimsAlign], w1Local.GetValue(point), embedDimsAlign); Muls(w2v2Local[point * embedDimsAlign], v2Local[point * embedDimsAlign], @@ -396,6 +394,8 @@ private: w4v4Local[point * embedDimsAlign], embedDimsAlign); Mul(gradWeightLocal[point * embedDimsAlign], topGradLocal, w1v1Local[point * embedDimsAlign], embedDimsAlign); + SetFlag(eventIdMte3ToV); + WaitFlag(eventIdMte3ToV); } } Mul(tmp5Local, topGradValueLocal, gradWWeightLocal, numPoints * embedDimsAlign); -- Gitee