From b1e82d26c04877fd1632c2570d5c5aa20c995630 Mon Sep 17 00:00:00 2001 From: wangyan <13167191585@163.com> Date: Mon, 18 Mar 2024 20:06:00 +0800 Subject: [PATCH] AddConstraintsToScatterMax Type: Bugfix Team: PyTorch_Ops_Dev InventoryUpdate: False Issue: Issue_no Description: Add constraints to the ScatterMax operator. --- ads/common/ops/csrc/ScatterMaxKernelNpu.cpp | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp b/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp index b71c8a06..c41485fa 100644 --- a/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp +++ b/ads/common/ops/csrc/ScatterMaxKernelNpu.cpp @@ -3,16 +3,31 @@ using namespace std; +void npu_scatter_max_check(const at::Tensor& updates, const at::Tensor& indices) +{ + auto indicesSizes = indices.sizes(); + int32_t updatesLength = 1; + int32_t indicesLength = 1; + for (size_t i = 1; i < updates.dim(); i++) { + updatesLength *= updatesSizes[i]; + } + for (size_t i = 1; i < indices.dim(); i++) { + indicesLength *= indicesSizes[i]; + } + TORCH_CHECK(updatesLength % 8 == 0, "The dim of input tensor [indices] should be equal to 1."); + TORCH_CHECK(indicesLength == 1, "All the dims's range except the first dim of input tensor [indices] should be equal to 1."); + TORCH_CHECK(indices.sizes()[0] == updates.sizes()[0], "input's updates size of dim 0 should be equal to indices's size."); + TORCH_CHECK(indices.max().item().toLong() + 1 <= 122880, "The maximum value of input tensor [indices] should be less than 122880.") +} + std::tuple npu_scatter_max(const at::Tensor& updates, const at::Tensor& indices, c10::optional out) { + npu_scatter_max_check(updates, indices); auto sizes = updates.sizes().vec(); - sizes[0] = indices.max().item().toLong() + 1; - at::Tensor result = out.value_or(at::zeros(sizes, updates.options().dtype(at::kFloat))); at::Tensor argmax = at::empty(result.sizes(), result.options().dtype(at::kInt)); - at_npu::native::OpCommand cmd; cmd.Name("ScatterMaxWithArgmax").Input(result).Input(indices).Input(updates).Output(result).Output(argmax).Run(); -- Gitee