diff --git a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp index 1bc2cdc3232ed8fd68efb24f161943187867102d..89873d75a3eec6e5c60c7da57a69b39b22a75839 100644 --- a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp +++ b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp @@ -109,9 +109,9 @@ std::tuple multi_scale_deformable_attn_grad( auto grad_sample_loc_size = {location_size[0], location_size[1], location_size[2], location_size[3], location_size[5], location_size[4]}; at::Tensor value1 = value.transpose(1, 2).contiguous(); at::Tensor location1 = location.transpose(4, 5).contiguous(); - at::Tensor result1 = at::empty(grad_value_size, value.options().dtype(at::kFloat)); - at::Tensor result2 = at::empty(grad_sample_loc_size, location.options().dtype(at::kFloat)); - at::Tensor result3 = at::empty(grad_atten_weight_size, attn_weight.options().dtype(at::kFloat)); + at::Tensor result1 = at::zeros(grad_value_size, value.options().dtype(at::kFloat)); + at::Tensor result2 = at::zeros(grad_sample_loc_size, location.options().dtype(at::kFloat)); + at::Tensor result3 = at::zeros(grad_atten_weight_size, attn_weight.options().dtype(at::kFloat)); at::Tensor value_fp = value1.to(at::kFloat); at::Tensor shape_fp = shape.to(at::kInt);