From e775a9ae19f49f01542d2a2a196b29b274426d5a Mon Sep 17 00:00:00 2001 From: jayshu Date: Wed, 28 Feb 2024 17:16:53 +0800 Subject: [PATCH] enhance op multi_scale_deformable_attention_grad --- .../multi_scale_deformable_attention_grad.cpp | 252 ++++++++++-------- 1 file changed, 135 insertions(+), 117 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index 0b75de39..f9223ead 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -18,7 +18,6 @@ * \file multi_scale_deformable_attention_grad.h * \brief */ - #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" using namespace AscendC; @@ -36,7 +35,7 @@ namespace constexpr static int32_t GV_OUTPUT_INDEX = 6; constexpr static int32_t GSL_OUTPUT_INDEX = 7; constexpr static int32_t GAW_OUTPUT_INDEX = 8; - constexpr static int32_t BUFFER_NUM = 1; + constexpr static int32_t BUFFER_NUM = 2; constexpr static int32_t DOUB = 2; constexpr static int32_t T_BLOCK = 8; constexpr static uint16_t DST_BLK_STRIDE = 1; @@ -45,7 +44,6 @@ namespace constexpr static uint8_t SRC_REP_STRIDE = 8; }; - template class MultiScaleDeformableAttentionGrad { @@ -85,7 +83,7 @@ public: int32_t loc_w_offset, int32_t w, int32_t h, int32_t sample_location_offset); __aicore__ inline void post_process_levels(int32_t w, int32_t h, int32_t cur_nl, int32_t data_weight_ptr, - int32_t data_loc_w_ptr); + int32_t data_loc_w_ptr, int32_t cur_nh); __aicore__ inline void cal_grad_value(LocalTensor &v_ub, LocalTensor &offset1_ub, LocalTensor &offset2_ub, LocalTensor &h_w_w1_ub, LocalTensor &h_w_w2_ub, LocalTensor &w_weight_ub, @@ -135,6 +133,8 @@ private: TBuf buffer_w4_v4; TBuf buffer_val; TBuf buffer_grad_weight; + TBuf buffer_grad_weight_full; + TBuf buffer_grad_sample_loc; int32_t spatial_size; int32_t cur_block_idx; int32_t cur_core_task_num; @@ -193,6 +193,8 @@ private: LocalTensor sample_location_local; LocalTensor level_start_index_local; LocalTensor spatial_shapes_local; + LocalTensor grad_weight_full_local; + LocalTensor grad_sample_loc_local; }; template @@ -206,6 +208,7 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::init(GM_ADDR input_ num_levels = tiling_data->num_levels; num_query = tiling_data->num_query; num_point = tiling_data->num_point; + int32_t doub = 2; int32_t batch_size = tiling_data->batch_size; cur_core_task_num = tiling_data->task_per_core; @@ -239,9 +242,9 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::init_buffer() pipe.InitBuffer(in_queue_grad_output, 1, top_grad_ub_size * sizeof(T)); pipe.InitBuffer(in_queue_lsi, 1, ceil(num_levels, int32_per_block) * sizeof(int32_t)); pipe.InitBuffer(in_queue_ss, 1, ceil(num_levels * DOUB, int32_per_block) * sizeof(int32_t)); - pipe.InitBuffer(in_queue_sl, 1, ceil(num_heads * num_levels * num_point * DOUB, + pipe.InitBuffer(in_queue_sl, 1, ceil(num_heads * num_levels * num_point * DOUB * t_per_block, t_per_block) * sizeof(T)); - pipe.InitBuffer(in_queue_aw, 1, ceil(num_heads * num_levels * num_point, + pipe.InitBuffer(in_queue_aw, 1, ceil(num_heads * num_levels * num_point * t_per_block, t_per_block) * sizeof(T)); pipe.InitBuffer(buffer_h_im, num_point_align * sizeof(T)); pipe.InitBuffer(buffer_w_im, num_point_align * sizeof(T)); @@ -281,6 +284,8 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::init_buffer() pipe.InitBuffer(buffer_w4_v4, channel_align * sizeof(T)); pipe.InitBuffer(buffer_val, point_channel_align * sizeof(T)); pipe.InitBuffer(buffer_grad_weight, point_channel_align * sizeof(T)); + pipe.InitBuffer(buffer_grad_weight_full, (num_heads * num_levels * num_point * channel_align) * sizeof(T)); + pipe.InitBuffer(buffer_grad_sample_loc, (num_heads * num_levels * num_point * channel_align * DOUB) * sizeof(T)); } template @@ -329,14 +334,15 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::init_local_tensor() w4_v4_local = buffer_w4_v4.Get(channel_align); val_local = buffer_val.Get(point_channel_align); grad_weight_local = buffer_grad_weight.Get(point_channel_align); + grad_weight_full_local = buffer_grad_weight_full.Get((num_heads * num_levels * num_point * channel_align) * sizeof(T)); + grad_sample_loc_local = buffer_grad_sample_loc.Get((num_heads * num_levels * num_point * channel_align * DOUB) * sizeof(T)); } template __aicore__ inline void MultiScaleDeformableAttentionGrad::process() { - // FOR ADD A NEW TILING MODE + // SAVE FOR NEXT TILING MODE compute_mode_zero(); - pipe_barrier(PIPE_ALL); } template @@ -347,57 +353,44 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::process_grad_value_ { for (int32_t cur_np = 0; cur_np < num_point; cur_np++) { - pipe_barrier(PIPE_ALL); auto h_im = h_im_local.GetValue(cur_np); auto w_im = w_im_local.GetValue(cur_np); - pipe_barrier(PIPE_ALL); if ((float)-1.0 < h_im && h_im < (float)h && w_im > (float)-1.0 && w_im < (float)w) { - pipe_barrier(PIPE_ALL); auto attn_weight = attn_weight_local.GetValue((cur_nh * num_levels + cur_nl) * num_point + cur_np); - pipe_barrier(PIPE_ALL); + set_flag(PIPE_S, PIPE_V, EVENT_ID1); + wait_flag(PIPE_S, PIPE_V, EVENT_ID1); muls_template(top_grad_value_local[cur_np * channels], grad_output_local[cur_nh * channels], - attn_weight, channel_align); - pipe_barrier(PIPE_ALL); + attn_weight, channel_align); + set_flag(PIPE_V, PIPE_S, EVENT_ID2); + wait_flag(PIPE_V, PIPE_S, EVENT_ID2); auto h_low = h_low_local.GetValue(cur_np); auto h_high = h_high_local.GetValue(cur_np); auto w_low = w_low_local.GetValue(cur_np); auto w_high = w_high_local.GetValue(cur_np); - pipe_barrier(PIPE_ALL); if (h_low >= 0 && w_low >= 0) { - pipe_barrier(PIPE_ALL); cal_grad_value(v1_local, h_low_ptr_offset_local, w_low_ptr_offset_local, hw_local, hh_local, w1_local, cur_np, base_ptr, value_ptr_offset, true, true); - pipe_barrier(PIPE_ALL); } - pipe_barrier(PIPE_ALL); if (h_low >= 0 && w_high < w) { - pipe_barrier(PIPE_ALL); cal_grad_value(v2_local, h_low_ptr_offset_local, w_high_ptr_offset_local, lw_local, hh_local, w2_local, cur_np, base_ptr, value_ptr_offset, true, false); - pipe_barrier(PIPE_ALL); } - pipe_barrier(PIPE_ALL); if (h_high < h && w_low >= 0) { - pipe_barrier(PIPE_ALL); cal_grad_value(v3_local, h_high_ptr_offset_local, w_low_ptr_offset_local, hw_local, lh_local, w3_local, cur_np, base_ptr, value_ptr_offset, false, true); - pipe_barrier(PIPE_ALL); } - pipe_barrier(PIPE_ALL); if (h_high < h && w_high < w) { - pipe_barrier(PIPE_ALL); cal_grad_value(v4_local, h_high_ptr_offset_local, w_high_ptr_offset_local, lw_local, lh_local, w4_local, cur_np, base_ptr, value_ptr_offset, false, false); - pipe_barrier(PIPE_ALL); } - pipe_barrier(PIPE_ALL); process_grad_weight_with_point(cur_nh, cur_np); } + } } @@ -409,25 +402,25 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::process_grad_weight auto w2 = w2_local.GetValue(cur_np); auto w3 = w3_local.GetValue(cur_np); auto w4 = w4_local.GetValue(cur_np); - pipe_barrier(PIPE_ALL); + set_flag(PIPE_S, PIPE_V, EVENT_ID2); + wait_flag(PIPE_S, PIPE_V, EVENT_ID2); muls_template(w1_v1_local, v1_local[cur_np * channels], w1, channel_align); muls_template(w2_v2_local, v2_local[cur_np * channels], w2, channel_align); muls_template(w3_v3_local, v3_local[cur_np * channels], w3, channel_align); muls_template(w4_v4_local, v4_local[cur_np * channels], w4, channel_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); #ifndef __GET_CODE_CHANNEL__ DataCopy(val_local[cur_np * channels], w1_v1_local, channel_align); #endif - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); Add(val_local[cur_np * channels], val_local[cur_np * channels], w2_v2_local, channel_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); Add(val_local[cur_np * channels], val_local[cur_np * channels], w3_v3_local, channel_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); Add(val_local[cur_np * channels], val_local[cur_np * channels], w4_v4_local, channel_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); Mul(grad_weight_local[cur_np * channels], val_local[cur_np * channels], grad_output_local[cur_nh * channels], channel_align); - pipe_barrier(PIPE_ALL); } template @@ -452,46 +445,40 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::pre_process_levels( adds_template(h_im_local, h_im_local, (float)(-0.5), num_point_align); muls_template(w_im_local, sample_location_local[loc_w_offset], (float)w, num_point_align); adds_template(w_im_local, w_im_local, (float)(-0.5), num_point_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); Cast(h_low_local, h_im_local, RoundMode::CAST_FLOOR, num_point_align); Cast(w_low_local, w_im_local, RoundMode::CAST_FLOOR, num_point_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); adds_template_int32(h_high_local, h_low_local, 1, num_point_align); adds_template_int32(w_high_local, w_low_local, 1, num_point_align); - pipe_barrier(PIPE_ALL); Cast(h_low_t_local, h_low_local, RoundMode::CAST_NONE, num_point_align); Cast(w_low_t_local, w_low_local, RoundMode::CAST_NONE, num_point_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); Sub(lh_local, h_im_local, h_low_t_local, num_point_align); Sub(lw_local, w_im_local, w_low_t_local, num_point_align); + pipe_barrier(PIPE_V); muls_template(neg_lh_local, lh_local, (float)-1.0, num_point_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); adds_template(hh_local, neg_lh_local, (float)1.0, num_point_align); - pipe_barrier(PIPE_ALL); muls_template(neg_lw_local, lw_local, (float)-1.0, num_point_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); adds_template(hw_local, neg_lw_local, (float)1.0, num_point_align); - pipe_barrier(PIPE_ALL); muls_template_int32(h_low_ptr_offset_local, h_low_local, h_stride, num_point_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); adds_template_int32(h_high_ptr_offset_local, h_low_ptr_offset_local, h_stride, num_point_align); - pipe_barrier(PIPE_ALL); muls_template_int32(w_low_ptr_offset_local, w_low_local, w_stride, num_point_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); adds_template_int32(w_high_ptr_offset_local, w_low_ptr_offset_local, w_stride, num_point_align); - pipe_barrier(PIPE_ALL); Mul(w1_local, hh_local, hw_local, num_point_align); Mul(w2_local, hh_local, lw_local, num_point_align); Mul(w3_local, lh_local, hw_local, num_point_align); Mul(w4_local, lh_local, lw_local, num_point_align); - pipe_barrier(PIPE_ALL); Duplicate(grad_w_weight_local, 0.0, point_channel_align); Duplicate(grad_h_weight_local, 0.0, point_channel_align); Duplicate(v1_local, 0.0, point_channel_align); Duplicate(v2_local, 0.0, point_channel_align); Duplicate(v3_local, 0.0, point_channel_align); Duplicate(v4_local, 0.0, point_channel_align); - pipe_barrier(PIPE_ALL); } template @@ -507,23 +494,18 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::muls_template_int32 int32_t repeats_tail = repeats % max_repeat; int32_t tail = calCount % mask; int32_t tensor_offset = 0; - pipe_barrier(PIPE_ALL); for (int32_t loop_idx = 0; loop_idx < loop; loop_idx++) { Muls(dstLocal[loop_idx * max_repeat * mask], srcLocal[loop_idx * max_repeat * mask], scalarValue, mask, max_repeat, {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); } - pipe_barrier(PIPE_ALL); tensor_offset = loop * max_repeat * mask; - pipe_barrier(PIPE_ALL); if (repeats_tail >= 1) { Muls(dstLocal[tensor_offset], srcLocal[tensor_offset], scalarValue, mask, repeats_tail, {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); } - pipe_barrier(PIPE_ALL); tensor_offset += repeats_tail * mask; - pipe_barrier(PIPE_ALL); if (tail >= 1) { Muls(dstLocal[tensor_offset], srcLocal[tensor_offset], scalarValue, tail); @@ -533,29 +515,20 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::muls_template_int32 template __aicore__ inline void MultiScaleDeformableAttentionGrad::post_process_levels(int32_t w, int32_t h, int32_t cur_nl, int32_t data_weight_ptr, - int32_t data_loc_w_ptr) + int32_t data_loc_w_ptr, + int32_t cur_nh) { Mul(grad_w_weight_local, top_grad_value_local, grad_w_weight_local, point_channel_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); muls_template(grad_w_weight_local, grad_w_weight_local, (float)w, point_channel_align); - pipe_barrier(PIPE_ALL); Mul(grad_h_weight_local, top_grad_value_local, grad_h_weight_local, point_channel_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); muls_template(grad_h_weight_local, grad_h_weight_local, (float)h, point_channel_align); - pipe_barrier(PIPE_ALL); - WholeReduceSum(w3_v3_local, grad_h_weight_local, channels, num_point, 1, 1, (channels - 1) / t_per_block + 1); - WholeReduceSum(w4_v4_local, grad_w_weight_local, channels, num_point, 1, 1, (channels - 1) / t_per_block + 1); - WholeReduceSum(w2_v2_local, grad_weight_local, channels, num_point, 1, 1, (channels - 1) / t_per_block + 1); - pipe_barrier(PIPE_ALL); - SetAtomicAdd(); - #ifndef __GET_CODE_CHANNEL__ - DataCopyParams copy_params{1, (uint16_t)(num_point * sizeof(float)), 0, 0}; - DataCopyPad(grad_attn_weight_gm[data_weight_ptr + cur_nl * num_point], w2_v2_local, copy_params); - DataCopyPad(grad_sampling_loc_gm[data_loc_w_ptr + cur_nl * 2 * num_point], w4_v4_local, copy_params); - DataCopyPad(grad_sampling_loc_gm[data_loc_w_ptr + cur_nl * 2 * num_point + num_point], w3_v3_local, copy_params); - #endif - SetAtomicNone(); - pipe_barrier(PIPE_ALL); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + Copy(grad_sample_loc_local[(cur_nh * num_levels + cur_nl) * num_point * 2 * channel_align], grad_w_weight_local, 64, num_point * channel_align / 64, { 1, 1, 8, 8 }); + Copy(grad_sample_loc_local[((cur_nh * num_levels + cur_nl) * 2 + 1) * num_point * channel_align], grad_h_weight_local, 64, num_point * channel_align / 64, { 1, 1, 8, 8 }); + Copy(grad_weight_full_local[(cur_nh * num_levels + cur_nl) * num_point * channel_align], grad_weight_local, 64, num_point * channel_align / 64, { 1, 1, 8, 8 }); } template @@ -570,23 +543,20 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::process_levels(int3 { for (int32_t cur_nl = 0; cur_nl < num_levels; cur_nl++) { - pipe_barrier(PIPE_ALL); auto level_start_id = level_start_index_local.GetValue(cur_nl); - pipe_barrier(PIPE_ALL); auto value_ptr_offset = data_value_ptr_init_offset + level_start_id * qid_stride; - pipe_barrier(PIPE_ALL); auto h = spatial_shapes_local.GetValue(2 * cur_nl); auto w = spatial_shapes_local.GetValue(2 * cur_nl + 1); - pipe_barrier(PIPE_ALL); auto h_stride = w * w_stride; auto loc_w_offset = (cur_nh * num_levels + cur_nl) * DOUB * num_point; auto loc_h_offset = loc_w_offset + num_point; - pipe_barrier(PIPE_ALL); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); pre_process_levels(cur_nh, cur_b, cur_nl, cur_q, sl_size, h_stride, w_stride, loc_h_offset, loc_w_offset, w, h, sample_location_offset); process_grad_value_with_point(cur_nh, cur_nl, base_ptr, value_ptr_offset, h, w); - pipe_barrier(PIPE_ALL); - post_process_levels(w, h, cur_nl, data_weight_ptr, data_loc_w_ptr); + pipe_barrier(PIPE_V); + post_process_levels(w, h, cur_nl, data_weight_ptr, data_loc_w_ptr, cur_nh); } } @@ -596,21 +566,17 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::compute_mode_zero() auto qid_stride = num_heads * channels; auto sl_size = num_heads * num_levels * num_point * DOUB; auto w_stride = num_heads * channels; - pipe_barrier(PIPE_ALL); #ifndef __GET_CODE_CHANNEL__ DataCopy(level_start_index_local, level_start_index_gm, ceil(num_levels, int32_per_block)); DataCopy(spatial_shapes_local, spatial_shapes_gm, ceil(num_levels * DOUB, int32_per_block)); #endif - pipe_barrier(PIPE_ALL); for (int32_t b_nq_ind = start_task_id; b_nq_ind < start_task_id + cur_core_task_num; b_nq_ind++) { - pipe_barrier(PIPE_ALL); int32_t cur_q = b_nq_ind % num_query; int32_t cur_b = b_nq_ind / num_query; auto data_value_ptr_init_offset = cur_b * spatial_size * qid_stride; auto grad_output_offset = (cur_b * num_query + cur_q) * num_heads * channels; auto sample_location_offset = (cur_b * num_query + cur_q) * sl_size; - pipe_barrier(PIPE_ALL); #ifndef __GET_CODE_CHANNEL__ DataCopy(grad_output_local, grad_output_gm[grad_output_offset], ceil(num_heads * channels, t_per_block)); DataCopy(attn_weight_local, attn_weight_gm[sample_location_offset / DOUB], ceil(sl_size / DOUB, t_per_block)); @@ -621,23 +587,87 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::compute_mode_zero() DataCopy(sample_location_local, sampling_loc_gm[sample_location_offset], ceil(sl_size, t_per_block)); #endif } - pipe_barrier(PIPE_ALL); + auto data_weight_ptr = (cur_b * num_query * num_heads + cur_q * num_heads) * num_levels * num_point; + auto data_loc_w_ptr = DOUB * data_weight_ptr; for (int32_t cur_nh = 0; cur_nh < num_heads; cur_nh++) { - pipe_barrier(PIPE_ALL); - auto data_weight_ptr = (cur_b * num_query * num_heads + cur_q * num_heads + - cur_nh) * num_levels * num_point; - auto data_loc_w_ptr = DOUB * data_weight_ptr; auto base_ptr = cur_nh * channels; - pipe_barrier(PIPE_ALL); process_levels(cur_nh, base_ptr, cur_b, cur_q, sl_size, qid_stride, data_value_ptr_init_offset, w_stride, sample_location_offset, data_weight_ptr, data_loc_w_ptr); - pipe_barrier(PIPE_ALL); } + pipe_barrier(PIPE_V); + int32_t time = 248; + auto ran = num_heads * num_levels * num_point / time; + auto ran1 = num_heads * num_levels * num_point * DOUB / time; + auto remain = num_heads * num_levels * num_point % time; + auto remain1 = num_heads * num_levels * num_point * DOUB % time; + + if (channels > 64) + { + auto mask = channels / t_per_block; + ran = num_heads * num_levels * num_point * t_per_block / time; + ran1 = num_heads * num_levels * num_point * DOUB * t_per_block / time; + remain = num_heads * num_levels * num_point * t_per_block % time; + remain1 = num_heads * num_levels * num_point * DOUB * t_per_block % time; + for (auto i = 0; i < ran; i++) + { + WholeReduceSum(attn_weight_local[i * time], grad_weight_full_local[i * time * mask], mask, time, 1, 1, (mask - 1) / t_per_block + 1); + } + for (auto i = 0; i < ran1; i++) + { + WholeReduceSum(sample_location_local[i * time], grad_sample_loc_local[i * time * mask], mask, time, 1, 1, (mask - 1) / t_per_block + 1); + } + pipe_barrier(PIPE_V); + WholeReduceSum(attn_weight_local[ran * time], grad_weight_full_local[ran * time * mask], mask, remain, 1, 1, (mask - 1) / t_per_block + 1); + WholeReduceSum(sample_location_local[ran1 * time], grad_sample_loc_local[ran1 * time * mask], mask, remain1, 1, 1, (mask - 1) / t_per_block + 1); + pipe_barrier(PIPE_V); + + ran = num_heads * num_levels * num_point / time; + ran1 = num_heads * num_levels * num_point * DOUB / time; + remain = num_heads * num_levels * num_point % time; + remain1 = num_heads * num_levels * num_point * DOUB % time; + for (auto i = 0; i < ran; i++) + { + WholeReduceSum(attn_weight_local[i * time], attn_weight_local[i * time * t_per_block], t_per_block, time, 1, 1, 1); + } + for (auto i = 0; i < ran1; i++) + { + WholeReduceSum(sample_location_local[i * time], sample_location_local[i * time * t_per_block], t_per_block, time, 1, 1, 1); + } + pipe_barrier(PIPE_V); + WholeReduceSum(attn_weight_local[ran * time], attn_weight_local[ran * time * t_per_block], t_per_block, remain, 1, 1, 1); + WholeReduceSum(sample_location_local[ran1 * time], sample_location_local[ran1 * time * t_per_block], t_per_block, remain1, 1, 1, 1); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + } else + { + for (auto i = 0; i < ran; i++) + { + WholeReduceSum(attn_weight_local[i * time], grad_weight_full_local[i * time * channels], channels, time, 1, 1, (channels - 1) / t_per_block + 1); + } + pipe_barrier(PIPE_V); + for (auto i = 0; i < ran1; i++) + { + WholeReduceSum(sample_location_local[i * time], grad_sample_loc_local[i * time * channels], channels, time, 1, 1, (channels - 1) / t_per_block + 1); + } + pipe_barrier(PIPE_V); + WholeReduceSum(attn_weight_local[ran * time], grad_weight_full_local[ran * time * channels], channels, remain, 1, 1, (channels - 1) / t_per_block + 1); + WholeReduceSum(sample_location_local[ran1 * time], grad_sample_loc_local[ran1 * time * channels], channels, remain1, 1, 1, (channels - 1) / t_per_block + 1); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + } + + SetAtomicAdd(); + #ifndef __GET_CODE_CHANNEL__ + DataCopyParams copy_params{1, (uint16_t)(num_heads * num_levels * num_point * sizeof(float)), 0, 0}; + DataCopyParams copy_params1{1, (uint16_t)(num_heads * num_levels * num_point * DOUB * sizeof(float)), 0, 0}; + DataCopyPad(grad_attn_weight_gm[data_weight_ptr], attn_weight_local, copy_params); + DataCopyPad(grad_sampling_loc_gm[data_loc_w_ptr], sample_location_local, copy_params1); + #endif + SetAtomicNone(); pipe_barrier(PIPE_ALL); } - pipe_barrier(PIPE_ALL); in_queue_grad_output.FreeTensor(grad_output_local); in_queue_lsi.FreeTensor(level_start_index_local); in_queue_ss.FreeTensor(spatial_shapes_local); @@ -656,14 +686,16 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::cal_grad_value(Loca auto h_w_w1 = h_w_w1_ub.GetValue(cur_np); auto h_w_w2 = h_w_w2_ub.GetValue(cur_np); auto w_weight = w_weight_ub.GetValue(cur_np); - pipe_barrier(PIPE_ALL); + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); #ifndef __GET_CODE_CHANNEL__ DataCopy(v_ub[cur_np * channels], value_gm[value_ptr_offset + ptr], channel_align); #endif - pipe_barrier(PIPE_ALL); + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); muls_template(v_w1_local, v_ub[cur_np * channels], h_w_w1, channel_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); if (neg_h) { Sub(grad_h_weight_local[cur_np * channels], grad_h_weight_local[cur_np * channels], v_w1_local, channel_align); @@ -671,9 +703,8 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::cal_grad_value(Loca { Add(grad_h_weight_local[cur_np * channels], grad_h_weight_local[cur_np * channels], v_w1_local, channel_align); } - pipe_barrier(PIPE_ALL); muls_template(v_w2_local, v_ub[cur_np * channels], h_w_w2, channel_align); - pipe_barrier(PIPE_ALL); + pipe_barrier(PIPE_V); if (neg_w) { Sub(grad_w_weight_local[cur_np * channels], grad_w_weight_local[cur_np * channels], v_w2_local, channel_align); @@ -681,16 +712,15 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::cal_grad_value(Loca { Add(grad_w_weight_local[cur_np * channels], grad_w_weight_local[cur_np * channels], v_w2_local, channel_align); } - pipe_barrier(PIPE_ALL); muls_template(mid_local, top_grad_value_local[cur_np * channels], w_weight, channel_align); - pipe_barrier(PIPE_ALL); + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); SetAtomicAdd(); DataCopyParams copy_params3{1, (uint16_t)(channels * sizeof(float)), 0, 0}; #ifndef __GET_CODE_CHANNEL__ DataCopyPad(grad_value_gm[value_ptr_offset + ptr], mid_local, copy_params3); #endif SetAtomicNone(); - pipe_barrier(PIPE_ALL); } template @@ -706,21 +736,17 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::adds_template_int32 int32_t repeats_tail = repeats % max_repeat; int32_t tail = calCount % mask; int32_t tensor_offset = 0; - pipe_barrier(PIPE_ALL); for (int32_t loop_idx = 0; loop_idx < loop; loop_idx++) { - Adds(dstLocal[loop_idx * max_repeat * mask], srcLocal[loop_idx * max_repeat * mask], scalarValue, - mask, max_repeat, {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); + Adds(dstLocal[loop_idx * max_repeat * mask], srcLocal[loop_idx * max_repeat * mask], scalarValue, mask, max_repeat, {DST_BLK_STRIDE, SRC_BLK_STRIDE, + DST_REP_STRIDE, SRC_REP_STRIDE}); } - pipe_barrier(PIPE_ALL); tensor_offset = loop * max_repeat * mask; - pipe_barrier(PIPE_ALL); if (repeats_tail >= 1) { Adds(dstLocal[tensor_offset], srcLocal[tensor_offset], scalarValue, mask, repeats_tail, - {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); + {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); } - pipe_barrier(PIPE_ALL); tensor_offset += repeats_tail * mask; pipe_barrier(PIPE_ALL); if (tail >= 1) @@ -752,21 +778,17 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::muls_template(const int32_t repeats_tail = repeats % max_repeat; int32_t tail = calCount % mask; int32_t tensor_offset = 0; - pipe_barrier(PIPE_ALL); for (int32_t loop_idx = 0; loop_idx < loop; loop_idx++) { Muls(dstLocal[loop_idx * max_repeat * mask], srcLocal[loop_idx * max_repeat * mask], scalarValue, mask, max_repeat, {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); } - pipe_barrier(PIPE_ALL); tensor_offset = loop * max_repeat * mask; - pipe_barrier(PIPE_ALL); if (repeats_tail >= 1) { Muls(dstLocal[tensor_offset], srcLocal[tensor_offset], scalarValue, mask, repeats_tail, {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); } - pipe_barrier(PIPE_ALL); tensor_offset += repeats_tail * mask; pipe_barrier(PIPE_ALL); if (tail >= 1) @@ -777,7 +799,8 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::muls_template(const template __aicore__ inline void MultiScaleDeformableAttentionGrad::adds_template(const LocalTensor &dstLocal, - const LocalTensor &srcLocal, T scalarValue, const int32_t calCount) + const LocalTensor &srcLocal, + T scalarValue, const int32_t calCount) { int32_t unit = 256; int32_t max_repeat = 64; @@ -787,23 +810,18 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::adds_template(const int32_t repeats_tail = repeats % max_repeat; int32_t tail = calCount % mask; int32_t tensor_offset = 0; - pipe_barrier(PIPE_ALL); for (int32_t loop_idx = 0; loop_idx < loop; loop_idx++) { - Adds(dstLocal[loop_idx * max_repeat * mask], srcLocal[loop_idx * max_repeat * mask], scalarValue, - mask, max_repeat, {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); + Adds(dstLocal[loop_idx * max_repeat * mask], srcLocal[loop_idx * max_repeat * mask], scalarValue, mask, max_repeat, + {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); } - pipe_barrier(PIPE_ALL); tensor_offset = loop * max_repeat * mask; - pipe_barrier(PIPE_ALL); if (repeats_tail >= 1) { Adds(dstLocal[tensor_offset], srcLocal[tensor_offset], scalarValue, mask, repeats_tail, {DST_BLK_STRIDE, SRC_BLK_STRIDE, DST_REP_STRIDE, SRC_REP_STRIDE}); } - pipe_barrier(PIPE_ALL); tensor_offset += repeats_tail * mask; - pipe_barrier(PIPE_ALL); if (tail >= 1) { Adds(dstLocal[tensor_offset], srcLocal[tensor_offset], scalarValue, tail); -- Gitee