From 740f88c85ce81fcbcabd8f0752341563227ac2ed Mon Sep 17 00:00:00 2001 From: jayshu Date: Mon, 4 Mar 2024 15:07:27 +0800 Subject: [PATCH] test --- .../multi_scale_deformable_attention_grad.cpp | 82 ++++++++++--------- 1 file changed, 44 insertions(+), 38 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp index f9223ead..20abf3fc 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_grad.cpp @@ -35,7 +35,7 @@ namespace constexpr static int32_t GV_OUTPUT_INDEX = 6; constexpr static int32_t GSL_OUTPUT_INDEX = 7; constexpr static int32_t GAW_OUTPUT_INDEX = 8; - constexpr static int32_t BUFFER_NUM = 2; + constexpr static int32_t BUFFER_NUM = 1; constexpr static int32_t DOUB = 2; constexpr static int32_t T_BLOCK = 8; constexpr static uint16_t DST_BLK_STRIDE = 1; @@ -54,8 +54,8 @@ public: __aicore__ inline void init_local_tensor(); __aicore__ inline void process_grad_value_with_point(int32_t cur_nh, int32_t cur_nl, int32_t base_ptr, int32_t value_ptr_offset, - int32_t h, int32_t w); - __aicore__ inline void process_grad_weight_with_point(int32_t cur_nh, int32_t cur_np); + int32_t h, int32_t w, int32_t data_weight_ptr); + __aicore__ inline void process_grad_weight_with_point(int32_t cur_nh, int32_t cur_np, int32_t cur_nl, int32_t data_weight_ptr); __aicore__ inline void process(); __aicore__ inline void compute_mode_zero(); __aicore__ inline int32_t ceil(int32_t a, int32_t b); @@ -135,6 +135,9 @@ private: TBuf buffer_grad_weight; TBuf buffer_grad_weight_full; TBuf buffer_grad_sample_loc; + TBuf buffer_work_loc; + TBuf buffer_np; + TBuf buffer_block; int32_t spatial_size; int32_t cur_block_idx; int32_t cur_core_task_num; @@ -195,6 +198,9 @@ private: LocalTensor spatial_shapes_local; LocalTensor grad_weight_full_local; LocalTensor grad_sample_loc_local; + LocalTensor work_local; + LocalTensor np_local; + LocalTensor block_local; }; template @@ -283,9 +289,12 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::init_buffer() pipe.InitBuffer(buffer_w3_v3, channel_align * sizeof(T)); pipe.InitBuffer(buffer_w4_v4, channel_align * sizeof(T)); pipe.InitBuffer(buffer_val, point_channel_align * sizeof(T)); - pipe.InitBuffer(buffer_grad_weight, point_channel_align * sizeof(T)); - pipe.InitBuffer(buffer_grad_weight_full, (num_heads * num_levels * num_point * channel_align) * sizeof(T)); + pipe.InitBuffer(buffer_grad_weight, channel_align * sizeof(T)); + pipe.InitBuffer(buffer_grad_weight_full, (num_heads * num_levels * num_point) * sizeof(T)); pipe.InitBuffer(buffer_grad_sample_loc, (num_heads * num_levels * num_point * channel_align * DOUB) * sizeof(T)); + pipe.InitBuffer(buffer_work_loc, channel_align * sizeof(T)); + pipe.InitBuffer(buffer_np, num_point_align * sizeof(T)); + pipe.InitBuffer(buffer_block, num_point_align * sizeof(T)); } template @@ -333,9 +342,12 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::init_local_tensor() w3_v3_local = buffer_w3_v3.Get(channel_align); w4_v4_local = buffer_w4_v4.Get(channel_align); val_local = buffer_val.Get(point_channel_align); - grad_weight_local = buffer_grad_weight.Get(point_channel_align); - grad_weight_full_local = buffer_grad_weight_full.Get((num_heads * num_levels * num_point * channel_align) * sizeof(T)); - grad_sample_loc_local = buffer_grad_sample_loc.Get((num_heads * num_levels * num_point * channel_align * DOUB) * sizeof(T)); + grad_weight_local = buffer_grad_weight.Get(channel_align); + grad_weight_full_local = buffer_grad_weight_full.Get(num_heads * num_levels * num_point); + grad_sample_loc_local = buffer_grad_sample_loc.Get(num_heads * num_levels * num_point * channel_align * DOUB); + work_local = buffer_work_loc.Get(channel_align); + np_local = buffer_np.Get(num_point_align); + block_local = buffer_block.Get(num_point_align); } template @@ -349,7 +361,7 @@ template __aicore__ inline void MultiScaleDeformableAttentionGrad::process_grad_value_with_point(int32_t cur_nh, int32_t cur_nl, int32_t base_ptr, int32_t value_ptr_offset, - int32_t h, int32_t w) + int32_t h, int32_t w, int32_t data_weight_ptr) { for (int32_t cur_np = 0; cur_np < num_point; cur_np++) { @@ -388,7 +400,7 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::process_grad_value_ cal_grad_value(v4_local, h_high_ptr_offset_local, w_high_ptr_offset_local, lw_local, lh_local, w4_local, cur_np, base_ptr, value_ptr_offset, false, false); } - process_grad_weight_with_point(cur_nh, cur_np); + process_grad_weight_with_point(cur_nh, cur_np, cur_nl, data_weight_ptr); } } @@ -396,7 +408,9 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::process_grad_value_ template __aicore__ inline void MultiScaleDeformableAttentionGrad::process_grad_weight_with_point(int32_t cur_nh, - int32_t cur_np) + int32_t cur_np, + int32_t cur_nl, + int32_t data_weight_ptr) { auto w1 = w1_local.GetValue(cur_np); auto w2 = w2_local.GetValue(cur_np); @@ -419,8 +433,23 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::process_grad_weight pipe_barrier(PIPE_V); Add(val_local[cur_np * channels], val_local[cur_np * channels], w4_v4_local, channel_align); pipe_barrier(PIPE_V); - Mul(grad_weight_local[cur_np * channels], val_local[cur_np * channels], grad_output_local[cur_nh * channels], + Mul(grad_weight_local, val_local[cur_np * channels], grad_output_local[cur_nh * channels], channel_align); + + // new code + Duplicate(np_local, 0.0, num_point_align); + pipe_barrier(PIPE_ALL); + ReduceSum(block_local, grad_weight_local, work_local, channels); + pipe_barrier(PIPE_ALL); + auto cur = block_local.GetValue(0); + pipe_barrier(PIPE_ALL); + np_local.SetValue(cur_np, cur); + pipe_barrier(PIPE_ALL); + SetAtomicAdd(); + DataCopyParams copy_params{1, (uint16_t)(num_point_align * sizeof(float)), 0, 0}; + DataCopyPad(grad_attn_weight_gm[data_weight_ptr + (cur_nh * num_levels + cur_nl) * num_point], np_local, copy_params); + pipe_barrier(PIPE_ALL); + SetAtomicNone(); } template @@ -526,9 +555,8 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::post_process_levels muls_template(grad_h_weight_local, grad_h_weight_local, (float)h, point_channel_align); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); - Copy(grad_sample_loc_local[(cur_nh * num_levels + cur_nl) * num_point * 2 * channel_align], grad_w_weight_local, 64, num_point * channel_align / 64, { 1, 1, 8, 8 }); - Copy(grad_sample_loc_local[((cur_nh * num_levels + cur_nl) * 2 + 1) * num_point * channel_align], grad_h_weight_local, 64, num_point * channel_align / 64, { 1, 1, 8, 8 }); - Copy(grad_weight_full_local[(cur_nh * num_levels + cur_nl) * num_point * channel_align], grad_weight_local, 64, num_point * channel_align / 64, { 1, 1, 8, 8 }); + Copy(grad_sample_loc_local[(cur_nh * num_levels + cur_nl) * num_point * 2 * channels], grad_w_weight_local, 64, num_point * channels / 64, { 1, 1, 8, 8 }); + Copy(grad_sample_loc_local[((cur_nh * num_levels + cur_nl) * 2 + 1) * num_point * channels], grad_h_weight_local, 64, num_point * channels / 64, { 1, 1, 8, 8 }); } template @@ -554,7 +582,7 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::process_levels(int3 wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); pre_process_levels(cur_nh, cur_b, cur_nl, cur_q, sl_size, h_stride, w_stride, loc_h_offset, loc_w_offset, w, h, sample_location_offset); - process_grad_value_with_point(cur_nh, cur_nl, base_ptr, value_ptr_offset, h, w); + process_grad_value_with_point(cur_nh, cur_nl, base_ptr, value_ptr_offset, h, w, data_weight_ptr); pipe_barrier(PIPE_V); post_process_levels(w, h, cur_nl, data_weight_ptr, data_loc_w_ptr, cur_nh); } @@ -606,53 +634,33 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::compute_mode_zero() if (channels > 64) { auto mask = channels / t_per_block; - ran = num_heads * num_levels * num_point * t_per_block / time; ran1 = num_heads * num_levels * num_point * DOUB * t_per_block / time; - remain = num_heads * num_levels * num_point * t_per_block % time; remain1 = num_heads * num_levels * num_point * DOUB * t_per_block % time; - for (auto i = 0; i < ran; i++) - { - WholeReduceSum(attn_weight_local[i * time], grad_weight_full_local[i * time * mask], mask, time, 1, 1, (mask - 1) / t_per_block + 1); - } for (auto i = 0; i < ran1; i++) { WholeReduceSum(sample_location_local[i * time], grad_sample_loc_local[i * time * mask], mask, time, 1, 1, (mask - 1) / t_per_block + 1); } pipe_barrier(PIPE_V); - WholeReduceSum(attn_weight_local[ran * time], grad_weight_full_local[ran * time * mask], mask, remain, 1, 1, (mask - 1) / t_per_block + 1); WholeReduceSum(sample_location_local[ran1 * time], grad_sample_loc_local[ran1 * time * mask], mask, remain1, 1, 1, (mask - 1) / t_per_block + 1); pipe_barrier(PIPE_V); - ran = num_heads * num_levels * num_point / time; ran1 = num_heads * num_levels * num_point * DOUB / time; - remain = num_heads * num_levels * num_point % time; remain1 = num_heads * num_levels * num_point * DOUB % time; - for (auto i = 0; i < ran; i++) - { - WholeReduceSum(attn_weight_local[i * time], attn_weight_local[i * time * t_per_block], t_per_block, time, 1, 1, 1); - } for (auto i = 0; i < ran1; i++) { WholeReduceSum(sample_location_local[i * time], sample_location_local[i * time * t_per_block], t_per_block, time, 1, 1, 1); } pipe_barrier(PIPE_V); - WholeReduceSum(attn_weight_local[ran * time], attn_weight_local[ran * time * t_per_block], t_per_block, remain, 1, 1, 1); WholeReduceSum(sample_location_local[ran1 * time], sample_location_local[ran1 * time * t_per_block], t_per_block, remain1, 1, 1, 1); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); } else { - for (auto i = 0; i < ran; i++) - { - WholeReduceSum(attn_weight_local[i * time], grad_weight_full_local[i * time * channels], channels, time, 1, 1, (channels - 1) / t_per_block + 1); - } - pipe_barrier(PIPE_V); for (auto i = 0; i < ran1; i++) { WholeReduceSum(sample_location_local[i * time], grad_sample_loc_local[i * time * channels], channels, time, 1, 1, (channels - 1) / t_per_block + 1); } pipe_barrier(PIPE_V); - WholeReduceSum(attn_weight_local[ran * time], grad_weight_full_local[ran * time * channels], channels, remain, 1, 1, (channels - 1) / t_per_block + 1); WholeReduceSum(sample_location_local[ran1 * time], grad_sample_loc_local[ran1 * time * channels], channels, remain1, 1, 1, (channels - 1) / t_per_block + 1); set_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID3); @@ -660,9 +668,7 @@ __aicore__ inline void MultiScaleDeformableAttentionGrad::compute_mode_zero() SetAtomicAdd(); #ifndef __GET_CODE_CHANNEL__ - DataCopyParams copy_params{1, (uint16_t)(num_heads * num_levels * num_point * sizeof(float)), 0, 0}; DataCopyParams copy_params1{1, (uint16_t)(num_heads * num_levels * num_point * DOUB * sizeof(float)), 0, 0}; - DataCopyPad(grad_attn_weight_gm[data_weight_ptr], attn_weight_local, copy_params); DataCopyPad(grad_sampling_loc_gm[data_loc_w_ptr], sample_location_local, copy_params1); #endif SetAtomicNone(); -- Gitee