From bbb6eb0708f6184952e8d2b05fa3fbf9bad25d85 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B4=BA=E9=9B=A8=E6=9D=B0?= Date: Wed, 18 Dec 2024 07:10:19 +0000 Subject: [PATCH] update kernels/op_host/group_points_grad.cpp. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 贺雨杰 --- kernels/op_host/group_points_grad.cpp | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/kernels/op_host/group_points_grad.cpp b/kernels/op_host/group_points_grad.cpp index 88007519..61f46acb 100644 --- a/kernels/op_host/group_points_grad.cpp +++ b/kernels/op_host/group_points_grad.cpp @@ -48,6 +48,9 @@ static ge::graphStatus TilingForGroupPointsGrad(gert::TilingContext* context) if (core_num == 0) { return ge::GRAPH_FAILED; } + if (context->GetInputDesc(0) == nullptr || context->GetAttrs() == nullptr) { + return ge::GRAPH_FAILED; + } auto dtype = context->GetInputDesc(0)->GetDataType(); if (ge::DT_FLOAT == dtype) { @@ -98,6 +101,10 @@ static ge::graphStatus TilingForGroupPointsGrad(gert::TilingContext* context) tiling.set_taskLast(taskLast); tiling.set_usedCoreNum(usedCoreNum); + if (context->GetRawTilingData() == nullptr) { + return ge::GRAPH_FAILED; + } + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); @@ -109,6 +116,9 @@ static ge::graphStatus TilingForGroupPointsGrad(gert::TilingContext* context) namespace ge { static ge::graphStatus InferShapeForGroupPointsGrad(gert::InferShapeContext* context) { + if (context->GetOutputShape(0) == nullptr || context->GetAttrs() == nullptr) { + return ge::GRAPH_SUCCESS; + } auto attrs = context->GetAttrs(); auto getAttr = [attrs](size_t idx) -> int32_t { auto ptr = attrs->GetInt(idx); -- Gitee