diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function.cpp deleted file mode 100644 index c59529fa7c78e206cb7b07cfdb8cc2a7cb200af0..0000000000000000000000000000000000000000 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function.cpp +++ /dev/null @@ -1,359 +0,0 @@ - -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * - * This sample is a very basic sample that implements vector add on Ascend plaform. - */ -#include "kernel_operator.h" -using namespace AscendC; -constexpr int32_t BUFFER_NUM = 2; - -class KernelMultiScaleDeformableAttnFunctionV2 -{ -public: - __aicore__ inline KernelMultiScaleDeformableAttnFunctionV2() {} - __aicore__ inline void Init(GM_ADDR value, - GM_ADDR value_spatial_shapes, - GM_ADDR value_level_start_index, - GM_ADDR sampling_locations, - GM_ADDR attention_weights, - GM_ADDR output, MultiScaleDeformableAttnFunctionV2TilingData *tiling_data) - { - ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); - dataAlign = blockNum / sizeof(DTYPE_VALUE); - batchSize = tiling_data->batchSize; - numKeys = tiling_data->numKeys; - numHeads = tiling_data->numHeads; - embedDims = tiling_data->embedDims; - - numLevels = tiling_data->numLevels; - numQueries = tiling_data->numQueries; - numPoints = tiling_data->numPoints; - coreNum = tiling_data->coreNum; - - taskNum = batchSize * numQueries; - taskNumPerCore = DivCeil(taskNum, coreNum); - - embedDimsAlign = AlignUp(embedDims, dataAlign); - numPointsAlign = AlignUp(numPoints, dataAlign); - numLevelsAlign = AlignUp(numLevels, dataAlign); - - curBlockIdx = GetBlockIdx(); - startOffset = curBlockIdx * taskNumPerCore; - endOffset = (curBlockIdx + 1) * taskNumPerCore; - if (endOffset > taskNum) - { - endOffset = taskNum; - } - - valueGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(value), batchSize * numKeys * numHeads * embedDims); - locationGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(sampling_locations), batchSize * numQueries * numHeads * numLevels * numPoints * 2); - attentionWeightsGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(attention_weights), batchSize * numQueries * numHeads * numLevels * numPoints); - outputGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(output), batchSize * numQueries * numHeads * embedDims); - - valueSpatialShapesGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE_SPATIAL_SHAPES *>(value_spatial_shapes), numLevels * 2); - valueLevelStartIndexGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE_SPATIAL_SHAPES *>(value_level_start_index), numLevels); - - pipe.InitBuffer(shapeQueue, BUFFER_NUM, AlignUp(numLevels * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(offsetQueue, BUFFER_NUM, numLevelsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(locationQueue, BUFFER_NUM, AlignUp(numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(attentionWeightsUb, BUFFER_NUM, AlignUp(numLevels * numPoints, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(outputQueue, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(tmpUb1, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb2, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb3, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb4, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(tmpResUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpResUb2, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(intOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(floatOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(tmpXUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpYUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpParam0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpParam1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(tmpIntX0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntY0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntX1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntY1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - - pipe.InitBuffer(leftTopWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftBottomWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightTopWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightBottomWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(leftTopValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftBottomValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightTopValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightBottomValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - } - - __aicore__ inline void Process() - { - for (uint32_t taskIdx = startOffset; taskIdx < endOffset; taskIdx++) - { - batch = taskIdx / numQueries; - query = taskIdx % numQueries; - pipe_barrier(PIPE_ALL); - Compute(batch, query); - } - } - -private: - __aicore__ inline bool isInRange(DTYPE_VALUE_SPATIAL_SHAPES x, DTYPE_VALUE_SPATIAL_SHAPES upper) - { - return 0 <= x && x < upper; - } - - __aicore__ inline void Compute(uint32_t batch, uint32_t query) - { - LocalTensor tmpResLocal = tmpResUb.Get(); - LocalTensor tmpResLocal2 = tmpResUb2.Get(); - - LocalTensor leftTopValueLocal = leftTopValueUb.Get(); - LocalTensor leftBottomValueUbLocal = leftBottomValueUb.Get(); - LocalTensor rightTopValueUbLocal = rightTopValueUb.Get(); - LocalTensor rightBottomValueUbLocal = rightBottomValueUb.Get(); - - LocalTensor leftTopWeiightLocal = leftTopWieightQueue.Get(); - LocalTensor leftBottomWeightLocal = leftBottomWieightQueue.Get(); - LocalTensor rightTopWeiightLocal = rightTopWieightQueue.Get(); - LocalTensor rightBottomWeightLocal = rightBottomWieightQueue.Get(); - - LocalTensor shapesLocal = shapeQueue.AllocTensor(); - LocalTensor offsetLocal = offsetQueue.AllocTensor(); - - LocalTensor locationLocal = locationQueue.AllocTensor(); - LocalTensor attentionWeightLocal = attentionWeightsUb.AllocTensor(); - - LocalTensor resLocal = outputQueue.AllocTensor(); - - LocalTensor xLocal = tmpXUb.Get(); - LocalTensor yLocal = tmpYUb.Get(); - - LocalTensor param0Local = tmpParam0Ub.Get(); - LocalTensor param1Local = tmpParam1Ub.Get(); - - LocalTensor x1Local = tmpIntX1Ub.Get(); - LocalTensor y1Local = tmpIntY1Ub.Get(); - - LocalTensor x0Local = tmpIntX0Ub.Get(); - LocalTensor y0Local = tmpIntY0Ub.Get(); - - LocalTensor tmpLocal1 = tmpUb1.Get(); - LocalTensor tmpLocal2 = tmpUb2.Get(); - LocalTensor tmpLocal3 = tmpUb3.Get(); - LocalTensor tmpLocal4 = tmpUb4.Get(); - - LocalTensor intOneLocal = intOneUb.Get(); - LocalTensor floatOneLocal = floatOneUb.Get(); - - Duplicate(intOneLocal, (DTYPE_VALUE_SPATIAL_SHAPES)1, numPointsAlign); - Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign); - DataCopyParams copyParams{1, (uint16_t)(embedDims * sizeof(DTYPE_VALUE)), 0, 0}; - - DataCopy(shapesLocal, valueSpatialShapesGm, AlignUp(numLevels * 2, dataAlign)); - DataCopy(offsetLocal, valueLevelStartIndexGm, numLevelsAlign); - Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); - moveOffset = batch * numQueries * numHeads * embedDims + query * numHeads * embedDims; - pipe_barrier(PIPE_ALL); - - for (uint32_t head = 0; head < numHeads; head++) - { - DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); - } - pipe_barrier(PIPE_ALL); - - for (uint32_t head = 0; head < numHeads; head++) - { - weightOffset = (batch * numQueries * numHeads * numLevels + query * numHeads * numLevels + head * numLevels) * numPoints; - - pipe_barrier(PIPE_ALL); - - DataCopy(locationLocal, locationGm[weightOffset * 2], AlignUp(numLevels * numPoints * 2, dataAlign)); - DataCopy(attentionWeightLocal, attentionWeightsGm[weightOffset], AlignUp(numLevels * numPoints, dataAlign)); - - pipe_barrier(PIPE_ALL); - for (uint32_t level = 0; level < numLevels; level++) - { - h = shapesLocal.GetValue(level * 2); - w = shapesLocal.GetValue(level * 2 + 1); - for (uint32_t point = 0; point < numPoints; point++) - { - locationOffset = (level * numPoints + point) * 2; - xLocal.SetValue(point, locationLocal.GetValue(locationOffset)); - yLocal.SetValue(point, locationLocal.GetValue(locationOffset + 1)); - } - - pipe_barrier(PIPE_ALL); - - Muls(tmpLocal1, xLocal, (DTYPE_VALUE)w, numPointsAlign); - Muls(tmpLocal2, yLocal, (DTYPE_VALUE)h, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Adds(param0Local, tmpLocal1, (DTYPE_VALUE)0.5, numPointsAlign); - Adds(param1Local, tmpLocal2, (DTYPE_VALUE)0.5, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Cast(x1Local, param0Local, RoundMode::CAST_FLOOR, numPointsAlign); - Cast(y1Local, param1Local, RoundMode::CAST_FLOOR, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Adds(tmpLocal3, param0Local, (DTYPE_VALUE)-1, numPointsAlign); - Adds(tmpLocal4, param1Local, (DTYPE_VALUE)-1, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Sub(x0Local, x1Local, intOneLocal, numPointsAlign); - Sub(y0Local, y1Local, intOneLocal, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Cast(xLocal, x0Local, RoundMode::CAST_NONE, numPointsAlign); - Cast(yLocal, y0Local, RoundMode::CAST_NONE, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Sub(tmpLocal1, tmpLocal3, xLocal, numPointsAlign); - Sub(tmpLocal2, tmpLocal4, yLocal, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Abs(param0Local, tmpLocal1, numPointsAlign); - Abs(param1Local, tmpLocal2, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Sub(xLocal, floatOneLocal, param0Local, numPointsAlign); - Sub(yLocal, floatOneLocal, param1Local, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Mul(leftTopWeiightLocal, xLocal, yLocal, numPointsAlign); - Mul(leftBottomWeightLocal, xLocal, param1Local, numPointsAlign); - Mul(rightTopWeiightLocal, param0Local, yLocal, numPointsAlign); - Mul(rightBottomWeightLocal, param0Local, param1Local, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); - - for (uint32_t point = 0; point < numPoints; point++) - { - Duplicate(leftTopValueLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(leftBottomValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(rightTopValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(rightBottomValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - - x0 = x0Local.GetValue(point); - y0 = y0Local.GetValue(point); - x1 = x1Local.GetValue(point); - y1 = y1Local.GetValue(point); - - valueOffset = batch * numKeys * numHeads + offsetLocal.GetValue(level) * numHeads + head; - pipe_barrier(PIPE_ALL); - - if (isInRange(x0, w)) - { - if (isInRange(y0, h)) - { - DataCopy(leftTopValueLocal, valueGm[(valueOffset + (y0 * w + x0) * numHeads) * embedDims], embedDimsAlign); - } - if (isInRange(y1, h)) - { - DataCopy(leftBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x0) * numHeads) * embedDims], embedDimsAlign); - } - } - if (isInRange(x1, w)) - { - if (isInRange(y0, h)) - { - DataCopy(rightTopValueUbLocal, valueGm[(valueOffset + (y0 * w + x1) * numHeads) * embedDims], embedDimsAlign); - } - if (isInRange(y1, h)) - { - DataCopy(rightBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x1) * numHeads) * embedDims], embedDimsAlign); - } - } - pipe_barrier(PIPE_ALL); - - Muls(leftTopValueLocal, leftTopValueLocal, leftTopWeiightLocal.GetValue(point), embedDimsAlign); - Muls(rightTopValueUbLocal, rightTopValueUbLocal, rightTopWeiightLocal.GetValue(point), embedDimsAlign); - Muls(leftBottomValueUbLocal, leftBottomValueUbLocal, leftBottomWeightLocal.GetValue(point), embedDimsAlign); - Muls(rightBottomValueUbLocal, rightBottomValueUbLocal, rightBottomWeightLocal.GetValue(point), embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(tmpResLocal, leftTopValueLocal, rightTopValueUbLocal, embedDimsAlign); - Add(tmpResLocal2, leftBottomValueUbLocal, rightBottomValueUbLocal, embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(tmpResLocal, tmpResLocal, tmpResLocal2, embedDimsAlign); - pipe_barrier(PIPE_ALL); - Muls(tmpResLocal, tmpResLocal, attentionWeightLocal.GetValue(level * numPoints + point), embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(resLocal, resLocal, tmpResLocal, embedDimsAlign); - } - pipe_barrier(PIPE_ALL); - - SetAtomicAdd(); - DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); - SetAtomicNone(); - } - } - locationQueue.FreeTensor(locationLocal); - attentionWeightsUb.FreeTensor(attentionWeightLocal); - outputQueue.FreeTensor(resLocal); - shapeQueue.FreeTensor(shapesLocal); - offsetQueue.FreeTensor(offsetLocal); - } - -private: - TPipe pipe; - GlobalTensor valueGm, locationGm, attentionWeightsGm, outputGm; - GlobalTensor valueSpatialShapesGm, valueLevelStartIndexGm; - - TQue locationQueue, attentionWeightsUb, shapeQueue, offsetQueue; - TQue outputQueue; - - TBuf tmpResUb, tmpResUb2, tmpXUb, tmpYUb, tmpParam0Ub, tmpParam1Ub, tmpIntX0Ub, tmpIntY0Ub, tmpIntX1Ub, tmpIntY1Ub, tmpUb1, tmpUb2, tmpUb3, tmpUb4; - TBuf intOneUb, floatOneUb, leftTopValueUb, leftBottomValueUb, rightTopValueUb, rightBottomValueUb; - TBuf leftTopWieightQueue, leftBottomWieightQueue, rightTopWieightQueue, rightBottomWieightQueue; - - uint32_t batchSize; - uint32_t numKeys; - uint32_t numHeads; - uint32_t embedDims; - - uint32_t numLevels; - uint32_t numQueries; - uint32_t numPoints; - uint32_t coreNum; - - uint32_t embedDimsAlign; - uint32_t numPointsAlign; - uint32_t numLevelsAlign; - - uint32_t batch; - uint32_t query; - uint32_t head; - - uint32_t taskNum; - uint32_t taskNumPerCore; - uint32_t curBlockIdx; - uint32_t startOffset; - uint32_t endOffset; - uint32_t dataAlign; - uint32_t blockNum = 32; - - DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, moveOffset; -}; - -extern "C" __global__ __aicore__ void multi_scale_deformable_attn_function_v2(GM_ADDR value, - GM_ADDR value_spatial_shapes, - GM_ADDR value_level_start_index, - GM_ADDR sampling_locations, - GM_ADDR attention_weights, - GM_ADDR output, GM_ADDR workspace, GM_ADDR tiling) -{ - GET_TILING_DATA(tiling_data, tiling); - KernelMultiScaleDeformableAttnFunctionV2 op; - op.Init(value, value_spatial_shapes, value_level_start_index, - sampling_locations, attention_weights, output, &tiling_data); - op.Process(); -} diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp index c59529fa7c78e206cb7b07cfdb8cc2a7cb200af0..0419e7919f46a33a5279c43659077a9905eaf412 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp @@ -1,12 +1,11 @@ - /* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. * * This sample is a very basic sample that implements vector add on Ascend plaform. */ #include "kernel_operator.h" using namespace AscendC; -constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t BUFFER_NUM = 1; class KernelMultiScaleDeformableAttnFunctionV2 { @@ -17,8 +16,11 @@ public: GM_ADDR value_level_start_index, GM_ADDR sampling_locations, GM_ADDR attention_weights, - GM_ADDR output, MultiScaleDeformableAttnFunctionV2TilingData *tiling_data) + GM_ADDR output, + MultiScaleDeformableAttnFunctionV2TilingData *tiling_data, + TPipe *tmpPipe) { + pipe = tmpPipe; ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); dataAlign = blockNum / sizeof(DTYPE_VALUE); batchSize = tiling_data->batchSize; @@ -31,7 +33,9 @@ public: numPoints = tiling_data->numPoints; coreNum = tiling_data->coreNum; - taskNum = batchSize * numQueries; + tailNum = numHeads * embedDims; + + taskNum = numQueries; taskNumPerCore = DivCeil(taskNum, coreNum); embedDimsAlign = AlignUp(embedDims, dataAlign); @@ -54,53 +58,48 @@ public: valueSpatialShapesGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE_SPATIAL_SHAPES *>(value_spatial_shapes), numLevels * 2); valueLevelStartIndexGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE_SPATIAL_SHAPES *>(value_level_start_index), numLevels); - pipe.InitBuffer(shapeQueue, BUFFER_NUM, AlignUp(numLevels * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(offsetQueue, BUFFER_NUM, numLevelsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(shapeQueue, BUFFER_NUM, AlignUp(numLevels * 2, dataAlign) * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(offsetQueue, BUFFER_NUM, numLevelsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(locationQueue, BUFFER_NUM, AlignUp(numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(attentionWeightsUb, BUFFER_NUM, AlignUp(numLevels * numPoints, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(outputQueue, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(locationQueue, BUFFER_NUM, AlignUp(numHeads * numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(attentionWeightsUb, BUFFER_NUM, AlignUp(numHeads * numLevels * numPoints, dataAlign) * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(outputQueue, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb1, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb2, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb3, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb4, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(emptyUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpResUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpResUb2, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb1, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb2, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb3, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb4, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(intOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(floatOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(intOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(floatOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpXUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpYUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpParam0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpParam1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpXUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpYUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpParam0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpParam1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpIntX0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntY0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntX1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntY1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntX0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntY0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntX1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntY1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(leftTopWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftBottomWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightTopWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightBottomWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(leftTopWieightQueue, BUFFER_NUM, 4 * numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftTopValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftBottomValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightTopValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightBottomValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(valueUb, BUFFER_NUM, numPoints * 4 * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb2, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb3, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); } __aicore__ inline void Process() { for (uint32_t taskIdx = startOffset; taskIdx < endOffset; taskIdx++) { - batch = taskIdx / numQueries; - query = taskIdx % numQueries; - pipe_barrier(PIPE_ALL); - Compute(batch, query); + SetAtomicAdd(); + Compute(taskIdx); + SetAtomicNone(); } } @@ -110,215 +109,208 @@ private: return 0 <= x && x < upper; } - __aicore__ inline void Compute(uint32_t batch, uint32_t query) + __aicore__ inline void Compute(uint32_t query) { - LocalTensor tmpResLocal = tmpResUb.Get(); - LocalTensor tmpResLocal2 = tmpResUb2.Get(); - - LocalTensor leftTopValueLocal = leftTopValueUb.Get(); - LocalTensor leftBottomValueUbLocal = leftBottomValueUb.Get(); - LocalTensor rightTopValueUbLocal = rightTopValueUb.Get(); - LocalTensor rightBottomValueUbLocal = rightBottomValueUb.Get(); - - LocalTensor leftTopWeiightLocal = leftTopWieightQueue.Get(); - LocalTensor leftBottomWeightLocal = leftBottomWieightQueue.Get(); - LocalTensor rightTopWeiightLocal = rightTopWieightQueue.Get(); - LocalTensor rightBottomWeightLocal = rightBottomWieightQueue.Get(); + LocalTensor locationLocal = locationQueue.AllocTensor(); + LocalTensor attentionWeightLocal = attentionWeightsUb.AllocTensor(); LocalTensor shapesLocal = shapeQueue.AllocTensor(); LocalTensor offsetLocal = offsetQueue.AllocTensor(); - LocalTensor locationLocal = locationQueue.AllocTensor(); - LocalTensor attentionWeightLocal = attentionWeightsUb.AllocTensor(); + DataCopy(shapesLocal, valueSpatialShapesGm, AlignUp(numLevels * 2, dataAlign)); + DataCopy(offsetLocal, valueLevelStartIndexGm, numLevelsAlign); - LocalTensor resLocal = outputQueue.AllocTensor(); + DataCopyParams copyParams{1, (uint16_t)(embedDims * sizeof(DTYPE_VALUE)), 0, 0}; - LocalTensor xLocal = tmpXUb.Get(); - LocalTensor yLocal = tmpYUb.Get(); + LocalTensor valueLocal = valueUb.Get(); - LocalTensor param0Local = tmpParam0Ub.Get(); - LocalTensor param1Local = tmpParam1Ub.Get(); + event_t eventIdVToMte3 = static_cast(GetTPipePtr()->AllocEventID()); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->AllocEventID()); - LocalTensor x1Local = tmpIntX1Ub.Get(); - LocalTensor y1Local = tmpIntY1Ub.Get(); + for (uint32_t batch = 0; batch < batchSize; batch++) + { + LocalTensor emptyUbLocal = emptyUb.Get(); - LocalTensor x0Local = tmpIntX0Ub.Get(); - LocalTensor y0Local = tmpIntY0Ub.Get(); + LocalTensor weightLocal = leftTopWieightQueue.Get(); - LocalTensor tmpLocal1 = tmpUb1.Get(); - LocalTensor tmpLocal2 = tmpUb2.Get(); - LocalTensor tmpLocal3 = tmpUb3.Get(); - LocalTensor tmpLocal4 = tmpUb4.Get(); + LocalTensor xLocal = tmpXUb.Get(); + LocalTensor yLocal = tmpYUb.Get(); - LocalTensor intOneLocal = intOneUb.Get(); - LocalTensor floatOneLocal = floatOneUb.Get(); + LocalTensor tmpResLocal = tmpResUb.Get(); + LocalTensor tmpResLocal2 = tmpResUb2.Get(); + LocalTensor tmpResLocal3 = tmpResUb3.Get(); - Duplicate(intOneLocal, (DTYPE_VALUE_SPATIAL_SHAPES)1, numPointsAlign); - Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign); - DataCopyParams copyParams{1, (uint16_t)(embedDims * sizeof(DTYPE_VALUE)), 0, 0}; + LocalTensor param0Local = tmpParam0Ub.Get(); + LocalTensor param1Local = tmpParam1Ub.Get(); - DataCopy(shapesLocal, valueSpatialShapesGm, AlignUp(numLevels * 2, dataAlign)); - DataCopy(offsetLocal, valueLevelStartIndexGm, numLevelsAlign); - Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); - moveOffset = batch * numQueries * numHeads * embedDims + query * numHeads * embedDims; - pipe_barrier(PIPE_ALL); + LocalTensor x1Local = tmpIntX1Ub.Get(); + LocalTensor y1Local = tmpIntY1Ub.Get(); - for (uint32_t head = 0; head < numHeads; head++) - { - DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); - } - pipe_barrier(PIPE_ALL); + LocalTensor x0Local = tmpIntX0Ub.Get(); + LocalTensor y0Local = tmpIntY0Ub.Get(); - for (uint32_t head = 0; head < numHeads; head++) - { - weightOffset = (batch * numQueries * numHeads * numLevels + query * numHeads * numLevels + head * numLevels) * numPoints; + LocalTensor tmpLocal1 = tmpUb1.Get(); + LocalTensor tmpLocal2 = tmpUb2.Get(); + LocalTensor tmpLocal3 = tmpUb3.Get(); + LocalTensor tmpLocal4 = tmpUb4.Get(); + + LocalTensor intOneLocal = intOneUb.Get(); + LocalTensor floatOneLocal = floatOneUb.Get(); - pipe_barrier(PIPE_ALL); + Duplicate(intOneLocal, (DTYPE_VALUE_SPATIAL_SHAPES)1, numPointsAlign); + Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign); - DataCopy(locationLocal, locationGm[weightOffset * 2], AlignUp(numLevels * numPoints * 2, dataAlign)); - DataCopy(attentionWeightLocal, attentionWeightsGm[weightOffset], AlignUp(numLevels * numPoints, dataAlign)); + Duplicate(emptyUbLocal, DTYPE_VALUE(0), embedDimsAlign); + moveOffset = batch * numQueries * numHeads * embedDims + query * numHeads * embedDims; - pipe_barrier(PIPE_ALL); - for (uint32_t level = 0; level < numLevels; level++) + for (uint32_t head = 0; head < numHeads; head++) { - h = shapesLocal.GetValue(level * 2); - w = shapesLocal.GetValue(level * 2 + 1); - for (uint32_t point = 0; point < numPoints; point++) - { - locationOffset = (level * numPoints + point) * 2; - xLocal.SetValue(point, locationLocal.GetValue(locationOffset)); - yLocal.SetValue(point, locationLocal.GetValue(locationOffset + 1)); - } + DataCopyPad(outputGm[moveOffset + head * embedDims], emptyUbLocal, copyParams); + } - pipe_barrier(PIPE_ALL); + weightOffset = (batch * numQueries * numHeads * numLevels + query * numHeads * numLevels) * numPoints; - Muls(tmpLocal1, xLocal, (DTYPE_VALUE)w, numPointsAlign); - Muls(tmpLocal2, yLocal, (DTYPE_VALUE)h, numPointsAlign); - pipe_barrier(PIPE_ALL); + DataCopy(locationLocal, locationGm[weightOffset * 2], AlignUp(numHeads * numLevels * numPoints * 2, dataAlign)); + DataCopy(attentionWeightLocal, attentionWeightsGm[weightOffset], AlignUp(numHeads * numLevels * numPoints, dataAlign)); - Adds(param0Local, tmpLocal1, (DTYPE_VALUE)0.5, numPointsAlign); - Adds(param1Local, tmpLocal2, (DTYPE_VALUE)0.5, numPointsAlign); - pipe_barrier(PIPE_ALL); + for (uint32_t head = 0; head < numHeads; head++) + { + for (uint32_t level = 0; level < numLevels; level++) + { + h = shapesLocal.GetValue(level * 2); + w = shapesLocal.GetValue(level * 2 + 1); - Cast(x1Local, param0Local, RoundMode::CAST_FLOOR, numPointsAlign); - Cast(y1Local, param1Local, RoundMode::CAST_FLOOR, numPointsAlign); - pipe_barrier(PIPE_ALL); + weightOffset = (head * numLevels + level) * numPoints; + locationOffset = weightOffset * 2; + for (uint32_t point = 0; point < numPoints; point++) + { + xLocal.SetValue(point, locationLocal.GetValue(locationOffset + point * 2)); + yLocal.SetValue(point, locationLocal.GetValue(locationOffset + point * 2 + 1)); + } - Adds(tmpLocal3, param0Local, (DTYPE_VALUE)-1, numPointsAlign); - Adds(tmpLocal4, param1Local, (DTYPE_VALUE)-1, numPointsAlign); - pipe_barrier(PIPE_ALL); + Muls(tmpLocal1, xLocal, (DTYPE_VALUE)w, numPointsAlign); + Muls(tmpLocal2, yLocal, (DTYPE_VALUE)h, numPointsAlign); - Sub(x0Local, x1Local, intOneLocal, numPointsAlign); - Sub(y0Local, y1Local, intOneLocal, numPointsAlign); - pipe_barrier(PIPE_ALL); + Adds(param0Local, tmpLocal1, (DTYPE_VALUE)0.5, numPointsAlign); + Adds(param1Local, tmpLocal2, (DTYPE_VALUE)0.5, numPointsAlign); - Cast(xLocal, x0Local, RoundMode::CAST_NONE, numPointsAlign); - Cast(yLocal, y0Local, RoundMode::CAST_NONE, numPointsAlign); - pipe_barrier(PIPE_ALL); + Cast(x1Local, param0Local, RoundMode::CAST_FLOOR, numPointsAlign); + Cast(y1Local, param1Local, RoundMode::CAST_FLOOR, numPointsAlign); - Sub(tmpLocal1, tmpLocal3, xLocal, numPointsAlign); - Sub(tmpLocal2, tmpLocal4, yLocal, numPointsAlign); - pipe_barrier(PIPE_ALL); + Adds(tmpLocal3, param0Local, (DTYPE_VALUE)-1, numPointsAlign); + Adds(tmpLocal4, param1Local, (DTYPE_VALUE)-1, numPointsAlign); - Abs(param0Local, tmpLocal1, numPointsAlign); - Abs(param1Local, tmpLocal2, numPointsAlign); - pipe_barrier(PIPE_ALL); + Sub(x0Local, x1Local, intOneLocal, numPointsAlign); + Sub(y0Local, y1Local, intOneLocal, numPointsAlign); - Sub(xLocal, floatOneLocal, param0Local, numPointsAlign); - Sub(yLocal, floatOneLocal, param1Local, numPointsAlign); - pipe_barrier(PIPE_ALL); + Cast(xLocal, x0Local, RoundMode::CAST_NONE, numPointsAlign); + Cast(yLocal, y0Local, RoundMode::CAST_NONE, numPointsAlign); - Mul(leftTopWeiightLocal, xLocal, yLocal, numPointsAlign); - Mul(leftBottomWeightLocal, xLocal, param1Local, numPointsAlign); - Mul(rightTopWeiightLocal, param0Local, yLocal, numPointsAlign); - Mul(rightBottomWeightLocal, param0Local, param1Local, numPointsAlign); - pipe_barrier(PIPE_ALL); + Sub(tmpLocal1, tmpLocal3, xLocal, numPointsAlign); + Sub(tmpLocal2, tmpLocal4, yLocal, numPointsAlign); - Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); + Abs(param0Local, tmpLocal1, numPointsAlign); + Abs(param1Local, tmpLocal2, numPointsAlign); - for (uint32_t point = 0; point < numPoints; point++) - { - Duplicate(leftTopValueLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(leftBottomValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(rightTopValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(rightBottomValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); + Sub(xLocal, floatOneLocal, param0Local, numPointsAlign); + Sub(yLocal, floatOneLocal, param1Local, numPointsAlign); + + Mul(weightLocal, xLocal, yLocal, numPointsAlign); + Mul(weightLocal[numPointsAlign], xLocal, param1Local, numPointsAlign); + Mul(weightLocal[numPointsAlign * 2], param0Local, yLocal, numPointsAlign); + Mul(weightLocal[numPointsAlign * 3], param0Local, param1Local, numPointsAlign); - x0 = x0Local.GetValue(point); - y0 = y0Local.GetValue(point); - x1 = x1Local.GetValue(point); - y1 = y1Local.GetValue(point); + Mul(weightLocal, weightLocal, attentionWeightLocal[weightOffset], numPointsAlign); + Mul(weightLocal[numPointsAlign], weightLocal[numPointsAlign], attentionWeightLocal[weightOffset], numPointsAlign); + Mul(weightLocal[numPointsAlign * 2], weightLocal[numPointsAlign * 2], attentionWeightLocal[weightOffset], numPointsAlign); + Mul(weightLocal[numPointsAlign * 3], weightLocal[numPointsAlign * 3], attentionWeightLocal[weightOffset], numPointsAlign); - valueOffset = batch * numKeys * numHeads + offsetLocal.GetValue(level) * numHeads + head; - pipe_barrier(PIPE_ALL); + valueOffset = (batch * numKeys * numHeads + offsetLocal.GetValue(level) * numHeads + head) * embedDims; - if (isInRange(x0, w)) + Duplicate(valueLocal, DTYPE_VALUE(0), 4 * numPoints * embedDimsAlign); + for (uint32_t point = 0; point < numPoints; point++) { - if (isInRange(y0, h)) + x0 = x0Local.GetValue(point); + y0 = y0Local.GetValue(point); + x1 = x1Local.GetValue(point); + y1 = y1Local.GetValue(point); + + if (isInRange(x0, w)) { - DataCopy(leftTopValueLocal, valueGm[(valueOffset + (y0 * w + x0) * numHeads) * embedDims], embedDimsAlign); + if (isInRange(y0, h)) + { + DataCopy(valueLocal[point * embedDimsAlign * 4], valueGm[valueOffset + (y0 * w + x0) * tailNum], embedDimsAlign); + } + if (isInRange(y1, h)) + { + DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], valueGm[valueOffset + (y1 * w + x0) * tailNum], embedDimsAlign); + } } - if (isInRange(y1, h)) + if (isInRange(x1, w)) { - DataCopy(leftBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x0) * numHeads) * embedDims], embedDimsAlign); + if (isInRange(y0, h)) + { + DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueGm[valueOffset + (y0 * w + x1) * tailNum], embedDimsAlign); + } + if (isInRange(y1, h)) + { + DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], valueGm[valueOffset + (y1 * w + x1) * tailNum], embedDimsAlign); + } } } - if (isInRange(x1, w)) + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + + for (uint32_t point = 0; point < numPoints; point++) { - if (isInRange(y0, h)) - { - DataCopy(rightTopValueUbLocal, valueGm[(valueOffset + (y0 * w + x1) * numHeads) * embedDims], embedDimsAlign); - } - if (isInRange(y1, h)) - { - DataCopy(rightBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x1) * numHeads) * embedDims], embedDimsAlign); - } + leftTopWeight = weightLocal.GetValue(point); + leftBottomWeight = weightLocal.GetValue(numPointsAlign + point); + rightTopWeiight = weightLocal.GetValue(numPointsAlign * 2 + point); + rightBottomWeight = weightLocal.GetValue(numPointsAlign * 3 + point); + + Muls(valueLocal[point * embedDimsAlign * 4], valueLocal[point * embedDimsAlign * 4], leftTopWeight, embedDimsAlign); + Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], leftBottomWeight, embedDimsAlign); + Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], rightTopWeiight, embedDimsAlign); + Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], rightBottomWeight, embedDimsAlign); + + Add(tmpResLocal[point * embedDimsAlign], valueLocal[point * embedDimsAlign * 4], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], embedDimsAlign); + Add(tmpResLocal2[point * embedDimsAlign], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], embedDimsAlign); + Add(tmpResLocal3[point * embedDimsAlign], tmpResLocal[point * embedDimsAlign], tmpResLocal2[point * embedDimsAlign], embedDimsAlign); + + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + DataCopyPad(outputGm[moveOffset + head * embedDims], tmpResLocal3[point * embedDimsAlign], copyParams); } - pipe_barrier(PIPE_ALL); - - Muls(leftTopValueLocal, leftTopValueLocal, leftTopWeiightLocal.GetValue(point), embedDimsAlign); - Muls(rightTopValueUbLocal, rightTopValueUbLocal, rightTopWeiightLocal.GetValue(point), embedDimsAlign); - Muls(leftBottomValueUbLocal, leftBottomValueUbLocal, leftBottomWeightLocal.GetValue(point), embedDimsAlign); - Muls(rightBottomValueUbLocal, rightBottomValueUbLocal, rightBottomWeightLocal.GetValue(point), embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(tmpResLocal, leftTopValueLocal, rightTopValueUbLocal, embedDimsAlign); - Add(tmpResLocal2, leftBottomValueUbLocal, rightBottomValueUbLocal, embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(tmpResLocal, tmpResLocal, tmpResLocal2, embedDimsAlign); - pipe_barrier(PIPE_ALL); - Muls(tmpResLocal, tmpResLocal, attentionWeightLocal.GetValue(level * numPoints + point), embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(resLocal, resLocal, tmpResLocal, embedDimsAlign); } - pipe_barrier(PIPE_ALL); - - SetAtomicAdd(); - DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); - SetAtomicNone(); } } locationQueue.FreeTensor(locationLocal); attentionWeightsUb.FreeTensor(attentionWeightLocal); - outputQueue.FreeTensor(resLocal); + shapeQueue.FreeTensor(shapesLocal); offsetQueue.FreeTensor(offsetLocal); + + GetTPipePtr()->ReleaseEventID(eventIdVToMte3); + GetTPipePtr()->ReleaseEventID(eventIdMte2ToV); } private: - TPipe pipe; + TPipe *pipe; GlobalTensor valueGm, locationGm, attentionWeightsGm, outputGm; GlobalTensor valueSpatialShapesGm, valueLevelStartIndexGm; TQue locationQueue, attentionWeightsUb, shapeQueue, offsetQueue; TQue outputQueue; - TBuf tmpResUb, tmpResUb2, tmpXUb, tmpYUb, tmpParam0Ub, tmpParam1Ub, tmpIntX0Ub, tmpIntY0Ub, tmpIntX1Ub, tmpIntY1Ub, tmpUb1, tmpUb2, tmpUb3, tmpUb4; - TBuf intOneUb, floatOneUb, leftTopValueUb, leftBottomValueUb, rightTopValueUb, rightBottomValueUb; - TBuf leftTopWieightQueue, leftBottomWieightQueue, rightTopWieightQueue, rightBottomWieightQueue; + TBuf tmpResUb, tmpResUb2, tmpResUb3, tmpXUb, tmpYUb, tmpParam0Ub, tmpParam1Ub, tmpIntX0Ub, tmpIntY0Ub, tmpIntX1Ub, tmpIntY1Ub, tmpUb1, tmpUb2, tmpUb3, tmpUb4; + TBuf intOneUb, floatOneUb, leftTopWieightQueue, emptyUb; + TBuf valueUb; uint32_t batchSize; uint32_t numKeys; uint32_t numHeads; uint32_t embedDims; + uint32_t tailNum; uint32_t numLevels; uint32_t numQueries; @@ -341,6 +333,7 @@ private: uint32_t dataAlign; uint32_t blockNum = 32; + DTYPE_VALUE leftTopWeight, rightTopWeiight, leftBottomWeight, rightBottomWeight, attnWeight; DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, moveOffset; }; @@ -351,9 +344,10 @@ extern "C" __global__ __aicore__ void multi_scale_deformable_attn_function_v2(GM GM_ADDR attention_weights, GM_ADDR output, GM_ADDR workspace, GM_ADDR tiling) { + TPipe pipe; // GET_TILING_DATA(tiling_data, tiling); KernelMultiScaleDeformableAttnFunctionV2 op; op.Init(value, value_spatial_shapes, value_level_start_index, - sampling_locations, attention_weights, output, &tiling_data); + sampling_locations, attention_weights, output, &tiling_data, &pipe); op.Process(); }