From 29e1e4d2e5d89e08a80ac9d96d9a29a72589bb38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=93=B2=E7=BB=AD?= Date: Thu, 18 Apr 2024 21:42:55 +0800 Subject: [PATCH] Type: Enhance MSDA. Team: ADS_Ops_Dev. Description: Enhance MSDA. --- .../csrc/MultiScaleDeformableAttnFunction.cpp | 8 +- ...ulti_scale_deformable_attn_function_v2.cpp | 181 ++++++++---------- 2 files changed, 86 insertions(+), 103 deletions(-) diff --git a/ads/common/ops/csrc/MultiScaleDeformableAttnFunction.cpp b/ads/common/ops/csrc/MultiScaleDeformableAttnFunction.cpp index 1f906a31..675657ea 100644 --- a/ads/common/ops/csrc/MultiScaleDeformableAttnFunction.cpp +++ b/ads/common/ops/csrc/MultiScaleDeformableAttnFunction.cpp @@ -51,7 +51,7 @@ at::Tensor npu_multi_scale_deformable_attn_function(const at::Tensor& value, con ", num_level is ", num_levels, "."); TORCH_CHECK(embed_dims % 8 == 0, "embed_dims must be a multiple of 8, but embed_dims is ", embed_dims, "."); - at::Tensor result = at::empty(output_size, value.options().dtype(at::kFloat)); + at::Tensor result = at::zeros(output_size, value.options().dtype(at::kFloat)); // reset inputs at::Tensor value_trans = at::transpose(value, 1, 2).contiguous(); @@ -104,9 +104,9 @@ std::tuple multi_scale_deformable_attn_grad( location_size[0], location_size[1], location_size[2], location_size[3], location_size[5], location_size[4]}; at::Tensor value1 = value.transpose(1, 2).contiguous(); at::Tensor location1 = location.transpose(4, 5).contiguous(); - at::Tensor result1 = at::empty(grad_value_size, value.options().dtype(at::kFloat)); - at::Tensor result2 = at::empty(grad_sample_loc_size, location.options().dtype(at::kFloat)); - at::Tensor result3 = at::empty(grad_atten_weight_size, attn_weight.options().dtype(at::kFloat)); + at::Tensor result1 = at::zeros(grad_value_size, value.options().dtype(at::kFloat)); + at::Tensor result2 = at::zeros(grad_sample_loc_size, location.options().dtype(at::kFloat)); + at::Tensor result3 = at::zeros(grad_atten_weight_size, attn_weight.options().dtype(at::kFloat)); at::Tensor value_fp = value1.to(at::kFloat); at::Tensor shape_fp = shape.to(at::kInt); diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp index 02aed9cd..caaa4360 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp @@ -27,7 +27,6 @@ public: coreNum = tiling_data->coreNum; tailNum = numHeads * embedDims; - taskNum = numQueries; taskNumPerCore = DivCeil(taskNum, coreNum); @@ -81,8 +80,8 @@ public: pipe->InitBuffer(weightQueue, 4 * numPointsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(valueUb, batchOffset * 8 * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(cornerWeightUb, batchOffset * 8 * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(valueUb, batchOffset * 4 * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(cornerWeightUb, batchOffset * 4 * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmpResUb, 2 * batchOffset * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmpResUb2, batchOffset * sizeof(DTYPE_VALUE)); @@ -91,6 +90,7 @@ public: __aicore__ inline void Process() { +#if __CCE_AICORE__ == 220 if (embedDims == 32 && numPoints == 2) { ComputeOpt<2>(); } else if (embedDims == 32 && numPoints == 4) { @@ -100,6 +100,9 @@ public: } else { Compute(); } +#else + Compute(); +#endif } private: @@ -135,10 +138,11 @@ private: LocalTensor floatOneLocal = floatOneUb.Get(); LocalTensor halfLocal = halfUb.Get(); LocalTensor locLocal = locUb.Get(); - - Duplicate(emptyUbLocal, DTYPE_VALUE(0), numHeads * embedDimsOpt); - SetFlag(eventIdVToMte3); - WaitFlag(eventIdVToMte3); + if (inner_clean == 1) { + Duplicate(emptyUbLocal, DTYPE_VALUE(0), numHeads * embedDimsOpt); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + } Duplicate(intOneLocal, (DTYPE_VALUE_SPATIAL_SHAPES)1, numPointsAlignOpt); Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlignOptTwice); @@ -167,10 +171,10 @@ private: baseOffset = batch * numHeads * numKeys; moveOffset = (batch * numQueries + query) * numHeads * embedDimsOpt; dataOffset = (batch * numQueries + query) * numHeads * numLevels * NUM_POINTS; - - DataCopy(outputGm[moveOffset], emptyUbLocal, numHeads * embedDimsOpt); - pipe_barrier(PIPE_ALL); - + if (inner_clean == 1) { + DataCopy(outputGm[moveOffset], emptyUbLocal, numHeads * embedDimsOpt); + pipe_barrier(PIPE_ALL); + } SetAtomicAdd(); for (uint32_t level = 0; level < numLevels; level++) { @@ -194,9 +198,9 @@ private: DataCopy(attentionWeightLocal, attentionWeightsGm[weightOffset], numPointsAlignOpt); SetFlag(eventIdMte2ToV_0); + Duplicate(valueLocal, DTYPE_VALUE(0), 4 * NUM_POINTS * embedDimsOpt); for (uint32_t head = 0; head < numHeads; head++) { - Duplicate(valueLocal, DTYPE_VALUE(0), 4 * NUM_POINTS * embedDimsOpt); srcOffset = head * NUM_POINTS * embedDimsOpt; dstOffset = moveOffset + head * embedDimsOpt; valueOffset = (oriOffset + head * numKeys) * embedDimsOpt; @@ -213,53 +217,30 @@ private: x0 = x1 - 1; y0 = y1 - 1; - #if __CCE_AICORE__ == 220 - if (isInRange(y0, h)) { - if (0 < x1 && x1 < w) { - DataCopyPad(valueLocal[point * embedDimsOpt], - valueGm[valueOffset + (y0 * w + x0) * embedDimsOpt], copyParams, padParams); - } else if (isInRange(x0, w)) { - DataCopy(valueLocal[point * embedDimsOpt], - valueGm[valueOffset + (y0 * w + x0) * embedDimsOpt], embedDimsOpt); - } else if (isInRange(x1, w)) { - DataCopy(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOptTwice], - valueGm[valueOffset + (y0 * w + x1) * embedDimsOpt], embedDimsOpt); - } - } - if (isInRange(y1, h)) { - if (0 < x1 && x1 < w) { - DataCopyPad(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOpt], - valueGm[valueOffset + (y1 * w + x0) * embedDimsOpt], copyParams, padParams); - } else if (isInRange(x0, w)) { - DataCopy(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOpt], - valueGm[valueOffset + (y1 * w + x0) * embedDimsOpt], embedDimsOpt); - } else if (isInRange(x1, w)) { - DataCopy(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOptTriple], - valueGm[valueOffset + (y1 * w + x1) * embedDimsOpt], embedDimsOpt); - } - } - #else - if (isInRange(y0, h)) { - if (isInRange(x0, w)) { - DataCopy(valueLocal[point * embedDimsOpt], - valueGm[valueOffset + (y0 * w + x0) * embedDimsOpt], embedDimsOpt); - } - if (isInRange(x1, w)) { - DataCopy(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOptTwice], - valueGm[valueOffset + (y0 * w + x1) * embedDimsOpt], embedDimsOpt); - } + if (isInRange(y0, h)) { + if (0 < x1 && x1 < w) { + DataCopyPad(valueLocal[point * embedDimsOpt], + valueGm[valueOffset + (y0 * w + x0) * embedDimsOpt], copyParams, padParams); + } else if (isInRange(x0, w)) { + DataCopy(valueLocal[point * embedDimsOpt], + valueGm[valueOffset + (y0 * w + x0) * embedDimsOpt], embedDimsOpt); + } else if (isInRange(x1, w)) { + DataCopy(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOptTwice], + valueGm[valueOffset + (y0 * w + x1) * embedDimsOpt], embedDimsOpt); } - if (isInRange(y1, h)) { - if (isInRange(x0, w)) { - DataCopy(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOpt], - valueGm[valueOffset + (y1 * w + x0) * embedDimsOpt], embedDimsOpt); - } - if (isInRange(x1, w)) { - DataCopy(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOptTriple], - valueGm[valueOffset + (y1 * w + x1) * embedDimsOpt], embedDimsOpt); - } + } + if (isInRange(y1, h)) { + if (0 < x1 && x1 < w) { + DataCopyPad(valueLocal[(point + NUM_POINTS) * embedDimsOpt], + valueGm[valueOffset + (y1 * w + x0) * embedDimsOpt], copyParams, padParams); + } else if (isInRange(x0, w)) { + DataCopy(valueLocal[(point + NUM_POINTS) * embedDimsOpt], + valueGm[valueOffset + (y1 * w + x0) * embedDimsOpt], embedDimsOpt); + } else if (isInRange(x1, w)) { + DataCopy(valueLocal[point * embedDimsOpt + NUM_POINTS * embedDimsOptTriple], + valueGm[valueOffset + (y1 * w + x1) * embedDimsOpt], embedDimsOpt); } - #endif + } } SetFlag(eventIdMte2ToV_1); @@ -279,7 +260,29 @@ private: WaitFlag(eventIdMte2ToV_0); Mul(weightLocal, weightLocal, attentionWeightLocal, numPointsAlignOpt, 4, {1, 1, 1, 1, 1, 0}); BroadCast(cornerWeightLocal, weightLocal, dstShape_, srcShape_); + + WaitFlag(eventIdMte2ToV_1); + if (numPointsAlignOpt == NUM_POINTS) { + Mul(valueLocal, valueLocal, cornerWeightLocal, 4 * NUM_POINTS * embedDimsOpt); + } else { + Mul(valueLocal, valueLocal, cornerWeightLocal, NUM_POINTS * embedDimsOpt); + Mul(valueLocal[NUM_POINTS * embedDimsOpt], valueLocal[NUM_POINTS * embedDimsOpt], + cornerWeightLocal[numPointsAlignOpt * embedDimsOpt], NUM_POINTS * embedDimsOpt); + Mul(valueLocal[NUM_POINTS * embedDimsOpt * 2], valueLocal[NUM_POINTS * embedDimsOptTwice], + cornerWeightLocal[numPointsAlignOpt * embedDimsOptTwice], NUM_POINTS * embedDimsOpt); + Mul(valueLocal[NUM_POINTS * embedDimsOpt * 3], valueLocal[NUM_POINTS * embedDimsOptTriple], + cornerWeightLocal[numPointsAlignOpt * embedDimsOptTriple], NUM_POINTS * embedDimsOpt); + } + + Add(tmpResLocal, valueLocal, valueLocal[NUM_POINTS * embedDimsOpt * 2], + NUM_POINTS * embedDimsOptTwice); + Add(tmpResLocal2, tmpResLocal, tmpResLocal[NUM_POINTS * embedDimsOpt], + NUM_POINTS * embedDimsOpt); + Add(tmpResLocal3[srcOffset], tmpResLocal2, tmpResLocal2[NUM_POINTS * embedDimsOptHalf], + NUM_POINTS * embedDimsOptHalf); + if (head < numHeads - 1) { + Duplicate(valueLocal, DTYPE_VALUE(0), 4 * NUM_POINTS * embedDimsOpt); weightOffset = weightOffset + numLevels * NUM_POINTS; if (numPointsAlignOpt == NUM_POINTS) { @@ -295,28 +298,6 @@ private: SetFlag(eventIdMte2ToV_0); } - WaitFlag(eventIdMte2ToV_1); - if (numPointsAlignOpt == NUM_POINTS) { - Mul(valueLocal[NUM_POINTS * embedDimsOpt * 4], valueLocal, cornerWeightLocal, - 4 * NUM_POINTS * embedDimsOpt); - } else { - Mul(valueLocal[NUM_POINTS * embedDimsOpt * 4], valueLocal, cornerWeightLocal, - NUM_POINTS * embedDimsOpt); - Mul(valueLocal[NUM_POINTS * embedDimsOpt * 5], valueLocal[NUM_POINTS * embedDimsOpt], - cornerWeightLocal[numPointsAlignOpt * embedDimsOpt], NUM_POINTS * embedDimsOpt); - Mul(valueLocal[NUM_POINTS * embedDimsOpt * 6], valueLocal[NUM_POINTS * embedDimsOptTwice], - cornerWeightLocal[numPointsAlignOpt * embedDimsOptTwice], NUM_POINTS * embedDimsOpt); - Mul(valueLocal[NUM_POINTS * embedDimsOpt * 7], valueLocal[NUM_POINTS * embedDimsOptTriple], - cornerWeightLocal[numPointsAlignOpt * embedDimsOptTriple], NUM_POINTS * embedDimsOpt); - } - - Add(tmpResLocal, valueLocal[NUM_POINTS * embedDimsOpt * 4], - valueLocal[NUM_POINTS * embedDimsOpt * 6], NUM_POINTS * embedDimsOptTwice); - Add(tmpResLocal2, tmpResLocal, tmpResLocal[NUM_POINTS * embedDimsOpt], - NUM_POINTS * embedDimsOpt); - Add(tmpResLocal3[srcOffset], tmpResLocal2, tmpResLocal2[NUM_POINTS * embedDimsOptHalf], - NUM_POINTS * embedDimsOptHalf); - SetFlag(eventIdVToMte3); WaitFlag(eventIdVToMte3); @@ -360,10 +341,11 @@ private: LocalTensor floatOneLocal = floatOneUb.Get(); LocalTensor halfLocal = halfUb.Get(); LocalTensor locLocal = locUb.Get(); - - Duplicate(emptyUbLocal, DTYPE_VALUE(0), embedDims); - SetFlag(eventIdVToMte3); - WaitFlag(eventIdVToMte3); + if (inner_clean == 1) { + Duplicate(emptyUbLocal, DTYPE_VALUE(0), embedDims); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + } Duplicate(intOneLocal, (DTYPE_VALUE_SPATIAL_SHAPES)1, numPointsAlign); Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign * 2); @@ -387,12 +369,13 @@ private: baseOffset = batch * numHeads * numKeys; moveOffset = (batch * numQueries + query) * numHeads * embedDims; dataOffset = (batch * numQueries + query) * numHeads * numLevels * numPoints; - - for (uint32_t head = 0; head < numHeads; head++) { - DataCopy(outputGm[moveOffset + head * embedDims], emptyUbLocal, embedDims); + if (inner_clean == 1) { + for (uint32_t head = 0; head < numHeads; head++) { + DataCopy(outputGm[moveOffset + head * embedDims], emptyUbLocal, embedDims); + } + pipe_barrier(PIPE_ALL); } - pipe_barrier(PIPE_ALL); - + SetAtomicAdd(); for (uint32_t level = 0; level < numLevels; level++) { h = shapesLocal.GetValue(level * 2); w = shapesLocal.GetValue(level * 2 + 1); @@ -401,7 +384,6 @@ private: Duplicate(locLocal, (DTYPE_VALUE)w, numPointsAlign); Duplicate(locLocal[numPointsAlign], (DTYPE_VALUE)h, numPointsAlign); - SetAtomicAdd(); weightOffset = dataOffset + level * numPoints; DataCopy(tmpLocal, locationGm[weightOffset * 2], numPointsAlign); @@ -412,7 +394,7 @@ private: SetFlag(eventIdMte2ToV_0); for (uint32_t head = 0; head < numHeads; head++) { - Duplicate(valueLocal[4 * batchOffset], DTYPE_VALUE(0), 4 * batchOffset); + Duplicate(valueLocal, DTYPE_VALUE(0), 4 * batchOffset); srcOffset = head * batchOffset; dstOffset = moveOffset + head * embedDims; valueOffset = (oriOffset + head * numKeys) * embedDims; @@ -431,25 +413,25 @@ private: if (isInRange(y0, h)) { if (0 < x1 && x1 < w) { - DataCopy(valueLocal[batchOffset * 4 + point * embedDims * 2], + DataCopy(valueLocal[point * embedDims * 2], valueGm[valueOffset + (y0 * w + x0) * embedDims], 2 * embedDims); } else if (isInRange(x0, w)) { - DataCopy(valueLocal[batchOffset * 4 + point * embedDims * 2], + DataCopy(valueLocal[point * embedDims * 2], valueGm[valueOffset + (y0 * w + x0) * embedDims], embedDims); } else if (isInRange(x1, w)) { - DataCopy(valueLocal[batchOffset * 4 + point * embedDims * 2 + embedDims], + DataCopy(valueLocal[point * embedDims * 2 + embedDims], valueGm[valueOffset + (y0 * w + x1) * embedDims], embedDims); } } if (isInRange(y1, h)) { if (0 < x1 && x1 < w) { - DataCopy(valueLocal[batchOffset * 6 + point * embedDims * 2], + DataCopy(valueLocal[batchOffset * 2 + point * embedDims * 2], valueGm[valueOffset + (y1 * w + x0) * embedDims], 2 * embedDims); } else if (isInRange(x0, w)) { - DataCopy(valueLocal[batchOffset * 6 + point * embedDims * 2], + DataCopy(valueLocal[batchOffset * 2 + point * embedDims * 2], valueGm[valueOffset + (y1 * w + x0) * embedDims], embedDims); } else if (isInRange(x1, w)) { - DataCopy(valueLocal[batchOffset * 6 + point * embedDims * 2 + embedDims], + DataCopy(valueLocal[batchOffset * 2 + point * embedDims * 2 + embedDims], valueGm[valueOffset + (y1 * w + x1) * embedDims], embedDims); } } @@ -503,7 +485,7 @@ private: } WaitFlag(eventIdMte2ToV_1); - Mul(valueLocal, valueLocal[batchOffset * 4], cornerWeightLocal, 4 * batchOffset); + Mul(valueLocal, valueLocal, cornerWeightLocal, 4 * batchOffset); if (embedDims != 32) { pipe_barrier(PIPE_ALL); @@ -520,8 +502,8 @@ private: DataCopy(outputGm[dstOffset], tmpResLocal3[srcOffset + point * embedDims], embedDims); } } - SetAtomicNone(); } + SetAtomicNone(); } } GetTPipePtr()->ReleaseEventID(eventIdVToMte3); @@ -573,9 +555,10 @@ private: uint32_t endOffset; uint32_t dataAlign; uint32_t blockNum = 32; + uint32_t inner_clean = 0; DTYPE_VALUE_SPATIAL_SHAPES tmpOffset1, tmpOffset2, baseOffset, valueOffset, weightOffset, oriOffset, pointOffset, - dataOffset, locationOffset, moveOffset, batchOffset, dstOffset, srcOffset, headOffset; + dataOffset, locationOffset, moveOffset, batchOffset, dstOffset, srcOffset, headOffset, valueLocalOffset; DTYPE_VALUE tmp1, tmp2, leftTopWeight, rightTopWeight, leftBottomWeight, rightBottomWeight, attnWeight; DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1; }; -- Gitee