diff --git a/ads/common/__init__.py b/ads/common/__init__.py index ca4f48b01b5dbf8ce013e8be03a890032c33ebc6..01f4df19d3d4f28a3eb57f3ecd0c9c84c903f395 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -30,3 +30,4 @@ from .ops.npu_multi_scale_deformable_attn_function import npu_multi_scale_deform from .ops.dynamic_voxelization import voxelization from .ops.dynamic_voxelization import Voxelization from .ops.nms3d_normal import npu_nms3d_normal +from .ops.furthest_point_sampling import npu_furthest_point_sampling diff --git a/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp b/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..29d775fe1e0390040173dc75aa19c6b1dd1485fd --- /dev/null +++ b/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "OpApiCommon.h" +#include "functions.h" + +at::Tensor npu_furthest_point_sampling(const at::Tensor &point_xyz, const at::Tensor &nearset_temp, const int32_t num_points) +{ + at::Tensor output = at::empty({static_cast(point_xyz.sizes()[0]), static_cast(num_points)}, + nearset_temp.options().dtype(at::kInt)); + EXEC_NPU_CMD(aclnnFurthestPointSampling, point_xyz, nearset_temp, num_points, output); + return output; +} \ No newline at end of file diff --git a/ads/common/ops/csrc/functions.h b/ads/common/ops/csrc/functions.h index cf901d9cc56948605f33ae8be666b10a79e384fb..9af883c2e248a5af1caa95b11c8fc9e926854b9f 100644 --- a/ads/common/ops/csrc/functions.h +++ b/ads/common/ops/csrc/functions.h @@ -144,6 +144,7 @@ at::Tensor npu_multi_scale_deformable_attn_function(const at::Tensor& value, const at::Tensor& attention_weights); std::tuple multi_scale_deformable_attn_grad(const at::Tensor& value, const at::Tensor& shape, const at::Tensor& level_start_index, const at::Tensor& location, const at::Tensor& attn_weight, const at::Tensor& grad_output); +at::Tensor npu_furthest_point_sampling(const at::Tensor &point_xyz, const at::Tensor &nearset_temp, const int32_t num_points); at::Tensor npu_ads_add(const at::Tensor &tensor1, const at::Tensor &tensor2); at::Tensor DynamicVoxelization( diff --git a/ads/common/ops/csrc/pybind.cpp b/ads/common/ops/csrc/pybind.cpp index 3d2c805cd16c06284a63aeec820fca005e481ffd..0ab27ddeb88d028371dd0743a1a3b852bee32f5c 100644 --- a/ads/common/ops/csrc/pybind.cpp +++ b/ads/common/ops/csrc/pybind.cpp @@ -91,7 +91,10 @@ void init_common(pybind11::module &m) // dyn_voxelization m.def("dynamic_voxelization", &DynamicVoxelization); - + // nms3d_normal m.def("nms3d_normal", &nms3d_normal); + + // npu_furthest_point_sampling + m.def("npu_furthest_point_sampling", &npu_furthest_point_sampling); } diff --git a/ads/common/ops/furthest_point_sampling.py b/ads/common/ops/furthest_point_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..2ad30a54bc45ffdf4c64f5f43fbda22a02d10678 --- /dev/null +++ b/ads/common/ops/furthest_point_sampling.py @@ -0,0 +1,22 @@ +import numpy as np +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads_c + + +class AdsFurthestPointSampling(Function): + @staticmethod + def forward(ctx, point_xyz, num_points): + B, N = point_xyz.size()[:2] + point_xyz = point_xyz.permute(0, 2, 1).contiguous() + + nearest_dist = torch.tensor(np.ones((B, N)) * 1e10, dtype=torch.float32, device='npu').contiguous() + output = ads_c.npu_furthest_point_sampling(point_xyz, nearest_dist, num_points) + + return output + + +npu_furthest_point_sampling = AdsFurthestPointSampling.apply \ No newline at end of file diff --git a/ads/common/ops/kernels/inc/base.h b/ads/common/ops/kernels/inc/base.h index 3c853e6cf5157779e7a6379a67ea0a95619672bc..b0fe79bfe9b714c8640a969a633a4bb0ba9988c0 100644 --- a/ads/common/ops/kernels/inc/base.h +++ b/ads/common/ops/kernels/inc/base.h @@ -3,14 +3,23 @@ // .INPUT(x2, TensorType({DT_FLOAT})) // .OUTPUT(y, TensorType({DT_FLOAT})) // .OP_END_FACTORY_REG(Add) + // REG_OP(FurthestPointSamplingWithDist) // .INPUT(points_dist, TensorType({DT_FLOAT})) // .INPUT(nearest_temp, TensorType({DT_FLOAT})) // .OUTPUT(index, TensorType({DT_INT32})) // .REQUIRED_ATTR(num_points, Int) // .OP_END_FACTORY_REG(FurthestPointSamplingWithDist) + // REG_OP(Nms3dNormal) // .INPUT(boxes, TensorType({DT_FLOAT, DT_FLOAT16})) // .OUTPUT(keep, TensorType({DT_INT16})) // .REQUIRED_ATTR(nms_overlap_thresh, Float) // .OP_END_FACTORY_REG(Nms3dNormal) + +// REG_OP(FurthestPointSampling) +// .INPUT(point_xyz, TensorType({DT_FLOAT})) +// .INPUT(nearest_temp, TensorType({DT_FLOAT})) +// .OUTPUT(index, TensorType({DT_INT32})) +// .REQUIRED_ATTR(num_points, Int) +// .OP_END_FACTORY_REG(FurthestPointSampling) diff --git a/ads/common/ops/kernels/op_host/furthest_point_sampling.cpp b/ads/common/ops/kernels/op_host/furthest_point_sampling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b8e4bfc55d21e580000987341eb77024cb198427 --- /dev/null +++ b/ads/common/ops/kernels/op_host/furthest_point_sampling.cpp @@ -0,0 +1,335 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file furthest_point_sampling.cc + * \brief + */ +#include "tiling/platform/platform_ascendc.h" +#include "tiling/tiling_api.h" +#include "register/op_def_registry.h" +#include "furthest_point_sampling_tiling.h" + +using namespace ge; +using namespace std; +using namespace AscendC; + +namespace optiling { +/****************constexpr definition*****************/ +constexpr int64_t FP32_MODE = 0; + +/****************struct definition*****************/ +struct ub_memory_tag { + uint64_t ub_size = 0; + uint64_t ub_reserve = 2048; + // 8 :dev into 8 pieces(point_x, point_y, point_z, temp, distance, pointTempX, pointTempY, pointTempZ, store N data each) + uint64_t ub_data_blocks = 8; +}; + +/****************function definition*****************/ +template +inline T getSmallestMulVal(T data, T multiple) +{ + if (multiple == 0) { + return 0; + } + return ((data + multiple - 1) / multiple); +} + +template +inline T getSmallestMul(T data, T multiple) +{ + return (getSmallestMulVal(data, multiple) * multiple); +} + +/****************class definition*****************/ +class FurthestPointSamplingTiling { +public: + explicit FurthestPointSamplingTiling(gert::TilingContext* context) : TilingContext(context) {}; + ge::graphStatus Init(); + ge::graphStatus RunKernelTiling(); +private: + inline void SetTilingKeyMode(ge::DataType dType); + inline uint64_t UbBlocksDataSpace(uint64_t data_num); + inline uint64_t UbBlocksWorkSpace(uint64_t data_num); + inline uint64_t UbBlocksSpace(uint64_t data_num); + inline uint64_t FindMaxDataBlock(); + // in Kernel, we use ReduceMax requiring us to calc size of worklocal + inline uint32_t calcWorkLocalSize(uint32_t max_repeat_times); +private: + FurthestPointSamplingTilingData TilingData; + gert::TilingContext* TilingContext = nullptr; + + uint32_t coreNum; + + uint32_t batch; + uint32_t N; + uint32_t numPoints; + uint32_t pieces; + uint32_t formerNum; + uint32_t tailNum; + uint32_t workSize; + uint32_t idxTempSize; + uint32_t bigCoreBatch; + uint32_t smallCoreBatch; + uint32_t bigCoreNum; + uint32_t repeats; + + ub_memory_tag ub_memory; + + uint64_t point_dtype_size; +}; + +/****************class impl*****************/ +ge::graphStatus FurthestPointSamplingTiling::Init() +{ + const gert::StorageShape *point_xyz_shape = TilingContext->GetInputShape(0); + const gert::RuntimeAttrs *attrs = TilingContext->GetAttrs(); + uint64_t max_data_num; + + auto platformInfoPtr = TilingContext->GetPlatformInfo(); + if (platformInfoPtr == nullptr) { + return ge::GRAPH_FAILED; + } + + auto platformInfo = platform_ascendc::PlatformAscendC(platformInfoPtr); + + // Set Tiling Key + SetTilingKeyMode(TilingContext->GetInputDesc(0)->GetDataType()); + + // get core num + this->coreNum = platformInfo.GetCoreNumAiv(); + if (this->coreNum == 0) { + return ge::GRAPH_FAILED; + } + + // get ub_size,cal the capability that is aligned with 256 bytes + platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, this->ub_memory.ub_size); + + // Get input args + this->batch = point_xyz_shape->GetStorageShape().GetDim(0); + this->N = point_xyz_shape->GetStorageShape().GetDim(2); + this->numPoints = *(attrs->GetAttrPointer(0)); + + // get the capability on UB + max_data_num = FindMaxDataBlock(); // pieces, repeats, workSize calc in this func + + if (this->repeats == 0) { + return ge::GRAPH_FAILED; + } + + // Tiling Args calc + this->bigCoreBatch = getSmallestMulVal(this->batch, this->coreNum); + this->smallCoreBatch = this->batch / this->coreNum; + if (this->bigCoreBatch == this->smallCoreBatch) { + this->bigCoreNum = this->coreNum; + } else if ((this->bigCoreBatch == 1) && (this->smallCoreBatch == 0)) { + this->bigCoreNum = this->batch; + } else { + this->bigCoreNum = (this->batch - (this->smallCoreBatch * this->coreNum)) / + (this->bigCoreBatch - this->smallCoreBatch); + } + + this->formerNum = ((this->repeats * 256) / this->point_dtype_size); + this->tailNum = this->N - this->formerNum * (this->pieces - 1); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus FurthestPointSamplingTiling::RunKernelTiling() +{ + size_t sysWorkspaceSize = 16 * 1024 * 1024; // Alloc 16M workspace + size_t userWorkSpaceSize = this->batch * this->N * this->point_dtype_size; // NearestDist needs a space to be moved out + size_t *currentWorkSpace = TilingContext->GetWorkspaceSizes(1); + currentWorkSpace[0] = userWorkSpaceSize + sysWorkspaceSize; + + TilingData.set_batch(this->batch); + TilingData.set_N(this->N); + TilingData.set_numPoints(this->numPoints); + TilingData.set_pieces(this->pieces); + TilingData.set_formerNum(this->formerNum); + TilingData.set_tailNum(this->tailNum); + TilingData.set_workSize(this->workSize); + TilingData.set_idxTempSize(this->idxTempSize); + TilingData.set_bigCoreBatch(this->bigCoreBatch); + TilingData.set_smallCoreBatch(this->smallCoreBatch); + TilingData.set_bigCoreNum(this->bigCoreNum); + if (this->batch <= this->coreNum) { + TilingContext->SetBlockDim(this->batch); + } else { + TilingContext->SetBlockDim(this->coreNum); + } + TilingData.set_repeats(this->repeats); + + TilingData.SaveToBuffer(TilingContext->GetRawTilingData()->GetData(), TilingContext->GetRawTilingData()->GetCapacity()); + TilingContext->GetRawTilingData()->SetDataSize(TilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} + +inline uint32_t FurthestPointSamplingTiling::calcWorkLocalSize(uint32_t max_repeat_times) +{ + uint32_t elemPerBlock = 32 / this->point_dtype_size; // num of data one block can store + uint32_t elemPerRepeat = 256 / this->point_dtype_size; // num of data one repeat can deal + + uint32_t iter1OutputCount = max_repeat_times * 2; // num of temp data in 1st stage + uint32_t iter2AlignStart = getSmallestMul(iter1OutputCount, elemPerBlock); // align with 32 bytes + uint32_t iter2OutputCount = getSmallestMulVal(iter1OutputCount, elemPerRepeat) * 2; // num of temp data in 2nd stage + uint32_t iter3AlignStart = getSmallestMul(iter2OutputCount, elemPerBlock); // align with 32 bytes + uint32_t iter3OutputCount = getSmallestMulVal(iter1OutputCount, elemPerRepeat) * 2; // num of temp data in 3rd stage + uint32_t iter3AlignEnd = getSmallestMul(iter2OutputCount, elemPerBlock); // align with 32 bytes + + uint32_t finalWorkLocalNeedSize = iter2AlignStart + iter3AlignStart + iter3AlignEnd; + uint32_t totalBytes = finalWorkLocalNeedSize * this->point_dtype_size; + + if (totalBytes % 32 != 0) { + return ge::GRAPH_FAILED; + } + return totalBytes; +} + +inline uint64_t FurthestPointSamplingTiling::FindMaxDataBlock() +{ + // divide & conquer ==> find the capability, if bigger than N, split then cal + uint64_t M = this->ub_memory.ub_size - this->ub_memory.ub_reserve; + uint64_t low = 1; // at least there exits one data in UB + uint64_t high = M / this->point_dtype_size; + uint64_t max_data_num = 0; + + while (low <= high) { + uint64_t mid = low + (high - low) / 2; + if (UbBlocksSpace(mid) <= M) { + max_data_num = mid; + low = mid + 1; + } else { + high = mid - 1; + } + } + + return max_data_num; +} + +inline void FurthestPointSamplingTiling::SetTilingKeyMode(ge::DataType dType) +{ + switch (dType) { + case ge::DT_FLOAT: + TilingContext->SetTilingKey(FP32_MODE); + this->point_dtype_size = 4; // 4: float32, 4 bytes + break; + default: + TilingContext->SetTilingKey(FP32_MODE); + this->point_dtype_size = 4; // 4: float32, 4 bytes + break; + } +} + +inline uint64_t FurthestPointSamplingTiling::UbBlocksDataSpace(uint64_t data_num) +{ + // data type is the same among the first 5 blocks, the num is data_num, aligned with 256 bytes + return getSmallestMul(this->point_dtype_size * data_num, 256); +} + +inline uint64_t FurthestPointSamplingTiling::UbBlocksWorkSpace(uint64_t data_num) +{ + uint64_t singleBlockRepeats = getSmallestMulVal(this->point_dtype_size * data_num, 256); + uint64_t totalRepeats = getSmallestMulVal(this->point_dtype_size * this->N, 256); + + this->pieces = (uint32_t)getSmallestMulVal(totalRepeats, singleBlockRepeats); + this->repeats = (singleBlockRepeats < totalRepeats) ? ((uint32_t)singleBlockRepeats) : ((uint32_t)totalRepeats); + return (uint64_t)calcWorkLocalSize(this->repeats); +} + +inline uint64_t FurthestPointSamplingTiling::UbBlocksSpace(uint64_t data_num) +{ + uint64_t dataSpace = UbBlocksDataSpace(data_num); + uint64_t workSpace = UbBlocksWorkSpace(data_num); + + this->workSize = workSpace; + this->idxTempSize = getSmallestMul(this->pieces * 2 * this->point_dtype_size, 32); + + return (this->ub_memory.ub_data_blocks * dataSpace + workSpace + this->idxTempSize); +} + +/****************main body*****************/ +static ge::graphStatus TilingFurthestPointSampling(gert::TilingContext* context) +{ + FurthestPointSamplingTiling tilingObject(context); + + tilingObject.Init(); + return tilingObject.RunKernelTiling(); +} +} + +namespace ge { +static ge::graphStatus InfershapeForFurthestPointSampling(gert::InferShapeContext *context) +{ + const gert::Shape *point_xyz_shape = context->GetInputShape(0); + const gert::RuntimeAttrs *attrs = context->GetAttrs(); + gert::Shape *index_shape = context->GetOutputShape(0); + if ((point_xyz_shape == nullptr) || (attrs == nullptr) || (index_shape == nullptr)) { + return ge::GRAPH_FAILED; + } + + uint32_t batch = point_xyz_shape->GetDim(0); + uint32_t N = point_xyz_shape->GetDim(2); + uint32_t num_points = *(attrs->GetAttrPointer(0)); + + index_shape->SetDimNum(2); + index_shape->SetDim(0, batch); + index_shape->SetDim(1, num_points); + + return GRAPH_SUCCESS; +} +} + +namespace ops { +class FurthestPointSampling : public OpDef { +public: + explicit FurthestPointSampling(const char* name) : OpDef(name) + { + this->Input("point_xyz") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("nearest_temp") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("index") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Attr("num_points") + .AttrType(REQUIRED) + .Int(); + + this->SetInferShape(ge::InfershapeForFurthestPointSampling); + this->AICore().SetTiling(optiling::TilingFurthestPointSampling); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true); + this->AICore().AddConfig("ascend910b", aicore_config); + } +}; + +OP_ADD(FurthestPointSampling); +} \ No newline at end of file diff --git a/ads/common/ops/kernels/op_host/furthest_point_sampling_tiling.h b/ads/common/ops/kernels/op_host/furthest_point_sampling_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..3e9622889f8e020150fc0efeb9903a102dcbfad4 --- /dev/null +++ b/ads/common/ops/kernels/op_host/furthest_point_sampling_tiling.h @@ -0,0 +1,44 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file furthest_point_sampling.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_TILING_RUNTIME_FURTHEST_POINT_SAMPLING_H_ +#define OPS_BUILT_IN_OP_TILING_RUNTIME_FURTHEST_POINT_SAMPLING_H_ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(FurthestPointSamplingTilingData) + TILING_DATA_FIELD_DEF(uint32_t, N); + TILING_DATA_FIELD_DEF(uint32_t, batch); + TILING_DATA_FIELD_DEF(uint32_t, numPoints); + TILING_DATA_FIELD_DEF(uint32_t, pieces); + TILING_DATA_FIELD_DEF(uint32_t, formerNum); + TILING_DATA_FIELD_DEF(uint32_t, tailNum); + TILING_DATA_FIELD_DEF(uint32_t, workSize); + TILING_DATA_FIELD_DEF(uint32_t, idxTempSize); + TILING_DATA_FIELD_DEF(uint32_t, bigCoreBatch); + TILING_DATA_FIELD_DEF(uint32_t, smallCoreBatch); + TILING_DATA_FIELD_DEF(uint32_t, bigCoreNum); + TILING_DATA_FIELD_DEF(uint32_t, repeats); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(FurthestPointSampling, FurthestPointSamplingTilingData) +} + +#endif // OPS_BUILT_IN_OP_TILING_RUNTIME_FURTHEST_POINT_SAMPLING_H_ \ No newline at end of file diff --git a/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5f53f3607548a0ff3e28ae6d5e81cc3eaaf0d74b --- /dev/null +++ b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp @@ -0,0 +1,587 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * This file constains code of cpu debug and npu code.We read data from bin file + * and write result to file. + */ +#include "furthest_point_sampling.h" + +using namespace AscendC; + +// Entrance of kernel +extern "C" __global__ __aicore__ void furthest_point_sampling( + GM_ADDR point_xyz, + GM_ADDR temp, + GM_ADDR index, + GM_ADDR workspace, + GM_ADDR tiling) { + GET_TILING_DATA(tiling_data, tiling); + + tilingArgs TA; + // Since type of tiling_data unknown, create a class out of reliability. + TA.N = tiling_data.N; + TA.batch = tiling_data.batch; + TA.numPoints = tiling_data.numPoints; + TA.pieces = tiling_data.pieces; + TA.formerNum = tiling_data.formerNum; + TA.tailNum = tiling_data.tailNum; + TA.workSize = tiling_data.workSize; + TA.idxTempSize = tiling_data.idxTempSize; + TA.bigCoreBatch = tiling_data.bigCoreBatch; + TA.smallCoreBatch = tiling_data.smallCoreBatch; + TA.bigCoreNum = tiling_data.bigCoreNum; + TA.repeats = tiling_data.repeats; + + if (TILING_KEY_IS(0)) { + furthestPointSamplingKernel op(point_xyz, temp, index, workspace, &TA); + op.Process(); + } +} + +template +__aicore__ inline furthestPointSamplingKernel::furthestPointSamplingKernel(GM_ADDR point_xyz, + GM_ADDR temp, GM_ADDR index, GM_ADDR workspace, tilingArgs *tiling) +{ + // Init tiling args. + this->TA = tiling; + // host tiling have ensured formerNum is aligned with 32bytes and bigger than tailNum. + this->sizeofFormer = this->TA->formerNum * sizeof(dataType); + this->sizeofTail = this->TA->tailNum * sizeof(dataType); + this->dataNumIn32Bytes = 32 / sizeof(dataType); + this->dataNumIn64Bytes = 64 / sizeof(dataType); + this->dataNumIn256Bytes = 256 / sizeof(dataType); + this->dataNumIn1024Bytes = 1024 / sizeof(dataType); + // Init GM. + InitGm(point_xyz, temp, index, workspace); + + // Must be aligned with 32bytes. + this->pipe.InitBuffer(this->pointXQue, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointYQue, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointZQue, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointTempXUb, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointTempYUb, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointTempZUb, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->nearestDistQue, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->distUb, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->workUb, BUFFER_NUM, this->TA->workSize); + + this->pipe.InitBuffer(this->idxQue, BUFFER_NUM, 1024); // copy out 256 fp32s once + + this->pipe.InitBuffer(this->idxTempUb, BUFFER_NUM, this->TA->idxTempSize); + this->pipe.InitBuffer(this->pointSampled, BUFFER_NUM, 32 * 3); + // Malloc. + this->ubBlocks.pointXLocal = pointXQue.AllocTensor(); + this->ubBlocks.pointYLocal = pointYQue.AllocTensor(); + this->ubBlocks.pointZLocal = pointZQue.AllocTensor(); + this->ubBlocks.pointTempXLocal = pointTempXUb.AllocTensor(); + this->ubBlocks.pointTempYLocal = pointTempYUb.AllocTensor(); + this->ubBlocks.pointTempZLocal = pointTempZUb.AllocTensor(); + this->ubBlocks.nearestDistLocal = nearestDistQue.AllocTensor(); + this->ubBlocks.distLocal = distUb.AllocTensor(); + this->ubBlocks.workLocal = workUb.AllocTensor(); + + this->ubBlocks.idxLocal = idxQue.AllocTensor(); + + this->ubBlocks.idxTempLocal = idxTempUb.AllocTensor(); + this->ubBlocks.pointSampledLocal = pointSampled.AllocTensor(); +} + +template +__aicore__ inline void furthestPointSamplingKernel::Process() +{ + uint32_t batch_num = (GetBlockIdx() < this->TA->bigCoreNum) ? (this->TA->bigCoreBatch) : (this->TA->smallCoreBatch); + + for (this->core_batch = 0; this->core_batch < batch_num; this->core_batch++) { + this->batchOffsetPoint = this->core_batch * this->TA->N * 3; + this->batchOffsetNearest = this->core_batch * this->TA->N; + // Set:idxGm[0] = 0 + CopyInIdx(0); + if (this->TA->numPoints == 1) { + CopyOut(0); // special case: only one points sampled. + } + if (this->TA->pieces == 1) { + Process_complete_data(); + } else { + Process_split_data(); + } + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyInIdx(uint32_t loopNum) +{ + DataCopyParams data_copy_param = {1, 1, 0, 0}; + uint32_t offsetGmX = this->batchOffsetPoint + this->maxDistIdx; + uint32_t offsetGmY = offsetGmX + this->TA->N; + uint32_t offsetGmZ = offsetGmY + this->TA->N; + uint32_t offsetLocalX = 0; + uint32_t offsetLocalY = this->dataNumIn32Bytes; + uint32_t offsetLocalZ = this->dataNumIn64Bytes; + uint32_t offsetIdx = loopNum & (this->dataNumIn1024Bytes - 1); // aka. loopNum % this->dataNumIn1024Bytes + + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); + +#ifndef __GET_CODE_CHANNEL__ + DataCopy(this->ubBlocks.pointSampledLocal[offsetLocalX], pointGm[offsetGmX], data_copy_param); + DataCopy(this->ubBlocks.pointSampledLocal[offsetLocalY], pointGm[offsetGmY], data_copy_param); + DataCopy(this->ubBlocks.pointSampledLocal[offsetLocalZ], pointGm[offsetGmZ], data_copy_param); +#endif + + set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); + + this->ubBlocks.idxLocal.SetValue(offsetIdx, this->maxDistIdx); + this->pointXSampled = -1 * this->ubBlocks.pointSampledLocal.GetValue(offsetLocalX); + this->pointYSampled = -1 * this->ubBlocks.pointSampledLocal.GetValue(offsetLocalY); + this->pointZSampled = -1 * this->ubBlocks.pointSampledLocal.GetValue(offsetLocalZ); + this->maxDist = 0; + this->maxDistIdx = 0; +} + +template +__aicore__ inline void furthestPointSamplingKernel::Process_complete_data() +{ + uint32_t loopNum; + + for (loopNum = 1; loopNum < this->TA->numPoints; loopNum++) { + if (loopNum == 1) { + Process_first_sampling(0); + } else { + ComputePointsSquare(); + + pipe_barrier(PIPE_V); + + ComputeDist(); + + pipe_barrier(PIPE_V); + + ComputeSamplePoints(0, 0); + } + pipe_barrier(PIPE_V); + + updateDist(); + + CopyInIdx(loopNum); + + CopyOut(loopNum); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::Process_split_data() +{ + uint32_t loopNum, loopSplit; + + for (loopNum = 1; loopNum < this->TA->numPoints; loopNum++) { + for (loopSplit = 0; loopSplit < this->TA->pieces; loopSplit++) { + if (loopNum == 1) { + Process_first_sampling(loopSplit); + } else { + uint32_t comBlock = (loopSplit + this->TA->pieces - 1) % this->TA->pieces; + + // Cal point_x -> Mov point_x, Cal point_y -> Mov point_y, Cal point_z -> Mov point_z + ComputePointDeltaSquare(this->ubBlocks.pointXLocal, this->ubBlocks.pointTempXLocal, this->pointXSampled); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + + CopyInPointAxis(pointAxis_x, loopSplit); + + ComputePointDeltaSquare(this->ubBlocks.pointYLocal, this->ubBlocks.pointTempYLocal, this->pointYSampled); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + + CopyInPointAxis(pointAxis_y, loopSplit); + + ComputePointDeltaSquare(this->ubBlocks.pointZLocal, this->ubBlocks.pointTempZLocal, this->pointZSampled); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + + CopyInPointAxis(pointAxis_z, loopSplit); + + pipe_barrier(PIPE_V); + + ComputeDist(); + + pipe_barrier(PIPE_V); + + ComputeSamplePoints(loopSplit, comBlock); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + + CopyInNearestDistTemp(loopSplit); + } + } + pipe_barrier(PIPE_V); + + updateDist(); + + CopyInIdx(loopNum); + + CopyOut(loopNum); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::Process_first_sampling(uint32_t loopSplit) +{ + // Mov point_x -> Cal point_x, Mov point_y -> Cal point_y, Mov point_z -> Cal point_z + CopyInPointAxis(pointAxis_x, loopSplit); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + ComputePointDeltaSquare(this->ubBlocks.pointXLocal, this->ubBlocks.pointTempXLocal, this->pointXSampled); + + CopyInPointAxis(pointAxis_y, loopSplit); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + + ComputePointDeltaSquare(this->ubBlocks.pointYLocal, this->ubBlocks.pointTempYLocal, this->pointYSampled); + + CopyInPointAxis(pointAxis_z, loopSplit); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + + ComputePointDeltaSquare(this->ubBlocks.pointZLocal, this->ubBlocks.pointTempZLocal, this->pointZSampled); + + pipe_barrier(PIPE_V); + + ComputeDist(); + + CopyInNearestDist(loopSplit); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + + ComputeSamplePoints(loopSplit, loopSplit); +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyInPointAxis(PointAxis pointAxis, + uint32_t loopSplit) +{ + uint32_t offset; + DataCopyParams data_copy_param = {1, 0, 0, 0}; + DataCopyPadParams pad_param = {false, 0, 0, 0}; + + if (loopSplit == (this->TA->pieces - 1)) { + data_copy_param.blockLen = this->sizeofTail; + } else { + data_copy_param.blockLen = this->sizeofFormer; + } + switch (pointAxis) { + case pointAxis_x: + offset = this->batchOffsetPoint + this->TA->formerNum * loopSplit; + break; + case pointAxis_y: + offset = this->batchOffsetPoint + this->TA->formerNum * loopSplit + this->TA->N; + break; + case pointAxis_z: + offset = this->batchOffsetPoint + this->TA->formerNum * loopSplit + this->TA->N * 2; + break; + default: + break; + } + + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID1); + + switch (pointAxis) { + case pointAxis_x: +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.pointXLocal, pointGm[offset], data_copy_param, pad_param); +#endif + break; + case pointAxis_y: +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.pointYLocal, pointGm[offset], data_copy_param, pad_param); +#endif + break; + case pointAxis_z: +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.pointZLocal, pointGm[offset], data_copy_param, pad_param); +#endif + break; + default: + break; + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyInNearestDist(uint32_t loopSplit) +{ + uint32_t offset = this->batchOffsetNearest + this->TA->formerNum * loopSplit; + DataCopyParams data_copy_param = {1, 0, 0, 0}; + DataCopyPadParams pad_param = {false, 0, 0, 0}; + + if (loopSplit == (this->TA->pieces - 1)) { + data_copy_param.blockLen = this->sizeofTail; + } else { + data_copy_param.blockLen = this->sizeofFormer; + } + + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID2); + +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.nearestDistLocal, nearestDistGm[offset], data_copy_param, pad_param); +#endif +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyInNearestDistTemp(uint32_t loopSplit) +{ + uint32_t offset = this->batchOffsetNearest + this->TA->formerNum * loopSplit; + DataCopyParams data_copy_param = {1, 0, 0, 0}; + DataCopyPadParams pad_param = {false, 0, 0, 0}; + + if (loopSplit == (this->TA->pieces - 1)) { + data_copy_param.blockLen = this->sizeofTail; + } else { + data_copy_param.blockLen = this->sizeofFormer; + } + + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID2); + +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.nearestDistLocal, nearestDistTempGm[offset], data_copy_param, pad_param); +#endif +} + +template +__aicore__ inline void furthestPointSamplingKernel::ComputePointsSquare() +{ + uint32_t total_num, dupTime, offset, comp_num; + + // while cal,every data block is aligned with 256 bytes. + for (offset = 0, total_num = this->TA->formerNum; total_num > 0; + comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { + dupTime = (total_num * sizeof(dataType)) / 256; + dupTime = (dupTime > 255) ? 255 : dupTime; + + set_flag(PIPE_S, PIPE_V, EVENT_ID3); + wait_flag(PIPE_S, PIPE_V, EVENT_ID3); + + Adds(this->ubBlocks.pointTempXLocal[offset], this->ubBlocks.pointXLocal[offset], this->pointXSampled, + this->dataNumIn256Bytes, dupTime, {1, 1, 8, 8}); + Adds(this->ubBlocks.pointTempYLocal[offset], this->ubBlocks.pointYLocal[offset], this->pointYSampled, + this->dataNumIn256Bytes, dupTime, {1, 1, 8, 8}); + Adds(this->ubBlocks.pointTempZLocal[offset], this->ubBlocks.pointZLocal[offset], this->pointZSampled, + this->dataNumIn256Bytes, dupTime, {1, 1, 8, 8}); + + pipe_barrier(PIPE_V); + + Mul(this->ubBlocks.pointTempXLocal[offset], this->ubBlocks.pointTempXLocal[offset], + this->ubBlocks.pointTempXLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + Mul(this->ubBlocks.pointTempYLocal[offset], this->ubBlocks.pointTempYLocal[offset], + this->ubBlocks.pointTempYLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + Mul(this->ubBlocks.pointTempZLocal[offset], this->ubBlocks.pointTempZLocal[offset], + this->ubBlocks.pointTempZLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::ComputePointDeltaSquare( + LocalTensor &pointLocal, LocalTensor &pointTempLocal, dataType pointSampled) +{ + uint32_t total_num, dupTime, offset, comp_num; + + // while cal,every data block is aligned with 256 bytes. + for (offset = 0, total_num = this->TA->formerNum; total_num > 0; + comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { + dupTime = (total_num * sizeof(dataType)) / 256; + dupTime = (dupTime > 255) ? 255 : dupTime; + + set_flag(PIPE_S, PIPE_V, EVENT_ID3); + wait_flag(PIPE_S, PIPE_V, EVENT_ID3); + + Adds(pointTempLocal[offset], pointLocal[offset], pointSampled, this->dataNumIn256Bytes, + dupTime, {1, 1, 8, 8}); + + pipe_barrier(PIPE_V); + + Mul(pointTempLocal[offset], pointTempLocal[offset], pointTempLocal[offset], this->dataNumIn256Bytes, + dupTime, {1, 1, 1, 8, 8, 8}); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::ComputeDist() +{ + uint32_t total_num, dupTime, offset, comp_num; + + // while cal,every data block is aligned with 256 bytes. + for (offset = 0, total_num = this->TA->formerNum; total_num > 0; + comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { + dupTime = (total_num * sizeof(dataType)) / 256; + dupTime = (dupTime > 255) ? 255 : dupTime; + + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + + Add(this->ubBlocks.distLocal[offset], this->ubBlocks.pointTempXLocal[offset], + this->ubBlocks.pointTempYLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + + pipe_barrier(PIPE_V); + + Add(this->ubBlocks.distLocal[offset], this->ubBlocks.distLocal[offset], + this->ubBlocks.pointTempZLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::ComputeSamplePoints(uint32_t loopSplit, + uint32_t comBlock) +{ + uint32_t total_num, dupTime, offset, comp_num, reduceCnt, reduceOffset; + + reduceCnt = ((this->TA->formerNum != this->TA->tailNum) && (comBlock == (this->TA->pieces - 1))) ? + this->TA->tailNum : this->TA->formerNum; + reduceOffset = comBlock * 2; + + for (offset = 0, total_num = this->TA->formerNum; total_num > 0; + comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { + dupTime = (total_num * sizeof(dataType)) / 256; + dupTime = (dupTime > 255) ? 255 : dupTime; + + set_flag(PIPE_S, PIPE_V, EVENT_ID1); + wait_flag(PIPE_S, PIPE_V, EVENT_ID1); + + Min(this->ubBlocks.nearestDistLocal[offset], this->ubBlocks.nearestDistLocal[offset], + this->ubBlocks.distLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + } + + if (this->TA->pieces > 1) { + // set_flag: After Updated nearestDistLocal, Mov nearestDistLocal to GM. + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + CopyOutNearestDistTemp(comBlock); + } + + pipe_barrier(PIPE_V); + + // ReduceMax + ReduceMax(this->ubBlocks.idxTempLocal[reduceOffset], this->ubBlocks.nearestDistLocal, + this->ubBlocks.workLocal, reduceCnt, 1); +} + +template +__aicore__ inline void furthestPointSamplingKernel::updateDist() +{ + dataType tempValue; + + // this->TA->pieces >= 1 + for (uint32_t i = 1; i < (2 * this->TA->pieces); i = (i + 2)) { + tempValue = this->ubBlocks.idxTempLocal.GetValue(i); + if (this->maxDist < this->ubBlocks.idxTempLocal.GetValue(i-1)) { + this->maxDist = this->ubBlocks.idxTempLocal.GetValue(i-1); + this->maxDistIdx = (this->TA->formerNum * (i / 2)) + (*reinterpret_cast(&tempValue)); + } + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyOut(uint32_t loopNum) +{ + uint32_t elemNum = this->dataNumIn1024Bytes; + // elemNum is a multiple of 2. + if ((loopNum != 0) && (((loopNum + 1) & (elemNum - 1)) != 0) && ((loopNum + 1) != this->TA->numPoints)) { + // when num of sampled < 256 && not last loop, return; + return ; + } + + uint32_t offset = this->core_batch * this->TA->numPoints; + DataCopyExtParams data_copy_param = {1, sizeof(dataType), 0, 0, 0}; + if (((loopNum + 1) & (elemNum - 1)) == 0) { + data_copy_param.blockLen = 1024; + offset = offset + loopNum / elemNum * elemNum; + } else if ((loopNum + 1) == this->TA->numPoints) { + data_copy_param.blockLen = sizeof(dataType) * + (this->TA->numPoints - (this->TA->numPoints / elemNum * elemNum)); + offset = offset + (this->TA->numPoints / elemNum * elemNum); + } + + set_flag(PIPE_S, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID0); + +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(idxGm[offset], this->ubBlocks.idxLocal, data_copy_param); +#endif +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyOutNearestDistTemp(uint32_t loopSplit) +{ + uint32_t offset = this->batchOffsetNearest + this->TA->formerNum * loopSplit; + DataCopyExtParams data_copy_param = {1, 0, 0, 0, 0}; + + if (loopSplit == (this->TA->pieces - 1)) { + data_copy_param.blockLen = this->sizeofTail; + } else { + data_copy_param.blockLen = this->sizeofFormer; + } + + set_flag(PIPE_S, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID1); + +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(nearestDistTempGm[offset], this->ubBlocks.nearestDistLocal, data_copy_param); +#endif +} + +template +__aicore__ inline void furthestPointSamplingKernel::InitGm(GM_ADDR point_xyz, GM_ADDR temp, + GM_ADDR index, GM_ADDR workspace) +{ + GM_ADDR usrWorkspace = AscendC::GetUserWorkspace(workspace); + uint32_t coreId = GetBlockIdx(); + uint32_t skipData, numData, skipIdx, numIdx; + uint32_t numDataBigCore = this->TA->bigCoreBatch * this->TA->N; + uint32_t numIdxBigCore = this->TA->bigCoreBatch * this->TA->numPoints; + + if (coreId < this->TA->bigCoreNum) { + numData = numDataBigCore; + numIdx = numIdxBigCore; + skipData = numData * coreId; + skipIdx = numIdx * coreId; + } else { + numData = this->TA->smallCoreBatch * this->TA->N; + numIdx = this->TA->smallCoreBatch * this->TA->numPoints; + skipData = this->TA->bigCoreNum * numDataBigCore + (coreId - this->TA->bigCoreNum) * numData; + skipIdx = this->TA->bigCoreNum * numIdxBigCore + (coreId - this->TA->bigCoreNum) * numIdx; + } + + this->pointGm.SetGlobalBuffer((__gm__ dataType*)point_xyz + skipData * 3, numData * 3); + this->nearestDistGm.SetGlobalBuffer((__gm__ dataType*)temp + skipData, numData); + this->idxGm.SetGlobalBuffer((__gm__ idxType*)index + skipIdx, numIdx); + this->nearestDistTempGm.SetGlobalBuffer((__gm__ dataType*)usrWorkspace + skipData, numData); +} + +template +__aicore__ inline furthestPointSamplingKernel::~furthestPointSamplingKernel() +{ + this->pointXQue.FreeTensor(this->ubBlocks.pointXLocal); + this->pointYQue.FreeTensor(this->ubBlocks.pointYLocal); + this->pointZQue.FreeTensor(this->ubBlocks.pointZLocal); + this->pointTempXUb.FreeTensor(this->ubBlocks.pointTempXLocal); + this->pointTempYUb.FreeTensor(this->ubBlocks.pointTempYLocal); + this->pointTempZUb.FreeTensor(this->ubBlocks.pointTempZLocal); + this->nearestDistQue.FreeTensor(this->ubBlocks.nearestDistLocal); + this->distUb.FreeTensor(this->ubBlocks.distLocal); + this->workUb.FreeTensor(this->ubBlocks.workLocal); + + this->idxQue.FreeTensor(this->ubBlocks.idxLocal); + + this->idxTempUb.FreeTensor(this->ubBlocks.idxTempLocal); + this->pointSampled.FreeTensor(this->ubBlocks.pointSampledLocal); +} \ No newline at end of file diff --git a/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h new file mode 100644 index 0000000000000000000000000000000000000000..2fdb36df87b7a26730d80632870adfbb78baa0e6 --- /dev/null +++ b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * This file constains code of cpu debug and npu code.We read data from bin file + * and write result to file. + */ +#ifndef FURTHEST_POINT_SAMPLING_H +#define FURTHEST_POINT_SAMPLING_H + +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_operator.h" + +namespace AscendC { +constexpr uint32_t BUFFER_NUM = 1u; + +enum PointAxis { + pointAxis_x, + pointAxis_y, + pointAxis_z +}; + +template +struct UbBlocks_tag { + __aicore__ UbBlocks_tag() = default; + + LocalTensor pointXLocal; + LocalTensor pointYLocal; + LocalTensor pointZLocal; + LocalTensor pointTempXLocal; + LocalTensor pointTempYLocal; + LocalTensor pointTempZLocal; + LocalTensor nearestDistLocal; + LocalTensor distLocal; + LocalTensor idxLocal; + LocalTensor idxTempLocal; + LocalTensor pointSampledLocal; + LocalTensor workLocal; +}; +template +using UbBlocks = UbBlocks_tag; + +class tilingArgs { +public: + __aicore__ inline tilingArgs() = default; +public: + uint32_t N; + uint32_t batch; + uint32_t numPoints; + uint32_t pieces; + uint32_t formerNum; + uint32_t tailNum; + uint32_t workSize; + uint32_t idxTempSize; + uint32_t bigCoreBatch; + uint32_t smallCoreBatch; + uint32_t bigCoreNum; + uint32_t repeats; +}; + +template +class furthestPointSamplingKernel { +public: + __aicore__ inline furthestPointSamplingKernel(GM_ADDR point_xyz, GM_ADDR temp, GM_ADDR index, GM_ADDR workspace, + tilingArgs *tiling); + __aicore__ inline ~furthestPointSamplingKernel(); + __aicore__ inline void Process(); + +private: + __aicore__ inline void Process_first_sampling(uint32_t loopSplit = 0); + __aicore__ inline void Process_split_data(); + __aicore__ inline void Process_complete_data(); + +private: + __aicore__ inline void CopyInPointAxis(PointAxis pointAxis, uint32_t loopSplit = 0); + __aicore__ inline void CopyInNearestDist(uint32_t loopSplit = 0); + __aicore__ inline void CopyInNearestDistTemp(uint32_t loopSplit = 0); + __aicore__ inline void CopyInIdx(uint32_t loopNum); + __aicore__ inline void CopyOut(uint32_t loopNum); + __aicore__ inline void CopyOutNearestDistTemp(uint32_t loopSplit = 0); + +private: + __aicore__ inline void ComputePointsSquare(); + __aicore__ inline void ComputePointDeltaSquare(LocalTensor &pointLocal, + LocalTensor &pointTempLocal, dataType pointSampled); + __aicore__ inline void ComputeDist(); + __aicore__ inline void ComputeSamplePoints(uint32_t loopSplit, uint32_t ComBlock); + __aicore__ inline void updateDist(); + +private: + __aicore__ inline void InitGm(GM_ADDR point_xyz, GM_ADDR temp, GM_ADDR index, GM_ADDR workspace); + +private: + TPipe pipe; + TQue pointXQue; + TQue pointYQue; + TQue pointZQue; + TQue pointTempXUb; + TQue pointTempYUb; + TQue pointTempZUb; + TQue nearestDistQue; + TQue distUb; + TQue workUb; + + TQue idxQue; + + TQue idxTempUb; + TQue pointSampled; + +private: + GlobalTensor pointGm; + GlobalTensor nearestDistGm; + GlobalTensor idxGm; + GlobalTensor nearestDistTempGm; + UbBlocks ubBlocks; + +private: + dataType pointXSampled {0}; + dataType pointYSampled {0}; + dataType pointZSampled {0}; + dataType maxDist {0}; + idxType maxDistIdx {0}; + uint32_t core_batch; + +private: + // tiling value + tilingArgs *TA; + +private: + uint32_t sizeofFormer; + uint32_t sizeofTail; + uint32_t dataNumIn32Bytes; + uint32_t dataNumIn64Bytes; + uint32_t dataNumIn256Bytes; + uint32_t dataNumIn1024Bytes; + uint32_t batchOffsetPoint; + uint32_t batchOffsetNearest; +}; +} + +#endif // FURTHEST_POINT_SAMPLING_H \ No newline at end of file diff --git a/tests/test_furthest_point_sampling.py b/tests/test_furthest_point_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..192b54c3efc6fbc242fd8d1158a0d1c54987a871 --- /dev/null +++ b/tests/test_furthest_point_sampling.py @@ -0,0 +1,126 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from abc import ABC, abstractmethod +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] + + +class CreateBenchMarkTest(ABC): + def __init__(self): + self.batch = None + self.N = None + self.numPoints = None + + self.point = None + self.nearestDist = None + + @abstractmethod + def createData(self): + pass + + def compare_min(self, a): + if a[0] > a[1]: + return a[1] + else : + return a[0] + + def getCpuRes(self): + cpuRes = np.zeros([self.batch, self.numPoints], dtype=np.int32) + nearestDistCopy = self.nearestDist.copy() + + for i in range(self.batch): + sampled = 1 + index = 0 + while sampled < self.numPoints: + deltaX = self.point[i][0] - self.point[i][0][index] + deltaY = self.point[i][1] - self.point[i][1][index] + deltaZ = self.point[i][2] - self.point[i][2][index] + deltaX2 = deltaX * deltaX + deltaY2 = deltaY * deltaY + deltaZ2 = deltaZ * deltaZ + currentDist = deltaX2 + deltaY2 + deltaZ2 + + nearestDistCopy[i] = np.apply_along_axis(self.compare_min, 0, np.stack((currentDist, nearestDistCopy[i]), axis=0)) + index = np.argmax(nearestDistCopy[i]) + cpuRes[i][sampled] = index + sampled = sampled + 1 + return cpuRes + + +class Test1(CreateBenchMarkTest): + def createData(self): + self.batch = 47 + self.N = 717 + self.numPoints = 580 + + self.point = np.zeros([self.batch, 3, self.N], dtype=np.float32) + for i in range(self.batch): + for j in range(self.N): + self.point[i, 0, j] = j + + self.nearestDist = np.ones([self.batch, self.N], dtype=np.float32) * 1e10 + self.point = torch.from_numpy(self.point) + + +class Test2(CreateBenchMarkTest): + def createData(self): + self.batch = 193 + self.N = 579 + self.numPoints = 123 + + self.point = np.zeros([self.batch, 3, self.N], dtype=np.float32) + for i in range(self.batch): + for j in range(self.N): + self.point[i, 0, j] = j + self.point[i, 1, j] = j + 1 + self.point[i, 2, j] = j + 3 + + self.nearestDist = np.ones([self.batch, self.N], dtype=np.float32) * 1e10 + self.point = torch.from_numpy(self.point) + + +test1 = Test1() +test2 = Test2() + + +class TestFurthestPointSample(TestCase): + def cpu_op_exec(self, myTest): + return myTest.getCpuRes() + + def npu_op_exec(self, myTest): + return ads.common.npu_furthest_point_sampling(myTest.point.npu(), myTest.numPoints) + + def compare_res(self, myTest): + myTest.createData() + cpuOutput = self.cpu_op_exec(myTest) + npuOutput = self.npu_op_exec(myTest) + self.assertRtolEqual(cpuOutput, npuOutput) + + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `FurthestPointSampling` is only for 910B, skip it.") + def test_FurthestPointSample(self): + self.compare_res(test1) + self.compare_res(test2) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file