diff --git a/ads/common/__init__.py b/ads/common/__init__.py index b88f4d130b86415b1be11840f0d18c75be1a9048..3c569e16b1b30519dbb2a7e655853c6e4053efce 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -22,4 +22,5 @@ from .ops.npu_batch_nms import npu_batch_nms from .ops.npu_confusion_transpose import npu_confusion_transpose from .ops.npu_broadcast import npu_broadcast from .ops.npu_moe_tutel import npu_moe_tutel +from .ops.npu_dynamic_scatter import npu_dynamic_scatter from .ops.ads_add import npu_ads_add diff --git a/ads/common/ops/csrc/DynamicScatterKernelNpuOpApi.cpp b/ads/common/ops/csrc/DynamicScatterKernelNpuOpApi.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c2ccb7a2c32693acc1faff686d254ec2141b136d --- /dev/null +++ b/ads/common/ops/csrc/DynamicScatterKernelNpuOpApi.cpp @@ -0,0 +1,70 @@ +#include +#include +#include "torch_npu/csrc/framework/OpCommand.h" +#include "torch_npu/csrc/framework/utils/OpPreparation.h" +#include "torch_npu/csrc/framework/utils/NpuUtils.h" +#include "torch_npu/csrc/aten/NPUNativeFunctions.h" +#include "torch_npu/csrc/aten/CustomFunctions.h" +#include "functions.h" +#include "common.h" +#include "OpApiCommon.h" + +using npu_preparation = at_npu::native::OpPreparation; +using torch::autograd::Function; +using torch::autograd::AutogradContext; +using tensor_tuple = std::tuple; + +namespace { +inline void npu_dynamic_scatter_check( + int64_t reduce_type) +{ + TORCH_CHECK(reduce_type == 0 || reduce_type == 1 || reduce_type == 2, + "reduce_type must be 0(sum) or 1(mean) or 2(max)."); +} +} // namespace + +static std::map REDUCE_TYPE_MAP = {{0, "sum"}, {1, "mean"}, {2, "max"}}; + +std::tuple npu_dynamic_scatter( + const at::Tensor &feats, + const at::Tensor &coors, + int64_t reduce_type) +{ + npu_dynamic_scatter_check(reduce_type); + auto num_input = feats.size(0); + auto num_feats = feats.size(1); + if (num_input == 0) { + return {feats.clone().detach(), coors.clone().detach(), + coors.new_empty({0}, at::kInt), coors.new_empty({0}, at::kInt)}; + } + + auto coors_clean = coors.masked_fill(coors.lt(0).any(-1, true), -1); + + at::Tensor out_coors_cpu; + at::Tensor coors_map_cpu; + at::Tensor reduce_count_cpu; + at::Tensor coors_clean_cpu = coors_clean.to("cpu"); + std::tie(out_coors_cpu, coors_map_cpu, reduce_count_cpu) = at::unique_dim(coors_clean_cpu, 0, true, true, true); + if (out_coors_cpu[0][0].lt(0).item()) { + out_coors_cpu = out_coors_cpu.slice(0, 1); + reduce_count_cpu = reduce_count_cpu.slice(0, 1); + coors_map_cpu = coors_map_cpu - 1; + } + coors_map_cpu = coors_map_cpu.to(at::kInt); + reduce_count_cpu = reduce_count_cpu.to(at::kInt); + auto npuDevice = coors.device(); + at::Tensor out_coors = out_coors_cpu.to(npuDevice); + at::Tensor coors_map = coors_map_cpu.to(npuDevice); + at::Tensor reduce_count = reduce_count_cpu.to(npuDevice); + + auto reduced_feats = at::empty({out_coors.size(0), num_feats}, feats.options()); + + const char *reduce_type_string = const_cast(REDUCE_TYPE_MAP[reduce_type] == "max" ? "max" : "sum"); + EXEC_NPU_CMD(aclnnDynamicScatter, feats, coors_map, reduce_type_string, reduced_feats); + + if (reduce_type == 1) { + reduced_feats /= reduce_count.unsqueeze(-1).to(reduced_feats.dtype()); + } + + return {reduced_feats, out_coors, coors_map, reduce_count}; +} diff --git a/ads/common/ops/csrc/functions.h b/ads/common/ops/csrc/functions.h index 1384a1538dd30108b58583a791c79ad0bf2690b1..b666ed1479d488ccf753016053d456de419b92b6 100644 --- a/ads/common/ops/csrc/functions.h +++ b/ads/common/ops/csrc/functions.h @@ -116,6 +116,10 @@ at::Tensor npu_conv_transpose2d( int64_t groups); at::Tensor npu_broadcast(const at::Tensor& self, at::IntArrayRef size); at::Tensor& npu_broadcast_out(const at::Tensor& self, at::IntArrayRef size, at::Tensor& result); +std::tuple npu_dynamic_scatter( + const at::Tensor &feats, + const at::Tensor &coors, + int64_t reduce_type); at::Tensor npu_moe_tutel( const at::Tensor &self, const at::Tensor &gates, diff --git a/ads/common/ops/csrc/pybind.cpp b/ads/common/ops/csrc/pybind.cpp index 5638ac6b09c986dba4e68e9fe99862475dbaa323..0f7e4275723822a01fbb8e3c28f978ea7615d8c5 100644 --- a/ads/common/ops/csrc/pybind.cpp +++ b/ads/common/ops/csrc/pybind.cpp @@ -71,6 +71,10 @@ void init_common(pybind11::module &m) m.def("npu_moe_tutel", &npu_moe_tutel, "npu_moe_tutel NPU version"); m.def("npu_moe_tutel_data_backward", &npu_moe_tutel_data_backward, "npu_moe_tutel_data_backward NPU version"); m.def("npu_moe_tutel_gate_backward", &npu_moe_tutel_gate_backward, "npu_moe_tutel_gate_backward NPU version"); + + // npu_dynamic_scatter + m.def("npu_dynamic_scatter", &npu_dynamic_scatter, "npu_dynamic_scatter NPU version"); + // ads_add m.def("npu_ads_add", &npu_ads_add); } diff --git a/ads/common/ops/kernels/ads_op/CMakePresets.json b/ads/common/ops/kernels/ads_op/CMakePresets.json index add05853f7195befe9689580c8aaf0e3f02ce381..a23c07b8bf823cc052ddf980a835408a9e3b918a 100644 --- a/ads/common/ops/kernels/ads_op/CMakePresets.json +++ b/ads/common/ops/kernels/ads_op/CMakePresets.json @@ -27,7 +27,7 @@ }, "ASCEND_COMPUTE_UNIT": { "type": "STRING", - "value": "ascend310p;ascend910;ascend910b" + "value": "ascend910b" }, "ENABLE_TEST": { "type": "BOOL", diff --git a/ads/common/ops/kernels/ads_op/op_host/dynamic_scatter.cpp b/ads/common/ops/kernels/ads_op/op_host/dynamic_scatter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e6b30b14eb81fef6ecac7939f9f26b4fe2ea7b34 --- /dev/null +++ b/ads/common/ops/kernels/ads_op/op_host/dynamic_scatter.cpp @@ -0,0 +1,211 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + */ +#include "dynamic_scatter_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" +#include "tiling/tiling_api.h" + +using namespace ge; +using namespace std; +using namespace AscendC; + +namespace optiling { +constexpr uint32_t BYTE_BLOCK = 32; +constexpr uint32_t SIZE_OF_FP16 = 2; +constexpr uint32_t SIZE_OF_FP32 = 4; +constexpr uint32_t DIM_INDEX0 = 0; +constexpr uint32_t DIM_INDEX1 = 1; +constexpr uint32_t BYTES_PER_DATA = 20; +constexpr int KEY_FP16 = 0; +constexpr int KEY_FP32 = 1; +static std::map REDUCE_TYPE_MAP = {{"max", 0}, {"sum", 1}}; +class DynamicScatterTiling { +public: + explicit DynamicScatterTiling(gert::TilingContext* context) : tilingContext(context){}; + ge::graphStatus Init(); + ge::graphStatus RunKernelTiling(); + +private: + void SetTilingKeyMode(ge::DataType dType, uint32_t reduceTypeNum) const; + uint32_t GetNeedCoreNum(const uint32_t coreNumPlatform) const; + void CalTilingAligned(ge::DataType dType); + +private: + DynamicScatterTilingData tilingData; + gert::TilingContext* tilingContext = nullptr; + uint32_t pointNum; + uint32_t featsNum; + uint32_t coreNum; + uint32_t totalLength = 1; // the length of input + uint32_t formerNum; // deal more data core num + uint32_t tailNum; // deal less data core num + uint32_t formerLength; // deal more data length + uint32_t tailLength; // deal less data length + uint32_t alignNum; // data count per block + uint32_t totalLengthAligned; // length to align 32B + uint32_t outPointNum; + uint32_t outPointNumAligned; + uint32_t featsAligned; + uint32_t formerInputNum; + uint32_t tailInputNum; + uint32_t tileLength; + uint64_t ubSizePlatForm; +}; + +void DynamicScatterTiling::SetTilingKeyMode(ge::DataType dType, uint32_t reduceTypeNum) const +{ + switch (dType) { + case ge::DT_FLOAT: + tilingContext->SetTilingKey(KEY_FP32 * 100 + reduceTypeNum); + break; + case ge::DT_FLOAT16: + tilingContext->SetTilingKey(KEY_FP16 * 100 + reduceTypeNum); + break; + default: + tilingContext->SetTilingKey(100); + break; + } +} + +uint32_t DynamicScatterTiling::GetNeedCoreNum(const uint32_t coreNumPlatform) const +{ + uint32_t tempCoreNum = pointNum; + if (tempCoreNum == 0) { + tempCoreNum = 1; + } + if (tempCoreNum < coreNumPlatform) { + return tempCoreNum; + } else { + return coreNumPlatform; + } +} + +void DynamicScatterTiling::CalTilingAligned(ge::DataType dType) +{ + alignNum = BYTE_BLOCK / SIZE_OF_FP32; + if (dType == ge::DT_FLOAT16) { + alignNum = BYTE_BLOCK / SIZE_OF_FP16; + } + tileLength = ubSizePlatForm / BYTES_PER_DATA; + tileLength = tileLength / (featsNum * alignNum) * (featsNum * alignNum); + featsAligned = (featsNum + alignNum - 1) / alignNum * alignNum; + tailInputNum = pointNum / coreNum; + formerNum = pointNum % coreNum; + tailNum = coreNum - formerNum; + formerInputNum = formerNum > 0 ? tailInputNum + 1 : tailInputNum; + outPointNumAligned = (outPointNum + alignNum - 1) / alignNum * alignNum; + formerLength = formerInputNum * featsNum; + tailLength = tailInputNum * featsNum; + totalLengthAligned = 0; +} + +ge::graphStatus DynamicScatterTiling::Init() +{ + size_t sysWorkspaceSize = 16 * 1024 * 1024; + size_t *currentWorkSpace = tilingContext->GetWorkspaceSizes(1); + currentWorkSpace[0] = sysWorkspaceSize; + auto featsShape = tilingContext->GetInputShape(0)->GetStorageShape(); + pointNum = featsShape.GetDim(DIM_INDEX0); + featsNum = featsShape.GetDim(DIM_INDEX1); + totalLength = featsShape.GetShapeSize(); + auto reducedFeatsShape = tilingContext->GetOutputShape(0)->GetStorageShape(); + outPointNum = reducedFeatsShape.GetDim(DIM_INDEX0); + + auto platformInfo = tilingContext->GetPlatformInfo(); + if (platformInfo == nullptr) { + return ge::GRAPH_FAILED; + } + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + coreNum = ascendcPlatform.GetCoreNumAiv(); + if (coreNum == 0) { + return ge::GRAPH_FAILED; + } + coreNum = GetNeedCoreNum(coreNum); + ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSizePlatForm); + + const char* reduceTypePtr = tilingContext->GetAttrs()->GetAttrPointer(DIM_INDEX0); + std::string reduceType(reduceTypePtr); + if (reduceType != "max" && reduceType != "sum") { + return ge::GRAPH_PARAM_INVALID; + } + auto dType = tilingContext->GetInputDesc(0)->GetDataType(); + SetTilingKeyMode(dType, REDUCE_TYPE_MAP[reduceType]); + CalTilingAligned(dType); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus DynamicScatterTiling::RunKernelTiling() +{ + tilingContext->SetBlockDim(coreNum); + tilingData.set_totalLength(totalLength); + tilingData.set_formerNum(formerNum); + tilingData.set_tailNum(tailNum); + tilingData.set_formerLength(formerLength); + tilingData.set_tailLength(tailLength); + tilingData.set_alignNum(alignNum); + tilingData.set_totalLengthAligned(totalLengthAligned); + tilingData.set_formerInputNum(formerInputNum); + tilingData.set_tailInputNum(tailInputNum); + tilingData.set_featsNum(featsNum); + tilingData.set_outPointNum(outPointNum); + tilingData.set_outPointNumAligned(outPointNumAligned); + tilingData.set_featsAligned(featsAligned); + tilingData.set_tileLength(tileLength); + tilingData.SaveToBuffer(tilingContext->GetRawTilingData()->GetData(), + tilingContext->GetRawTilingData()->GetCapacity()); + tilingContext->GetRawTilingData()->SetDataSize(tilingData.GetDataSize()); + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus TilingDynamicScatter(gert::TilingContext* context) +{ + DynamicScatterTiling tilingObject(context); + tilingObject.Init(); + return tilingObject.RunKernelTiling(); +} +} // optiling + + +namespace ge { +static ge::graphStatus InferShape(gert::InferShapeContext* context) +{ + const gert::Shape* featShape = context->GetInputShape(0); + + gert::Shape* outShape = context->GetOutputShape(0); + outShape->SetDim(0, -1); + outShape->SetDim(1, featShape->GetDim(1)); + return GRAPH_SUCCESS; +} +} // ge + + +namespace ops { +class DynamicScatter : public OpDef { +public: + explicit DynamicScatter(const char* name) : OpDef(name) + { + this->Input("feats") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("coors_map") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("reduced_feats") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Attr("reduce_type").AttrType(REQUIRED).String("max"); + this->SetInferShape(ge::InferShape); + this->AICore().SetTiling(optiling::TilingDynamicScatter); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(DynamicScatter); +} diff --git a/ads/common/ops/kernels/ads_op/op_host/dynamic_scatter_tiling.h b/ads/common/ops/kernels/ads_op/op_host/dynamic_scatter_tiling.h new file mode 100644 index 0000000000000000000000000000000000000000..8503cd128b062145d222a7413c9600bb2f53f0cc --- /dev/null +++ b/ads/common/ops/kernels/ads_op/op_host/dynamic_scatter_tiling.h @@ -0,0 +1,28 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + */ +#ifndef DYNAMIC_SCATTER_TILING_H +#define DYNAMIC_SCATTER_TILING_H +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(DynamicScatterTilingData) + TILING_DATA_FIELD_DEF(uint32_t, totalLength); + TILING_DATA_FIELD_DEF(uint32_t, formerNum); + TILING_DATA_FIELD_DEF(uint32_t, tailNum); + TILING_DATA_FIELD_DEF(uint32_t, formerLength); + TILING_DATA_FIELD_DEF(uint32_t, tailLength); + TILING_DATA_FIELD_DEF(uint32_t, alignNum); + TILING_DATA_FIELD_DEF(uint32_t, totalLengthAligned); + TILING_DATA_FIELD_DEF(uint32_t, formerInputNum); + TILING_DATA_FIELD_DEF(uint32_t, tailInputNum); + TILING_DATA_FIELD_DEF(uint32_t, featsNum); + TILING_DATA_FIELD_DEF(uint32_t, outPointNum); + TILING_DATA_FIELD_DEF(uint32_t, outPointNumAligned); + TILING_DATA_FIELD_DEF(uint32_t, featsAligned); + TILING_DATA_FIELD_DEF(uint32_t, tileLength); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(DynamicScatter, DynamicScatterTilingData) +} +#endif // DYNAMIC_SCATTER_TILING_H \ No newline at end of file diff --git a/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter.cpp b/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..4784f37874efb0d51d9de9296adba5199ee0b67a --- /dev/null +++ b/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter.cpp @@ -0,0 +1,21 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + */ +#include "dynamic_scatter_max.h" +#include "dynamic_scatter_sum.h" + +using namespace DynamicScatterN; + +extern "C" __global__ __aicore__ void dynamic_scatter(GM_ADDR feats, GM_ADDR coors_map, GM_ADDR reduced_feats, + GM_ADDR workspace, GM_ADDR tiling) { + GET_TILING_DATA(tilingData, tiling); + if (TILING_KEY_IS(100)) { + DynamicScatterN::DynamicScatterMax op; + op.Init(feats, coors_map, reduced_feats, &tilingData); + op.Process(); + } else if (TILING_KEY_IS(101)) { + DynamicScatterN::DynamicScatterSum op; + op.Init(feats, coors_map, reduced_feats, &tilingData); + op.Process(); + } +} \ No newline at end of file diff --git a/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_base.h b/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_base.h new file mode 100644 index 0000000000000000000000000000000000000000..beadc04edfc0480a51a6dd4a277c882c1ce226a4 --- /dev/null +++ b/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_base.h @@ -0,0 +1,105 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + */ +#ifndef _DYNAMIC_SCATTER_BASE_H_ +#define _DYNAMIC_SCATTER_BASE_H_ + +#include +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_operator.h" + +namespace DynamicScatterN { +using namespace AscendC; + +constexpr int32_t BUFFER_NUM = 2; + +template +class DynamicScatterBase { +public: + __aicore__ inline DynamicScatterBase() {} + __aicore__ inline void BaseInit(DynamicScatterTilingData* tilingData) + { + TilingDataInit(tilingData); + MemberDataInit(); + uint32_t inputNumAligned = (inputNum + alignNum - 1) / alignNum * alignNum; + tileLength = inputNumAligned * featsNum > tileLength ? tileLength : inputNumAligned * featsNum; + tilePointNum = tileLength / featsNum; + loop = blockLength / tileLength; + lastLength = blockLength % tileLength; + lastPointNum = lastLength / featsNum; + featsLastStartIndex = blockLength - lastLength; + featsLastStartIndex = featsLastStartIndex > 0 ? featsLastStartIndex : 0; + mapLastStartIndex = featsLastStartIndex / featsNum; + outLength = outPointNumAligned * featsNum; + CopyParamasInit(); + } + + __aicore__ inline void TilingDataInit(DynamicScatterTilingData* tilingData) + { + totalLength = tilingData->totalLength; + formerNum = tilingData->formerNum; + tailNum = tilingData->tailNum; + formerLength = tilingData->formerLength; + tailLength = tilingData->tailLength; + alignNum = tilingData->alignNum; + totalLengthAligned = tilingData->totalLengthAligned; + formerInputNum = tilingData->formerInputNum; + tailInputNum = tilingData->tailInputNum; + featsNum = tilingData->featsNum; + outPointNum = tilingData->outPointNum; + outPointNumAligned = tilingData->outPointNumAligned; + featsAligned = tilingData->featsAligned; + tileLength = tilingData->tileLength; + } + + __aicore__ inline void MemberDataInit() + { + if (GetBlockIdx() < formerNum) { + blockLength = formerLength; + inputNum = formerInputNum; + featsOffset = blockLength * GetBlockIdx(); + coorsMapOffset = formerInputNum * GetBlockIdx(); + } else { + blockLength = tailLength; + inputNum = tailInputNum; + featsOffset = formerLength * formerNum + tailLength * (GetBlockIdx() - formerNum); + coorsMapOffset = formerInputNum * formerNum + tailInputNum * (GetBlockIdx() - formerNum); + } + } + + __aicore__ inline void CopyParamasInit() + { + copyParamsOut.blockCount = 1; + copyParamsOut.blockLen = static_cast(featsNum * sizeof(T)); + copyParamsOut.srcStride = 0; + copyParamsOut.dstStride = 0; + copyParamsOut.rsv = 0; + } + + __aicore__ inline void CopyOutMax(GlobalTensor reducedFeatsGm, uint32_t index, LocalTensor featsLocal) + { + SetAtomicMax(); + DataCopyPad(reducedFeatsGm[index], featsLocal, copyParamsOut); + SetAtomicNone(); + pipe_barrier(PIPE_ALL); + } + + __aicore__ inline void CopyOutAdd(GlobalTensor reducedFeatsGm, uint32_t index, LocalTensor featsLocal) + { + SetAtomicAdd(); + DataCopyPad(reducedFeatsGm[index], featsLocal, copyParamsOut); + SetAtomicNone(); + pipe_barrier(PIPE_ALL); + } +protected: + uint32_t totalLength, formerNum, tailNum, formerLength, tailLength, alignNum, totalLengthAligned, outLength; + uint32_t formerInputNum, tailInputNum, featsNum; + uint32_t outPointNum, outPointNumAligned, featsAligned; + uint32_t blockLength, tilePointNum, inputNum, featsOffset, coorsMapOffset; + DataCopyExtParams copyParamsOut; + uint32_t blockLengthAligned; + int32_t tileLength = 8; + int32_t loop, lastLength, featsLastStartIndex, mapLastStartIndex, lastPointNum; +}; +} // DynamicScatterN +#endif // _DYNAMIC_SCATTER_BASE_H_ \ No newline at end of file diff --git a/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_max.h b/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_max.h new file mode 100644 index 0000000000000000000000000000000000000000..3b43858a5b225ab2ce75e5e0851748d2137f7ced --- /dev/null +++ b/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_max.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + */ +#ifndef _DYNAMIC_SCATTER_MAX_H_ +#define _DYNAMIC_SCATTER_MAX_H_ + +#include "dynamic_scatter_base.h" + +namespace DynamicScatterN { +using namespace AscendC; + +template +class DynamicScatterMax : public DynamicScatterBase { +public: + __aicore__ inline DynamicScatterMax() {} + __aicore__ inline void Init(GM_ADDR feats, GM_ADDR coorsMap, GM_ADDR reducedFeats, + DynamicScatterTilingData* tilingData) + { + this->BaseInit(tilingData); + BufferInit(); + featsGm.SetGlobalBuffer((__gm__ T *)feats + this->featsOffset, this->blockLength); + coorsMapGm.SetGlobalBuffer((__gm__ int32_t *)coorsMap + this->coorsMapOffset, this->inputNum); + reducedFeatsGm.SetGlobalBuffer((__gm__ T *)reducedFeats, this->outLength); + if (GetBlockIdx() == 0) { + InitOutput(this->reducedFeatsGm, this->outLength, static_cast(-INFINITY)); + } + SyncAll(); + } + + __aicore__ inline void Process() + { + // loop count need to be doubled, due to double buffer + for (int32_t i = 0; i < this->loop; i++) { + CopyIn(i * this->tilePointNum); + pipe_barrier(PIPE_ALL); + Compute(i * this->tileLength, this->tilePointNum); + pipe_barrier(PIPE_ALL); + } + + if (this->lastLength) { + CopyInTail(this->mapLastStartIndex); + pipe_barrier(PIPE_ALL); + ComputeTail(this->lastPointNum); + } + } + +private: + __aicore__ inline void BufferInit() + { + pipe.InitBuffer(inQueueFeats, BUFFER_NUM, this->featsAligned * sizeof(T)); + pipe.InitBuffer(inQueueCoorsMap, BUFFER_NUM, this->tilePointNum * sizeof(int32_t)); + } + + __aicore__ inline void CopyIn(int32_t startIndex) + { + // alloc tensor from queue memory + LocalTensor coorsMapLocal = inQueueCoorsMap.AllocTensor(); + // copy progress_th tile from global tensor to local tensor + DataCopy(coorsMapLocal, this->coorsMapGm[startIndex], this->tilePointNum); + // enque input tensors to VECIN queue + inQueueCoorsMap.EnQue(coorsMapLocal); + } + + __aicore__ inline void Compute(int32_t gmOffset, int32_t pointNum) + { + // deque input tensors from VECIN queue + LocalTensor featsLocal = inQueueFeats.AllocTensor(); + LocalTensor coorsMapLocal = inQueueCoorsMap.DeQue(); + + for (uint32_t idx = 0; idx < pointNum; idx++) { + int32_t reduce_to = coorsMapLocal.GetValue(idx); + if (reduce_to > -1) { + DataCopy(featsLocal, this->featsGm[gmOffset + idx * this->featsNum], this->featsAligned); + pipe_barrier(PIPE_ALL); + this->CopyOutMax(reducedFeatsGm, reduce_to * this->featsNum, featsLocal); + } + } + // free input tensors for reuse + inQueueFeats.FreeTensor(featsLocal); + inQueueCoorsMap.FreeTensor(coorsMapLocal); + } + + __aicore__ inline void CopyInTail(int32_t mapLastStartIndex) + { + // alloc tensor from queue memory + LocalTensor coorsMapLocal = inQueueCoorsMap.AllocTensor(); + // copy progress_th tile from global tensor to local tensor + DataCopy(coorsMapLocal, this->coorsMapGm[mapLastStartIndex], this->tilePointNum); + // enque input tensors to VECIN queue + inQueueCoorsMap.EnQue(coorsMapLocal); + } + + __aicore__ inline void ComputeTail(int32_t pointNum) + { + // deque input tensors from VECIN queue + LocalTensor featsLocal = inQueueFeats.AllocTensor(); + LocalTensor coorsMapLocal = inQueueCoorsMap.DeQue(); + + for (uint32_t idx = 0; idx < pointNum; idx++) { + int32_t reduce_to = coorsMapLocal.GetValue(idx); + if (reduce_to > -1) { + DataCopy(featsLocal, this->featsGm[this->featsLastStartIndex + idx * this->featsNum], + this->featsAligned); + pipe_barrier(PIPE_ALL); + this->CopyOutMax(reducedFeatsGm, reduce_to * this->featsNum, featsLocal); + } + } + // free input tensors for reuse + inQueueFeats.FreeTensor(featsLocal); + inQueueCoorsMap.FreeTensor(coorsMapLocal); + } + +private: + TPipe pipe; + TQue inQueueFeats, inQueueCoorsMap, inQueueReduceCount; + TQue outQueueReducedFeats; + + GlobalTensor featsGm, reducedFeatsGm; + GlobalTensor coorsMapGm, reduceCountGm; +}; +} // DynamicScatterN +#endif // _DYNAMIC_SCATTER_MAX_H_ \ No newline at end of file diff --git a/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_sum.h b/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_sum.h new file mode 100644 index 0000000000000000000000000000000000000000..0a9ecd3ea364cfe2fec165e798aaff81c4582bea --- /dev/null +++ b/ads/common/ops/kernels/ads_op/op_kernel/dynamic_scatter_sum.h @@ -0,0 +1,122 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + */ +#ifndef _DYNAMIC_SCATTER_SUM_H_ +#define _DYNAMIC_SCATTER_SUM_H_ + +#include "dynamic_scatter_base.h" + +namespace DynamicScatterN { +using namespace AscendC; + +template +class DynamicScatterSum : public DynamicScatterBase { +public: + __aicore__ inline DynamicScatterSum() {} + __aicore__ inline void Init(GM_ADDR feats, GM_ADDR coorsMap, GM_ADDR reducedFeats, + DynamicScatterTilingData* tilingData) + { + this->BaseInit(tilingData); + BufferInit(); + featsGm.SetGlobalBuffer((__gm__ T *)feats + this->featsOffset, this->blockLength); + coorsMapGm.SetGlobalBuffer((__gm__ int32_t *)coorsMap + this->coorsMapOffset, this->inputNum); + reducedFeatsGm.SetGlobalBuffer((__gm__ T *)reducedFeats, this->outLength); + if (GetBlockIdx() == 0) { + InitOutput(this->reducedFeatsGm, this->outLength, static_cast(0.0)); + } + SyncAll(); + } + + __aicore__ inline void Process() + { + // loop count need to be doubled, due to double buffer + for (int32_t i = 0; i < this->loop; i++) { + CopyIn(i * this->tilePointNum); + pipe_barrier(PIPE_ALL); + Compute(i * this->tileLength, this->tilePointNum); + pipe_barrier(PIPE_ALL); + } + + if (this->lastLength) { + CopyInTail(this->mapLastStartIndex); + pipe_barrier(PIPE_ALL); + ComputeTail(this->lastPointNum); + } + } + +private: + __aicore__ inline void BufferInit() + { + pipe.InitBuffer(inQueueFeats, BUFFER_NUM, this->featsAligned * sizeof(T)); + pipe.InitBuffer(inQueueCoorsMap, BUFFER_NUM, this->tilePointNum * sizeof(int32_t)); + } + + __aicore__ inline void CopyIn(int32_t startIndex) + { + // alloc tensor from queue memory + LocalTensor coorsMapLocal = inQueueCoorsMap.AllocTensor(); + // copy progress_th tile from global tensor to local tensor + DataCopy(coorsMapLocal, this->coorsMapGm[startIndex], this->tilePointNum); + // enque input tensors to VECIN queue + inQueueCoorsMap.EnQue(coorsMapLocal); + } + + __aicore__ inline void Compute(int32_t gmOffset, int32_t pointNum) + { + // deque input tensors from VECIN queue + LocalTensor featsLocal = inQueueFeats.AllocTensor(); + LocalTensor coorsMapLocal = inQueueCoorsMap.DeQue(); + + for (uint32_t idx = 0; idx < pointNum; idx++) { + int32_t reduce_to = coorsMapLocal.GetValue(idx); + if (reduce_to > -1) { + DataCopy(featsLocal, this->featsGm[gmOffset + idx * this->featsNum], this->featsAligned); + pipe_barrier(PIPE_ALL); + this->CopyOutAdd(reducedFeatsGm, reduce_to * this->featsNum, featsLocal); + } + } + // free input tensors for reuse + inQueueFeats.FreeTensor(featsLocal); + inQueueCoorsMap.FreeTensor(coorsMapLocal); + } + + __aicore__ inline void CopyInTail(int32_t mapLastStartIndex) + { + // alloc tensor from queue memory + LocalTensor coorsMapLocal = inQueueCoorsMap.AllocTensor(); + // copy progress_th tile from global tensor to local tensor + DataCopy(coorsMapLocal, this->coorsMapGm[mapLastStartIndex], this->tilePointNum); + // enque input tensors to VECIN queue + inQueueCoorsMap.EnQue(coorsMapLocal); + } + + __aicore__ inline void ComputeTail(int32_t pointNum) + { + // deque input tensors from VECIN queue + LocalTensor featsLocal = inQueueFeats.AllocTensor(); + LocalTensor coorsMapLocal = inQueueCoorsMap.DeQue(); + + for (uint32_t idx = 0; idx < pointNum; idx++) { + int32_t reduce_to = coorsMapLocal.GetValue(idx); + int32_t featsIndex = this->inputNum - idx - 1; + if (reduce_to > -1) { + DataCopy(featsLocal, this->featsGm[this->featsLastStartIndex + idx * this->featsNum], + this->featsAligned); + pipe_barrier(PIPE_ALL); + this->CopyOutAdd(reducedFeatsGm, reduce_to * this->featsNum, featsLocal); + } + } + // free input tensors for reuse + inQueueFeats.FreeTensor(featsLocal); + inQueueCoorsMap.FreeTensor(coorsMapLocal); + } +private: + TPipe pipe; + TQue inQueueFeats, inQueueCoorsMap, inQueueReduceCount; + TQue outQueueReducedFeats; + + GlobalTensor featsGm, reducedFeatsGm; + GlobalTensor coorsMapGm, reduceCountGm; +}; +} // DynamicScatterN +#endif // _DYNAMIC_SCATTER_SUM_H_ \ No newline at end of file diff --git a/ads/common/ops/npu_dynamic_scatter.py b/ads/common/ops/npu_dynamic_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..755915eaa143be3b56e05edd8feea794191b36cd --- /dev/null +++ b/ads/common/ops/npu_dynamic_scatter.py @@ -0,0 +1,27 @@ +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads_c + + +class DynamicScatterFunction(Function): + @staticmethod + # 'pylint: disable=too-many-arguments,huawei-too-many-arguments + def forward(ctx, feats, coors, reduce_type): + (voxel_feats, voxel_coors, point2voxel_map, voxel_points_count) = ads_c.npu_dynamic_scatter(feats, coors, + reduce_type) + ctx.reduce_type = reduce_type + ctx.save_for_backward(feats, voxel_feats, point2voxel_map, voxel_points_count) + ctx.mark_non_differentiable(voxel_coors) + return voxel_feats, voxel_coors + + @staticmethod + # 'pylint: disable=too-many-arguments,huawei-too-many-arguments + # 'pylint: disable=too-many-return-arguments,huawei-too-many-return-arguments + def backward(ctx, y_grad): + raise "Error: npu_dynamic_scatter is not currently support backward." + + +npu_dynamic_scatter = DynamicScatterFunction.apply diff --git a/tests/test_npu_dynamic_scatter.py b/tests/test_npu_dynamic_scatter.py new file mode 100644 index 0000000000000000000000000000000000000000..fbbea560785deda3f4e5a86283c75332d5cc8fb3 --- /dev/null +++ b/tests/test_npu_dynamic_scatter.py @@ -0,0 +1,68 @@ +import unittest +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +reduce_type_mapping = {"mean": 1, "max": 2} + + +class TestDynamicScatter(TestCase): + + def cpu_op_exec(self, feats, coors, reduce_type): + clean_coors = coors.masked_fill(coors.lt(0).any(-1, True), -1) + out_coors, coors_map, reduce_count = clean_coors.unique(dim=0, sorted=True, return_inverse=True, + return_counts=True) + out_coors = out_coors[out_coors.min(dim=-1).values >= 0] + + if out_coors[0][0].lt(0): + out_coors = out_coors.slice(0, 1) + reduce_count = reduce_count.slice(0, 1) + coors_map = coors_map - 1 + + output_feats = [] + if reduce_type == 1: + for ref_voxel_coors in out_coors: + voxel_mask = (coors == ref_voxel_coors).all(dim=-1) + output_feats.append(feats[voxel_mask].mean(dim=0)) + else: + for ref_voxel_coors in out_coors: + voxel_mask = (coors == ref_voxel_coors).all(dim=-1) + output_feats.append(feats[voxel_mask].max(dim=0).values) + output_feats = torch.stack(output_feats) + return output_feats.numpy(), out_coors.numpy() + + def npu_op_exec(self, feats, coors, reduce_type): + output_feats, output_coors = ads.common.npu_dynamic_scatter(feats, coors, reduce_type) + return output_feats.cpu().numpy(), output_coors.cpu().numpy() + + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `DynamicScatter` is only supported on 910B, skip this ut!") + def test_dynamic_scatter_max_fp32(self): + shape_feats = (2000, 3) + shape_coors = (2000, 3) + cpu_feats, npu_feats = create_common_tensor(["float32", 2, shape_feats], -50, 50) + cpu_coors, npu_coors = create_common_tensor(["int32", 2, shape_coors], -1, 20) + reduce_type = reduce_type_mapping["max"] + cpu_output = self.cpu_op_exec(cpu_feats, cpu_coors, reduce_type) + npu_output = self.npu_op_exec(npu_feats, npu_coors, reduce_type) + self.assertRtolEqual(cpu_output[0], npu_output[0]) + self.assertRtolEqual(cpu_output[1], npu_output[1]) + + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `DynamicScatter` is only supported on 910B, skip this ut!") + def test_dynamic_scatter_mean_fp32(self): + shape_feats = (2000, 3) + shape_coors = (2000, 3) + cpu_feats, npu_feats = create_common_tensor(["float32", 2, shape_feats], -50, 50) + cpu_coors, npu_coors = create_common_tensor(["int32", 2, shape_coors], -1, 20) + reduce_type = reduce_type_mapping["mean"] + cpu_output = self.cpu_op_exec(cpu_feats, cpu_coors, reduce_type) + npu_output = self.npu_op_exec(npu_feats, npu_coors, reduce_type) + self.assertRtolEqual(cpu_output[0], npu_output[0]) + self.assertRtolEqual(cpu_output[1], npu_output[1]) + +if __name__ == "__main__": + run_tests()