diff --git a/kernels/op_host/bev_pool.cpp b/kernels/op_host/bev_pool.cpp index 3acdc5783f0bc2d259f27bfd948f37961e5f65ce..22e3f09144325604420fc785b1fde3e76de1cea3 100644 --- a/kernels/op_host/bev_pool.cpp +++ b/kernels/op_host/bev_pool.cpp @@ -134,23 +134,33 @@ static graphStatus InferShapeForBEVPool(gert::InferShapeContext* context) static graphStatus InferShapeForBEVPoolGrad(gert::InferShapeContext* context) { + CHECK_NULLPTR(context); const gert::Shape* GeomFeatShape = context->GetInputShape(GEOM_FEAT_IDX); + CHECK_NULLPTR(GeomFeatShape); const auto n = GeomFeatShape->GetDim(0); auto attrs = context->GetAttrs(); - CHECK_NULLPTR(attrs) - auto c = *attrs->GetInt(C_IDX); + CHECK_NULLPTR(attrs); + auto c_ptr = attrs->GetInt(C_IDX); + CHECK_NULLPTR(c_ptr); + auto c = *c_ptr; gert::Shape* gradFeatShape = context->GetOutputShape(0); + CHECK_NULLPTR(gradFeatShape); *gradFeatShape = {n, c}; return GRAPH_SUCCESS; } static graphStatus InferShapeForBEVPoolV2Grad(gert::InferShapeContext* context) { + CHECK_NULLPTR(context); gert::Shape* gradDepthShape = context->GetOutputShape(0); const gert::Shape* depthShape = context->GetInputShape(1); + CHECK_NULLPTR(gradDepthShape); + CHECK_NULLPTR(depthShape); *gradDepthShape = *depthShape; gert::Shape* gradFeatShape = context->GetOutputShape(1); const gert::Shape* featShape = context->GetInputShape(2); + CHECK_NULLPTR(gradFeatShape); + CHECK_NULLPTR(featShape); *gradFeatShape = *featShape; return GRAPH_SUCCESS; }