From 9e52e0fc5260ccac749b52a3e6201a1d1889673d Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Mon, 25 Mar 2024 11:15:02 +0800 Subject: [PATCH] fix precision of msdagrad --- .../ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp index 1bc2cdc3..89873d75 100644 --- a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp +++ b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp @@ -109,9 +109,9 @@ std::tuple multi_scale_deformable_attn_grad( auto grad_sample_loc_size = {location_size[0], location_size[1], location_size[2], location_size[3], location_size[5], location_size[4]}; at::Tensor value1 = value.transpose(1, 2).contiguous(); at::Tensor location1 = location.transpose(4, 5).contiguous(); - at::Tensor result1 = at::empty(grad_value_size, value.options().dtype(at::kFloat)); - at::Tensor result2 = at::empty(grad_sample_loc_size, location.options().dtype(at::kFloat)); - at::Tensor result3 = at::empty(grad_atten_weight_size, attn_weight.options().dtype(at::kFloat)); + at::Tensor result1 = at::zeros(grad_value_size, value.options().dtype(at::kFloat)); + at::Tensor result2 = at::zeros(grad_sample_loc_size, location.options().dtype(at::kFloat)); + at::Tensor result3 = at::zeros(grad_atten_weight_size, attn_weight.options().dtype(at::kFloat)); at::Tensor value_fp = value1.to(at::kFloat); at::Tensor shape_fp = shape.to(at::kInt); -- Gitee