From c3a41aca0bc7a1b0132b45e6c69466541cc455df Mon Sep 17 00:00:00 2001 From: jayshu Date: Fri, 23 Feb 2024 16:42:15 +0800 Subject: [PATCH] add operator constraints --- .../ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp | 7 +++++++ .../op_host/multi_scale_deformable_attention_grad.cpp | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp index e1e74560..8f7529ff 100644 --- a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp +++ b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp @@ -93,6 +93,13 @@ std::tuple multi_scale_deformable_attn_grad( auto ori_dtype = value.scalar_type(); auto value_size = value.sizes(); auto location_size = location.sizes(); + auto channels = value_size[3]; + auto num_points = location_size[4]; + auto num_levels = location_size[3]; + auto data_total = channels + num_points + num_levels; + TORCH_CHECK(data_total < 512, "data_total is over 512: channels ", channels, " num_points is ", + num_points, " num_level is ", num_levels, "."); + TORCH_CHECK(channels % 8 == 0, "channels must be a multiple of eight, but channels is", channels, "."); auto grad_value_size = {value_size[0], value_size[1], value_size[2], value_size[3]}; auto grad_atten_weight_size = {location_size[0], location_size[1], location_size[2], location_size[3], location_size[4]}; auto grad_sample_loc_size = {location_size[0], location_size[1], location_size[2], location_size[3], location_size[5], location_size[4]}; diff --git a/ads/common/ops/kernels/op_host/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_host/multi_scale_deformable_attention_grad.cpp index 54497529..148674a3 100644 --- a/ads/common/ops/kernels/op_host/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_host/multi_scale_deformable_attention_grad.cpp @@ -31,7 +31,7 @@ namespace optiling auto channels = value_shape.GetDim(3); auto num_query = sampling_loc_shape.GetDim(1); auto num_levels = sampling_loc_shape.GetDim(3); - auto num_point = sampling_loc_shape.GetDim(4); + auto num_point = sampling_loc_shape.GetDim(5); auto task_per_core = (batch_size * num_query - 1) / core_num + 1; auto core_used = (batch_size * num_query - 1) / task_per_core + 1; auto task_tail_core = batch_size * num_query - (core_used - 1) * task_per_core; -- Gitee