diff --git a/mx_driving/csrc/MultiScaleDeformableAttn.cpp b/mx_driving/csrc/MultiScaleDeformableAttn.cpp index 8d54973a1aaa3a789888433c54cbe3d1ccba670a..3d65199eaf2446f6fd7fee387a2b58d130703756 100644 --- a/mx_driving/csrc/MultiScaleDeformableAttn.cpp +++ b/mx_driving/csrc/MultiScaleDeformableAttn.cpp @@ -101,12 +101,12 @@ std::tuple multi_scale_deformable_attn_backw } if (ASCEND_UNLIKELY(value.scalar_type() == at::kHalf)) { - at::Tensor grad_value_fp32 = grad_value.to(at::kFloat); + at::Tensor grad_output_fp32 = grad_output.to(at::kFloat); at::Tensor value_fp32 = value.to(at::kFloat); at::Tensor sampling_locations_fp32 = sampling_locations.to(at::kFloat); at::Tensor attention_weights_fp32 = attention_weights.to(at::kFloat); EXEC_NPU_CMD(aclnnMultiScaleDeformableAttnGrad, value_fp32, value_spatial_shapes, value_level_start_index, - sampling_locations_fp32, attention_weights_fp32, grad_value_fp32, grad_value, grad_sampling_loc, + sampling_locations_fp32, attention_weights_fp32, grad_output_fp32, grad_value, grad_sampling_loc, grad_attn_weight); return std::make_tuple( grad_value.to(at::kHalf), grad_sampling_loc.to(at::kHalf), grad_attn_weight.to(at::kHalf)); diff --git a/tests/torch/test_multi_scale_deformable_attn.py b/tests/torch/test_multi_scale_deformable_attn.py index 1265bac41ac8f0cf23ab690042866712c663b61b..9d7491a8dc7f74b3b6f8a3bafb17a9e2a2cc7b45 100644 --- a/tests/torch/test_multi_scale_deformable_attn.py +++ b/tests/torch/test_multi_scale_deformable_attn.py @@ -202,6 +202,16 @@ class TestMultiScaleDeformableAttnFunction(TestCase): self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights) self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations) + def test_fp16(self): + shape = [6, 9680, 32, 8, 4, 4] + cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float16) + cpu_results = self.cpu_to_exec(cpu_inputs) + npu_results = self.npu_to_exec(npu_inputs) + self.assertRtolEqual(cpu_results.output, npu_results.output) + self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value) + self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights) + self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations) + if __name__ == "__main__": run_tests()