From 009921fb747cae4f00d0c320d79a1c5769279617 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=93=B2=E7=BB=AD?= Date: Mon, 22 Jan 2024 19:59:34 +0800 Subject: [PATCH] Type: Feature Enhance Team: Pytorch_Ops_Dev. Description: Scene interception. --- ...tiScaleDeformableAttnFunctionKernelNpu.cpp | 9 +++++ .../ops/kernels/ads_op/CMakePresets.json | 2 +- .../multi_scale_deformable_attn_function.cpp | 33 +++++++++++-------- 3 files changed, 30 insertions(+), 14 deletions(-) diff --git a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp index 1cad3d44..31ff3bc3 100644 --- a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp +++ b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp @@ -35,6 +35,15 @@ at::Tensor npu_multi_scale_deformable_attn_function(const at::Tensor& value, auto location_size = sampling_locations.sizes(); auto output_size = {value_size[0], location_size[1], value_size[2] * value_size[3]}; + auto embed_dims = value_size[3]; + auto num_points = location_size[4]; + auto num_levels = location_size[3]; + auto data_total = embed_dims + num_points + num_levels; + + TORCH_CHECK( + data_total < 512, + "data_total is over 512: embed_dims ", embed_dims, " num_points is ", num_points, " num_level is ", num_levels, "." ); + at::Tensor result = at::empty(output_size, value.options().dtype(at::kFloat)); // reset inputs diff --git a/ads/common/ops/kernels/ads_op/CMakePresets.json b/ads/common/ops/kernels/ads_op/CMakePresets.json index add05853..a23c07b8 100644 --- a/ads/common/ops/kernels/ads_op/CMakePresets.json +++ b/ads/common/ops/kernels/ads_op/CMakePresets.json @@ -27,7 +27,7 @@ }, "ASCEND_COMPUTE_UNIT": { "type": "STRING", - "value": "ascend310p;ascend910;ascend910b" + "value": "ascend910b" }, "ENABLE_TEST": { "type": "BOOL", diff --git a/ads/common/ops/kernels/ads_op/op_kernel/multi_scale_deformable_attn_function.cpp b/ads/common/ops/kernels/ads_op/op_kernel/multi_scale_deformable_attn_function.cpp index 1ddf8390..72e70e86 100644 --- a/ads/common/ops/kernels/ads_op/op_kernel/multi_scale_deformable_attn_function.cpp +++ b/ads/common/ops/kernels/ads_op/op_kernel/multi_scale_deformable_attn_function.cpp @@ -167,21 +167,18 @@ private: { DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); } - set_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE3, PIPE_S, EVENT_ID0); + pipe_barrier(PIPE_ALL); for (uint32_t head = 0; head < numHeads; head++) { weightOffset = (batch * numQueries * numHeads * numLevels + query * numHeads * numLevels + head * numLevels) * numPoints; - set_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); + pipe_barrier(PIPE_ALL); DataCopy(locationLocal, locationGm[weightOffset * 2], AlignUp(numLevels * numPoints * 2, dataAlign)); DataCopy(attentionWeightLocal, attentionWeightsGm[weightOffset], AlignUp(numLevels * numPoints, dataAlign)); - set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); + pipe_barrier(PIPE_ALL); for (uint32_t level = 0; level < numLevels; level++) { h = shapesLocal.GetValue(level * 2); @@ -193,40 +190,49 @@ private: yLocal.SetValue(point, locationLocal.GetValue(locationOffset + 1)); } - set_flag(PIPE_S, PIPE_V, EVENT_ID0); - wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_ALL); Muls(tmpLocal1, xLocal, (DTYPE_VALUE)w, numPointsAlign); Muls(tmpLocal2, yLocal, (DTYPE_VALUE)h, numPointsAlign); + pipe_barrier(PIPE_ALL); Adds(param0Local, tmpLocal1, (DTYPE_VALUE)0.5, numPointsAlign); Adds(param1Local, tmpLocal2, (DTYPE_VALUE)0.5, numPointsAlign); + pipe_barrier(PIPE_ALL); Cast(x1Local, param0Local, RoundMode::CAST_FLOOR, numPointsAlign); Cast(y1Local, param1Local, RoundMode::CAST_FLOOR, numPointsAlign); + pipe_barrier(PIPE_ALL); Adds(tmpLocal3, param0Local, (DTYPE_VALUE)-1, numPointsAlign); Adds(tmpLocal4, param1Local, (DTYPE_VALUE)-1, numPointsAlign); + pipe_barrier(PIPE_ALL); Sub(x0Local, x1Local, intOneLocal, numPointsAlign); Sub(y0Local, y1Local, intOneLocal, numPointsAlign); + pipe_barrier(PIPE_ALL); Cast(xLocal, x0Local, RoundMode::CAST_NONE, numPointsAlign); Cast(yLocal, y0Local, RoundMode::CAST_NONE, numPointsAlign); + pipe_barrier(PIPE_ALL); Sub(tmpLocal1, tmpLocal3, xLocal, numPointsAlign); Sub(tmpLocal2, tmpLocal4, yLocal, numPointsAlign); + pipe_barrier(PIPE_ALL); Abs(param0Local, tmpLocal1, numPointsAlign); Abs(param1Local, tmpLocal2, numPointsAlign); + pipe_barrier(PIPE_ALL); Sub(xLocal, floatOneLocal, param0Local, numPointsAlign); Sub(yLocal, floatOneLocal, param1Local, numPointsAlign); + pipe_barrier(PIPE_ALL); Mul(leftTopWeiightLocal, xLocal, yLocal, numPointsAlign); Mul(leftBottomWeightLocal, xLocal, param1Local, numPointsAlign); Mul(rightTopWeiightLocal, param0Local, yLocal, numPointsAlign); Mul(rightBottomWeightLocal, param0Local, param1Local, numPointsAlign); + pipe_barrier(PIPE_ALL); Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); @@ -267,22 +273,23 @@ private: DataCopy(rightBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x1) * numHeads) * embedDims], embedDimsAlign); } } - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + pipe_barrier(PIPE_ALL); Muls(leftTopValueLocal, leftTopValueLocal, leftTopWeiightLocal.GetValue(point), embedDimsAlign); Muls(rightTopValueUbLocal, rightTopValueUbLocal, rightTopWeiightLocal.GetValue(point), embedDimsAlign); Muls(leftBottomValueUbLocal, leftBottomValueUbLocal, leftBottomWeightLocal.GetValue(point), embedDimsAlign); Muls(rightBottomValueUbLocal, rightBottomValueUbLocal, rightBottomWeightLocal.GetValue(point), embedDimsAlign); + pipe_barrier(PIPE_ALL); Add(tmpResLocal, leftTopValueLocal, rightTopValueUbLocal, embedDimsAlign); Add(tmpResLocal2, leftBottomValueUbLocal, rightBottomValueUbLocal, embedDimsAlign); + pipe_barrier(PIPE_ALL); Add(tmpResLocal, tmpResLocal, tmpResLocal2, embedDimsAlign); + pipe_barrier(PIPE_ALL); Muls(tmpResLocal, tmpResLocal, attentionWeightLocal.GetValue(level * numPoints + point), embedDimsAlign); + pipe_barrier(PIPE_ALL); Add(resLocal, resLocal, tmpResLocal, embedDimsAlign); } - - set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + pipe_barrier(PIPE_ALL); SetAtomicAdd(); DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); -- Gitee