diff --git a/docs/api/context/multi_scale_deformable_attn.md b/docs/api/context/multi_scale_deformable_attn.md index 2460a0d3a0a16c647356df87dca3368cd872f3ec..8d50b9b31da1b1270449ee6069aba001cf96b61f 100644 --- a/docs/api/context/multi_scale_deformable_attn.md +++ b/docs/api/context/multi_scale_deformable_attn.md @@ -21,7 +21,7 @@ mx_driving.point.npu_multi_scale_deformable_attn_function(Tensor value, Tensor v - Atlas A2 训练系列产品 ### 约束说明 - 当前版本只支持`num_points * num_levels` ≤ 64,`num_heads` ≤ 8,`embed_dims` ≤ 64。 -- 注意:当`num_points * num_levels` = 64且`embed_dims` = 64时,`num_heads`最大为7。 +- 注意:当`num_points * num_levels` = 64且`(embed_dims + 7) / 8 * 8` = 64时,`num_heads`最大为7。 ### 调用示例 ```python import torch, torch_npu