diff --git a/kernels/op_host/group_points_grad.cpp b/kernels/op_host/group_points_grad.cpp index 88007519e4d70fb28253c1456c28aa14997d3d87..61f46acb4dee710152f031444179e4d9cd07b56c 100644 --- a/kernels/op_host/group_points_grad.cpp +++ b/kernels/op_host/group_points_grad.cpp @@ -48,6 +48,9 @@ static ge::graphStatus TilingForGroupPointsGrad(gert::TilingContext* context) if (core_num == 0) { return ge::GRAPH_FAILED; } + if (context->GetInputDesc(0) == nullptr || context->GetAttrs() == nullptr) { + return ge::GRAPH_FAILED; + } auto dtype = context->GetInputDesc(0)->GetDataType(); if (ge::DT_FLOAT == dtype) { @@ -98,6 +101,10 @@ static ge::graphStatus TilingForGroupPointsGrad(gert::TilingContext* context) tiling.set_taskLast(taskLast); tiling.set_usedCoreNum(usedCoreNum); + if (context->GetRawTilingData() == nullptr) { + return ge::GRAPH_FAILED; + } + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), context->GetRawTilingData()->GetCapacity()); context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); @@ -109,6 +116,9 @@ static ge::graphStatus TilingForGroupPointsGrad(gert::TilingContext* context) namespace ge { static ge::graphStatus InferShapeForGroupPointsGrad(gert::InferShapeContext* context) { + if (context->GetOutputShape(0) == nullptr || context->GetAttrs() == nullptr) { + return ge::GRAPH_SUCCESS; + } auto attrs = context->GetAttrs(); auto getAttr = [attrs](size_t idx) -> int32_t { auto ptr = attrs->GetInt(idx);