diff --git a/ads/common/ops/kernels/op_host/multi_scale_deformable_attn_function_v2.cpp b/ads/common/ops/kernels/op_host/multi_scale_deformable_attn_function_v2.cpp index 246b3e548b6a6c56c928abd565e992efb3918348..736a42eff5419ee5b30818b39f750bce348bdc48 100644 --- a/ads/common/ops/kernels/op_host/multi_scale_deformable_attn_function_v2.cpp +++ b/ads/common/ops/kernels/op_host/multi_scale_deformable_attn_function_v2.cpp @@ -123,6 +123,10 @@ namespace ops this->AICore() .SetTiling(optiling::TilingFuncForMultiScaleDeformableAttnFunctionV2); + OpAICoreConfig aiConfig; + aiConfig.ExtendCfgInfo("enableVectorCore.flag", "false"); + aiConfig.DynamicCompileStaticFlag(true); + this->AICore().AddConfig("ascend310p", aiConfig); this->AICore().AddConfig("ascend910b"); } }; diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp index ffaf649e8faada74f806dfd2a43ed575440b9c5f..e3d150023ac332ca4ee47ffeaeb8e286a827b8d0 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp @@ -1,8 +1,8 @@ /* -* Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. -* -* This sample is a very basic sample that implements vector add on Ascend plaform. -*/ + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. + * + * This sample is a very basic sample that implements vector add on Ascend plaform. + */ #include "kernel_operator.h" using namespace AscendC; constexpr int32_t BUFFER_NUM = 1; @@ -83,16 +83,14 @@ public: pipe->InitBuffer(valueUb, BUFFER_NUM, batchOffset * 4 * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmpResUb, BUFFER_NUM, batchOffset * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmpResUb2, BUFFER_NUM, batchOffset * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpResUb3, BUFFER_NUM, batchOffset * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb3, BUFFER_NUM, numHeads * batchOffset * sizeof(DTYPE_VALUE)); } __aicore__ inline void Process() { for (uint32_t taskIdx = startOffset; taskIdx < endOffset; taskIdx++) { - SetAtomicAdd(); Compute(taskIdx); - SetAtomicNone(); } } @@ -117,7 +115,6 @@ private: event_t eventIdVToMte3 = static_cast(GetTPipePtr()->AllocEventID()); event_t eventIdMte2ToV = static_cast(GetTPipePtr()->AllocEventID()); - event_t eventIdMte3ToV = static_cast(GetTPipePtr()->AllocEventID()); for (uint32_t batch = 0; batch < batchSize; batch++) { @@ -151,12 +148,21 @@ private: for (uint32_t head = 0; head < numHeads; head++) { - dstOffset = moveOffset + head * embedDims; - for (uint32_t level = 0; level < numLevels; level++) + DataCopy(outputGm[moveOffset + head * embedDims], emptyUbLocal, embedDims); + } + pipe_barrier(PIPE_ALL); + + for (uint32_t level = 0; level < numLevels; level++) + { + h = shapesLocal.GetValue(level * 2); + w = shapesLocal.GetValue(level * 2 + 1); + + SetAtomicAdd(); + for (uint32_t head = 0; head < numHeads; head++) { - srcOffset = level * batchOffset; - h = shapesLocal.GetValue(level * 2); - w = shapesLocal.GetValue(level * 2 + 1); + srcOffset = head * batchOffset; + dstOffset = moveOffset + head * embedDims; + weightOffset = (head * numLevels + level) * numPoints; DataCopy(attentionWeightLocal, attentionWeightsGm[dataOffset + weightOffset], AlignUp(numPoints, dataAlign)); SetFlag(eventIdMte2ToV); @@ -209,6 +215,13 @@ private: { DataCopy(valueLocal[batchOffset + point * embedDims], valueGm[valueOffset + (y1 * w + x0) * tailNum], embedDims); } + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + leftTopWeight = weightLocal.GetValue(point); + leftBottomWeight = weightLocal.GetValue(numPointsAlign + point); + + Muls(valueLocal[point * embedDims], valueLocal[point * embedDims], leftTopWeight, embedDims); + Muls(valueLocal[batchOffset + point * embedDims], valueLocal[batchOffset + point * embedDims], leftBottomWeight, embedDims); } if (isInRange(x1, w)) { @@ -220,42 +233,38 @@ private: { DataCopy(valueLocal[batchOffset * 3 + point * embedDims], valueGm[valueOffset + (y1 * w + x1) * tailNum], embedDims); } + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + + rightTopWeiight = weightLocal.GetValue(numPointsAlign * 2 + point); + rightBottomWeight = weightLocal.GetValue(numPointsAlign * 3 + point); + + Muls(valueLocal[batchOffset * 2 + point * embedDims], valueLocal[batchOffset * 2 + point * embedDims], rightTopWeiight, embedDims); + Muls(valueLocal[batchOffset * 3 + point * embedDims], valueLocal[batchOffset * 3 + point * embedDims], rightBottomWeight, embedDims); } } - SetFlag(eventIdMte2ToV); - WaitFlag(eventIdMte2ToV); - for (uint32_t point = 0; point < numPoints; point++) - { - leftTopWeight = weightLocal.GetValue(point); - leftBottomWeight = weightLocal.GetValue(numPointsAlign + point); - rightTopWeiight = weightLocal.GetValue(numPointsAlign * 2 + point); - rightBottomWeight = weightLocal.GetValue(numPointsAlign * 3 + point); - - Muls(valueLocal[point * embedDims], valueLocal[point * embedDims], leftTopWeight, embedDims); - Muls(valueLocal[batchOffset + point * embedDims], valueLocal[batchOffset + point * embedDims], leftBottomWeight, embedDims); - Muls(valueLocal[batchOffset * 2 + point * embedDims], valueLocal[batchOffset * 2 + point * embedDims], rightTopWeiight, embedDims); - Muls(valueLocal[batchOffset * 3 + point * embedDims], valueLocal[batchOffset * 3 + point * embedDims], rightBottomWeight, embedDims); - } if (embedDims != 32) { pipe_barrier(PIPE_ALL); } + Add(tmpResLocal, valueLocal, valueLocal[batchOffset], batchOffset); Add(tmpResLocal2, valueLocal[batchOffset * 2], valueLocal[batchOffset * 3], batchOffset); - Add(tmpResLocal3, tmpResLocal, tmpResLocal2, batchOffset); + Add(tmpResLocal3[srcOffset], tmpResLocal, tmpResLocal2, batchOffset); SetFlag(eventIdVToMte3); WaitFlag(eventIdVToMte3); + for (uint32_t point = 0; point < numPoints; point++) { - DataCopy(outputGm[dstOffset], tmpResLocal3[point * embedDims], embedDims); + DataCopy(outputGm[dstOffset], tmpResLocal3[srcOffset + point * embedDims], embedDims); } } + SetAtomicNone(); } } GetTPipePtr()->ReleaseEventID(eventIdVToMte3); GetTPipePtr()->ReleaseEventID(eventIdMte2ToV); - GetTPipePtr()->ReleaseEventID(eventIdMte3ToV); } private: @@ -297,7 +306,8 @@ private: uint32_t blockNum = 32; DTYPE_VALUE tmp1, tmp2, leftTopWeight, rightTopWeiight, leftBottomWeight, rightBottomWeight, attnWeight; - DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, dataOffset, locationOffset, moveOffset, batchOffset, dstOffset, srcOffset; + DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1; + DTYPE_VALUE_SPATIAL_SHAPES valueOffset, weightOffset, dataOffset, locationOffset, moveOffset, batchOffset, dstOffset, srcOffset, headOffset; }; extern "C" __global__ __aicore__ void multi_scale_deformable_attn_function_v2(GM_ADDR value, @@ -311,6 +321,6 @@ extern "C" __global__ __aicore__ void multi_scale_deformable_attn_function_v2(GM GET_TILING_DATA(tiling_data, tiling); KernelMultiScaleDeformableAttnFunctionV2 op; op.Init(value, value_spatial_shapes, value_level_start_index, - sampling_locations, attention_weights, output, &tiling_data, &pipe); + sampling_locations, attention_weights, output, &tiling_data, &pipe); op.Process(); -} \ No newline at end of file +}