From c78c29849dc1c844eb227e59b4aadce8f3c1aec1 Mon Sep 17 00:00:00 2001 From: huangyuan Date: Mon, 1 Sep 2025 11:40:12 +0800 Subject: [PATCH] fix the msda ops inf and nan inputs bug --- kernels/op_kernel/msda.h | 31 +++++++++++-- .../op_kernel/multi_scale_deformable_attn.cpp | 2 +- .../multi_scale_deformable_attn_grad.cpp | 19 +++++--- .../torch/test_multi_scale_deformable_attn.py | 43 ++++++++++++++++--- 4 files changed, 78 insertions(+), 17 deletions(-) diff --git a/kernels/op_kernel/msda.h b/kernels/op_kernel/msda.h index de86a205..85cdafcd 100644 --- a/kernels/op_kernel/msda.h +++ b/kernels/op_kernel/msda.h @@ -232,7 +232,7 @@ protected: } __aicore__ inline void ComputeLocation(uint32_t taskIdx, const LocalTensor& locationFloat, - const LocalTensor& locationInt, const LocalTensor& shapeFloat, + const LocalTensor& attentionWeight, const LocalTensor& locationInt, const LocalTensor& shapeFloat, const LocalTensor& shapeInt, const LocalTensor& locFloat, const LocalTensor& locInt, const LocalTensor& offsetInt, const LocalTensor& validFlag); @@ -294,7 +294,7 @@ protected: template __aicore__ inline void MSDABaseKernel::ComputeLocation(uint32_t taskIdx, - const LocalTensor& locationFloat, const LocalTensor& locationInt, + const LocalTensor& locationFloat, const LocalTensor& attentionWeight, const LocalTensor& locationInt, const LocalTensor& shapeFloat, const LocalTensor& shapeInt, const LocalTensor& locFloat, const LocalTensor& locInt, const LocalTensor& offsetInt, const LocalTensor& validFlag) { @@ -309,7 +309,30 @@ __aicore__ inline void MSDABaseKernel::ComputeLocati ResetMask(); Mul(locationFloat, locationFloat, shapeFloat, MASK_PLACEHOLDER, 2 * taskRpt_, {1, 1, 1, 8, 8, 8}); - Adds(locFloat, locationFloat, 0.5f, MASK_PLACEHOLDER, 2 * taskRpt_, {1, 1, 8, 8}); + Adds(locFloat, locationFloat, -0.5f, MASK_PLACEHOLDER, 2 * taskRpt_, {1, 1, 8, 8}); + // fix the sampling location inputs nan and inf bug + CompareScalar(validFlag[4 * validFlagMaskLen_], locFloat, -1.f, + CMPMODE::GT, MASK_PLACEHOLDER, taskRpt_, {1, 1, 8, 8}); + CompareScalar(validFlag[5 * validFlagMaskLen_], locFloat[alignedOneTaskNum_], -1.f, + CMPMODE::GT, MASK_PLACEHOLDER, taskRpt_, {1, 1, 8, 8}); + Compare(validFlag[6 * validFlagMaskLen_], locFloat, shapeFloat, + CMPMODE::LT, MASK_PLACEHOLDER, taskRpt_, {1, 1, 1, 8, 8, 8}); + Compare(validFlag[7 * validFlagMaskLen_], locFloat[alignedOneTaskNum_], shapeFloat[alignedOneTaskNum_], + CMPMODE::LT, MASK_PLACEHOLDER, taskRpt_, {1, 1, 1, 8, 8, 8}); + And(validFlag[4 * validFlagMaskLen_].ReinterpretCast(), + validFlag[4 * validFlagMaskLen_].ReinterpretCast(), + validFlag[6 * validFlagMaskLen_].ReinterpretCast(), MASK_PLACEHOLDER, 2, {1, 1, 1, 8, 8, 8}); + And(validFlag[5 * validFlagMaskLen_].ReinterpretCast(), + validFlag[4 * validFlagMaskLen_].ReinterpretCast(), + validFlag[5 * validFlagMaskLen_].ReinterpretCast(), MASK_PLACEHOLDER, 1, {1, 1, 1, 8, 8, 8}); + Select(locFloat, validFlag[5 * validFlagMaskLen_], locFloat, + -2.0f, SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, taskRpt_, {1, 1, 1, 8, 8, 8}); + Select(locFloat[alignedOneTaskNum_], validFlag[5 * validFlagMaskLen_], locFloat[alignedOneTaskNum_], + -2.0f, SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, taskRpt_, {1, 1, 1, 8, 8, 8}); + Select(attentionWeight, validFlag[5 * validFlagMaskLen_], attentionWeight, + 0.0f, SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, taskRpt_, {1, 1, 1, 8, 8, 8}); + // fix end + Adds(locFloat, locFloat, 1.0f, MASK_PLACEHOLDER, 2 * taskRpt_, {1, 1, 8, 8}); Cast(locInt, locFloat, RoundMode::CAST_FLOOR, MASK_PLACEHOLDER, 2 * taskRpt_, {1, 1, 8, 8}); // fix the precesion issue of the floor operation(0.9999f -> 1.0f) Cast( @@ -632,6 +655,6 @@ private: __aicore__ inline void ComputeGrad(const LocalTensor& production, const LocalTensor& locFloat, const LocalTensor& weight, const LocalTensor& attentionWeight, const LocalTensor& gradLocation, const LocalTensor& gradAttentionWeight, - const LocalTensor& gatherOffset, uint32_t taskIdx); + const LocalTensor& gatherOffset, const LocalTensor& validFlag, uint32_t taskIdx); }; #endif // MSDA_H diff --git a/kernels/op_kernel/multi_scale_deformable_attn.cpp b/kernels/op_kernel/multi_scale_deformable_attn.cpp index fe100587..6efeb6e4 100644 --- a/kernels/op_kernel/multi_scale_deformable_attn.cpp +++ b/kernels/op_kernel/multi_scale_deformable_attn.cpp @@ -194,7 +194,7 @@ __aicore__ inline void MultiScaleDeformableAttnKernel::Proces UpdateParams(this->endOffset_ - taskIdx); } this->CopyInSample(locationFloat[2 * this->alignedOneTaskNum_], attentionWeight, taskIdx); - this->ComputeLocation(taskIdx, locationFloat, locationInt, shapeFloat, shapeInt, locFloat, locInt, offsetInt, + this->ComputeLocation(taskIdx, locationFloat, attentionWeight, locationInt, shapeFloat, shapeInt, locFloat, locInt, offsetInt, validFlag.ReinterpretCast()); ComputeBilinearInterpolation(validFlag, shapeInt, locationInt, locInt, shapeFloat, production, value, locFloat, weight, attentionWeight, cornerWeightBrc, output); diff --git a/kernels/op_kernel/multi_scale_deformable_attn_grad.cpp b/kernels/op_kernel/multi_scale_deformable_attn_grad.cpp index 126ad630..761696d4 100644 --- a/kernels/op_kernel/multi_scale_deformable_attn_grad.cpp +++ b/kernels/op_kernel/multi_scale_deformable_attn_grad.cpp @@ -194,8 +194,8 @@ __aicore__ inline void MultiScaleDeformableAttnGradKernel::Co template __aicore__ inline void MultiScaleDeformableAttnGradKernel::ComputeGrad( const LocalTensor& production, const LocalTensor& locFloat, const LocalTensor& weight, - const LocalTensor& attentionWeight, const LocalTensor& gradLocation, - const LocalTensor& gradAttentionWeight, const LocalTensor& gatherOffset, uint32_t taskIdx) + const LocalTensor& attentionWeight, const LocalTensor& gradLocation, const LocalTensor& gradAttentionWeight, + const LocalTensor& gatherOffset, const LocalTensor& validFlag, uint32_t taskIdx) { uint64_t sampleOffset = taskIdx * this->oneQueryNum_; Mul(production, weight, production, MASK_PLACEHOLDER, 4 * this->taskRpt_, {1, 1, 1, 8, 8, 8}); @@ -204,6 +204,10 @@ __aicore__ inline void MultiScaleDeformableAttnGradKernel::Co WaitFlag(1); Add(gradAttentionWeight, production, production[this->alignedOneTaskNum_], MASK_PLACEHOLDER, this->taskRpt_, {1, 1, 1, 8, 8, 8}); + // fix the sampling location inputs nan and inf bug + Select(gradAttentionWeight, validFlag[5 * this->validFlagMaskLen_ / 8], gradAttentionWeight, + 0.0f, SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, this->taskRpt_, {1, 1, 1, 8, 8, 8}); + // fix end Sub(gradLocation, weight[3 * this->alignedOneTaskNum_], weight[this->alignedOneTaskNum_], MASK_PLACEHOLDER, this->taskRpt_, {1, 1, 1, 8, 8, 8}); @@ -216,6 +220,12 @@ __aicore__ inline void MultiScaleDeformableAttnGradKernel::Co Mul(gradLocation, locFloat, gradLocation, MASK_PLACEHOLDER, 4 * this->taskRpt_, {1, 1, 1, 8, 8, 8}); Add(gradLocation[2 * this->alignedOneTaskNum_], gradLocation, gradLocation[2 * this->alignedOneTaskNum_], MASK_PLACEHOLDER, 2 * this->taskRpt_, {1, 1, 1, 8, 8, 8}); + // fix the sampling location inputs nan and inf bug + Select(gradLocation[2 * this->alignedOneTaskNum_], validFlag[5 * this->validFlagMaskLen_ / 8], + gradLocation[2 * this->alignedOneTaskNum_], 0.0f, SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, this->taskRpt_, {1, 1, 1, 8, 8, 8}); + Select(gradLocation[3 * this->alignedOneTaskNum_], validFlag[5 * this->validFlagMaskLen_ / 8], + gradLocation[3 * this->alignedOneTaskNum_], 0.0f, SELMODE::VSEL_TENSOR_SCALAR_MODE, 64, this->taskRpt_, {1, 1, 1, 8, 8, 8}); + // fix end Gather(gradLocation, gradLocation[2 * this->alignedOneTaskNum_], gatherOffset, 0, 64, 2 * this->taskRpt_, 8); SetFlag(1); @@ -272,12 +282,11 @@ __aicore__ inline void MultiScaleDeformableAttnGradKernel::Pr } this->CopyInSample(locationFloat[2 * this->alignedOneTaskNum_], attentionWeight, taskIdx); CopyInGradOut(gradOut, taskIdx); - this->ComputeLocation(taskIdx, locationFloat, locationInt, shapeFloat, shapeInt, locFloat, locInt, offsetInt, + this->ComputeLocation(taskIdx, locationFloat, attentionWeight, locationInt, shapeFloat, shapeInt, locFloat, locInt, offsetInt, validFlag.ReinterpretCast()); ComputeBilinearInterpolation(validFlag, shapeInt, locationInt, locInt, shapeFloat, production, value, locFloat, weight, attentionWeight, cornerWeightBrc, gradOut, gradValue); - ComputeGrad( - production, locFloat, weight, attentionWeight, gradLocation, gradAttentionWeight, gatherOffset, taskIdx); + ComputeGrad(production, locFloat, weight, attentionWeight, gradLocation, gradAttentionWeight, gatherOffset, validFlag, taskIdx); } WaitFlag(this->copyEvt_); WaitFlag(0); diff --git a/tests/torch/test_multi_scale_deformable_attn.py b/tests/torch/test_multi_scale_deformable_attn.py index 9d7491a8..d0044c8a 100644 --- a/tests/torch/test_multi_scale_deformable_attn.py +++ b/tests/torch/test_multi_scale_deformable_attn.py @@ -1,6 +1,7 @@ import unittest from collections import namedtuple +import numpy as np import torch import torch_npu from data_cache import golden_data_cache @@ -11,7 +12,7 @@ import mx_driving # pylint: disable=too-many-return-values @golden_data_cache(__file__) -def cpu_gen_inputs(shape): +def cpu_gen_inputs(shape, dtype): bs, num_queries, embed_dims, num_heads, num_levels, num_points = shape shapes = torch.tensor([60, 40] * num_levels).reshape(num_levels, 2) num_keys = sum((H * W).item() for H, W in shapes) @@ -22,6 +23,11 @@ def cpu_gen_inputs(shape): offset = torch.cat((shapes.new_zeros((1,)), shapes.prod(1).cumsum(0)[:-1])) grad_output = torch.rand(bs, num_queries, num_heads * embed_dims) * 1e-3 + value = value.to(dtype) + sampling_locations = sampling_locations.to(dtype) + attention_weights = attention_weights.to(dtype) + grad_output = grad_output.to(dtype) + return shapes, num_keys, value, sampling_locations, attention_weights, offset, grad_output @@ -75,9 +81,14 @@ Inputs = namedtuple("Inputs", ["value", "shapes", "offset", "sampling_locations" class TestMultiScaleDeformableAttnFunction(TestCase): - def gen_inputs(self, shape, dtype): + def gen_inputs(self, shape, dtype, data=None): bs, num_queries, embed_dims, num_heads, num_levels, num_points = shape - shapes, num_keys, value, sampling_locations, attention_weights, offset, grad_output = cpu_gen_inputs(shape) + shapes, _, value, sampling_locations, attention_weights, offset, grad_output = cpu_gen_inputs(shape, dtype) + if data is not None: + value.fill_(data) + sampling_locations.fill_(data) + attention_weights.fill_(data) + grad_output.fill_(data) cpu_value = value.double() cpu_shapes = shapes.long() @@ -207,10 +218,28 @@ class TestMultiScaleDeformableAttnFunction(TestCase): cpu_inputs, npu_inputs = self.gen_inputs(shape, torch.float16) cpu_results = self.cpu_to_exec(cpu_inputs) npu_results = self.npu_to_exec(npu_inputs) - self.assertRtolEqual(cpu_results.output, npu_results.output) - self.assertRtolEqual(cpu_results.grad_value, npu_results.grad_value) - self.assertRtolEqual(cpu_results.grad_attention_weights, npu_results.grad_attention_weights) - self.assertRtolEqual(cpu_results.grad_sampling_locations, npu_results.grad_sampling_locations) + self.assertRtolEqual(cpu_results.output.astype(np.float16), npu_results.output) + self.assertRtolEqual(cpu_results.grad_value.astype(np.float16), npu_results.grad_value) + self.assertRtolEqual(cpu_results.grad_attention_weights.astype(np.float16), npu_results.grad_attention_weights) + self.assertRtolEqual(cpu_results.grad_sampling_locations.astype(np.float16), npu_results.grad_sampling_locations) + + def test_nan(self): + shape = [6, 9680, 32, 8, 4, 4] + _, npu_inputs = self.gen_inputs(shape, torch.float32, float('nan')) + npu_results = self.npu_to_exec(npu_inputs) + self.assertRtolEqual(np.zeros_like(npu_results.output), npu_results.output) + self.assertRtolEqual(np.zeros_like(npu_results.grad_value), npu_results.grad_value) + self.assertRtolEqual(np.zeros_like(npu_results.grad_attention_weights), npu_results.grad_attention_weights) + self.assertRtolEqual(np.zeros_like(npu_results.grad_sampling_locations), npu_results.grad_sampling_locations) + + def test_inf(self): + shape = [6, 9680, 32, 8, 4, 4] + _, npu_inputs = self.gen_inputs(shape, torch.float32, float('inf')) + npu_results = self.npu_to_exec(npu_inputs) + self.assertRtolEqual(np.zeros_like(npu_results.output), npu_results.output) + self.assertRtolEqual(np.zeros_like(npu_results.grad_value), npu_results.grad_value) + self.assertRtolEqual(np.zeros_like(npu_results.grad_attention_weights), npu_results.grad_attention_weights) + self.assertRtolEqual(np.zeros_like(npu_results.grad_sampling_locations), npu_results.grad_sampling_locations) if __name__ == "__main__": -- Gitee