From 56c26c0ee5cf13e52f47ef56bcb4beede7417263 Mon Sep 17 00:00:00 2001 From: zhuweichen Date: Fri, 29 Mar 2024 17:18:30 +0800 Subject: [PATCH] update msdagrad --- .../op_kernel/multi_scale_deformable_attention_v2_grad.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_v2_grad.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_v2_grad.cpp index 73043d10..aca37b9d 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_v2_grad.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attention_v2_grad.cpp @@ -350,6 +350,7 @@ private: } } SetFlag(eventIdMte3ToV); + SetFlag(eventIdVToMte2); Mul(tmpLocal, zerosLocal[topGradValueId * baseOffsetUb], zerosLocal[gradWWeightId * baseOffsetUb], numPoints * embedDims); Muls(gradSampleXLocLocal, tmpLocal, (DTYPE_VALUE)w, numPoints * embedDims); @@ -370,7 +371,6 @@ private: WaitFlag(eventIdVToMte3Y); DataCopyPad(gradLocationGm[offsetLocation + level * 2 * numPoints + numPoints], yLocal, copyParams); WaitFlag(eventIdMte3ToV); - SetFlag(eventIdVToMte2); WaitFlag(eventIdVToMte2); } } -- Gitee