From 70507a98cae7d4e616c6c236926f2dc5c8b566b1 Mon Sep 17 00:00:00 2001 From: lishuai183 Date: Tue, 20 Feb 2024 15:25:15 +0800 Subject: [PATCH 01/12] init nms3d op. --- ads/common/__init__.py | 1 + ads/common/ops/csrc/Nms3dOpApi.cpp | 19 + ads/common/ops/csrc/functions.h | 2 + ads/common/ops/csrc/pybind.cpp | 3 + ads/common/ops/kernels/op_host/nms3d.cpp | 88 ++++ ads/common/ops/kernels/op_host/nms3d_tiling.h | 24 + ads/common/ops/kernels/op_kernel/nms3d.cpp | 478 ++++++++++++++++++ ads/common/ops/npu_nms3d.py | 20 + tests/torch/test_npu_nms3d.py | 272 ++++++++++ 9 files changed, 907 insertions(+) create mode 100644 ads/common/ops/csrc/Nms3dOpApi.cpp create mode 100644 ads/common/ops/kernels/op_host/nms3d.cpp create mode 100644 ads/common/ops/kernels/op_host/nms3d_tiling.h create mode 100644 ads/common/ops/kernels/op_kernel/nms3d.cpp create mode 100644 ads/common/ops/npu_nms3d.py create mode 100644 tests/torch/test_npu_nms3d.py diff --git a/ads/common/__init__.py b/ads/common/__init__.py index ca4f48b0..d20c765a 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -30,3 +30,4 @@ from .ops.npu_multi_scale_deformable_attn_function import npu_multi_scale_deform from .ops.dynamic_voxelization import voxelization from .ops.dynamic_voxelization import Voxelization from .ops.nms3d_normal import npu_nms3d_normal +from .ops.npu_nms3d import npu_nms3d diff --git a/ads/common/ops/csrc/Nms3dOpApi.cpp b/ads/common/ops/csrc/Nms3dOpApi.cpp new file mode 100644 index 00000000..f3cad27c --- /dev/null +++ b/ads/common/ops/csrc/Nms3dOpApi.cpp @@ -0,0 +1,19 @@ +#include "csrc/OpApiCommon.h" +#include "functions.h" +#include + +std::tuple nms3d(const at::Tensor& boxes, + double threshold) +{ + int32_t box_num = boxes.size(0); + int32_t data_align = 16; + int32_t mask_num = ((box_num - 1) / data_align + 1) * data_align; + at::Tensor mask = + at::empty({box_num, mask_num}, boxes.options().dtype(at::kShort)); + EXEC_NPU_CMD(aclnnNms3d, boxes, threshold, mask); + + at::Tensor keep = at::zeros({box_num}, mask.options()); + at::Tensor num_out = at::zeros(1, mask.options()); + EXEC_NPU_CMD(aclnnGatherNms3dMask, mask, keep, num_out); + return std::tie(keep, num_out); +} diff --git a/ads/common/ops/csrc/functions.h b/ads/common/ops/csrc/functions.h index cf901d9c..58c4ad8b 100644 --- a/ads/common/ops/csrc/functions.h +++ b/ads/common/ops/csrc/functions.h @@ -160,4 +160,6 @@ at::Tensor DynamicVoxelization( const double coorsMinZ); std::tuple nms3d_normal(const at::Tensor &boxes, double nms_overlap_thresh); + +std::tuple nms3d(const at::Tensor &boxes, double threshold); #endif // COMMON_OPS_CSRC_FUNCTIONS_H_ diff --git a/ads/common/ops/csrc/pybind.cpp b/ads/common/ops/csrc/pybind.cpp index 3d2c805c..f3a8fb45 100644 --- a/ads/common/ops/csrc/pybind.cpp +++ b/ads/common/ops/csrc/pybind.cpp @@ -94,4 +94,7 @@ void init_common(pybind11::module &m) // nms3d_normal m.def("nms3d_normal", &nms3d_normal); + + // ads_nms3d + m.def("nms3d", &nms3d); } diff --git a/ads/common/ops/kernels/op_host/nms3d.cpp b/ads/common/ops/kernels/op_host/nms3d.cpp new file mode 100644 index 00000000..2ee0219c --- /dev/null +++ b/ads/common/ops/kernels/op_host/nms3d.cpp @@ -0,0 +1,88 @@ +#include "nms3d_tiling.h" +#include "register/op_def_registry.h" +#include "tiling/platform/platform_ascendc.h" + +using namespace std; + +namespace optiling { +static ge::graphStatus Nms3dTilingFunc(gert::TilingContext *context) +{ + Nms3dTilingData tiling; + + auto platformInfo = context->GetPlatformInfo(); + auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo); + static uint32_t coreNum = ascendcPlatform.GetCoreNumAiv(); + + auto boxShape = context->GetInputShape(0)->GetStorageShape(); + auto maskShape = context->GetOutputShape(0)->GetStorageShape(); + auto dtype = context->GetInputDesc(0)->GetDataType(); + auto attrs = context->GetAttrs(); + uint32_t boxNum = boxShape.GetDim(0); + uint32_t maskNum = maskShape.GetDim(1); + uint32_t dataAlign = 16; + if (ge::DT_FLOAT == dtype) { + context->SetTilingKey(1); + } else if (ge::DT_FLOAT16 == dtype) { + context->SetTilingKey(2); + } else { + return ge::GRAPH_FAILED; + } + + uint32_t usedCoreNum = std::min((boxNum - 1) / dataAlign + 1, coreNum); + uint32_t loopTime = (boxNum - 1) / (usedCoreNum * dataAlign) + 1; + uint32_t tailSum = boxNum - usedCoreNum * (loopTime - 1) * dataAlign; + uint32_t tailNum = (tailSum - 1) % dataAlign + 1; + float nms_overlap_thresh = *(attrs->GetAttrPointer(0)); + + context->SetBlockDim(usedCoreNum); + tiling.set_usedCoreNum(usedCoreNum); + tiling.set_boxNum(boxNum); + tiling.set_loopTime(loopTime); + tiling.set_eachSum(loopTime * dataAlign); + tiling.set_tailSum(tailSum); + tiling.set_tailNum(tailNum); + tiling.set_maskNum(maskNum); + tiling.set_overlapThresh(nms_overlap_thresh); + tiling.SaveToBuffer(context->GetRawTilingData()->GetData(), + context->GetRawTilingData()->GetCapacity()); + context->GetRawTilingData()->SetDataSize(tiling.GetDataSize()); + + size_t *currentWorkspace = context->GetWorkspaceSizes(1); + currentWorkspace[0] = 0; + return ge::GRAPH_SUCCESS; +} +} // namespace optiling + +namespace ge { +static ge::graphStatus Nms3dInferShape(gert::InferShapeContext *context) +{ + return GRAPH_SUCCESS; +} +} // namespace ge + +namespace ops { +class Nms3d : public OpDef { +public: + explicit Nms3d(const char *name) : OpDef(name) + { + this->Input("boxes") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Output("mask") + .ParamType(REQUIRED) + .DataType({ge::DT_INT16, ge::DT_INT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + this->Attr("threshold").AttrType(REQUIRED).Float(); + + this->SetInferShape(ge::Nms3dInferShape); + + this->AICore().SetTiling(optiling::Nms3dTilingFunc); + this->AICore().AddConfig("ascend910b"); + } +}; + +OP_ADD(Nms3d); +} // namespace ops \ No newline at end of file diff --git a/ads/common/ops/kernels/op_host/nms3d_tiling.h b/ads/common/ops/kernels/op_host/nms3d_tiling.h new file mode 100644 index 00000000..2db9787f --- /dev/null +++ b/ads/common/ops/kernels/op_host/nms3d_tiling.h @@ -0,0 +1,24 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + */ +#ifndef NMS3D_TILING_H +#define NMS3D_TILING_H + +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(Nms3dTilingData) + TILING_DATA_FIELD_DEF(uint32_t, usedCoreNum) // used cores + TILING_DATA_FIELD_DEF(uint32_t, boxNum) // count of boxes + TILING_DATA_FIELD_DEF(uint32_t, loopTime) // loop times + TILING_DATA_FIELD_DEF(uint32_t, eachSum) // count of each core, = loop_time * 8 + TILING_DATA_FIELD_DEF(uint32_t, tailSum) // count of tail core + TILING_DATA_FIELD_DEF(uint32_t, tailNum) // last time count of tail core + TILING_DATA_FIELD_DEF(uint32_t, maskNum) // mask align 32bit + TILING_DATA_FIELD_DEF(float, overlapThresh) +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(Nms3d, Nms3dTilingData) +} // namespace optiling + +#endif // NMS3D_TILING_H \ No newline at end of file diff --git a/ads/common/ops/kernels/op_kernel/nms3d.cpp b/ads/common/ops/kernels/op_kernel/nms3d.cpp new file mode 100644 index 00000000..f52853a7 --- /dev/null +++ b/ads/common/ops/kernels/op_kernel/nms3d.cpp @@ -0,0 +1,478 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023-2023. All rights reserved. + * + */ +#include "kernel_operator.h" +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_utils.h" + +#define M_PI 3.14159265358979323846 /* pi */ + +using namespace AscendC; +constexpr int32_t +BUFFER_NUM = 2; +constexpr float EPS = 1e-8; +constexpr float ATAN2_DEFAULT_VALUE = 1000.0; + + +struct Point { + float x, y; + + __aicore__ Point() {} + + __aicore__ Point(float _x, float _y) + { + x = _x; + y = _y; + } + + __aicore__ void set(float _x, float _y) + { + x = _x; + y = _y; + } + + __aicore__ Point + + operator+(const Point &b) const + { + return Point(x + b.x, y + b.y); + } + + __aicore__ Point + + operator-(const Point &b) const + { + return Point(x - b.x, y - b.y); + } +}; + +template +class KernelNms3d { +public: + __aicore__ inline KernelNms3d() {} + + __aicore__ inline void Init(GM_ADDR boxes, GM_ADDR mask, + const Nms3dTilingData *__restrict tiling_data) + { + ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); + usedCoreNum = tiling_data->usedCoreNum; + eachSum = tiling_data->eachSum; + boxNum = tiling_data->boxNum; + tailSum = tiling_data->tailSum; + tailNum = tiling_data->tailNum; + maskNum = tiling_data->maskNum; + loopTime = tiling_data->loopTime; + overlapThresh = tiling_data->overlapThresh; + + uint32_t core_id = GetBlockIdx(); + isLastCore = (core_id == (tiling_data->usedCoreNum - 1)); + + boxGm.SetGlobalBuffer(reinterpret_cast<__gm__ T * > (boxes), boxNum * 7); + maskGm.SetGlobalBuffer(reinterpret_cast<__gm__ int16_t * > (mask), + maskNum * boxNum); + + pipe.InitBuffer(inQueueCur, BUFFER_NUM, dataAlign * sizeof(T)); + pipe.InitBuffer(inQueueBox, BUFFER_NUM, dataAlign * 7 * sizeof(T)); + pipe.InitBuffer(outQueueMask, BUFFER_NUM, dataAlign * sizeof(int16_t)); + pipe.InitBuffer(oneMask, BUFFER_NUM, dataAlign * sizeof(int16_t)); + + pipe.InitBuffer(comBuf, dataAlign * sizeof(float)); + pipe.InitBuffer(p1Buf, dataAlign * sizeof(T)); + pipe.InitBuffer(p2Buf, dataAlign * sizeof(T)); + pipe.InitBuffer(q1Buf, dataAlign * sizeof(T)); + pipe.InitBuffer(q2Buf, dataAlign * sizeof(T)); + pipe.InitBuffer(angleBuf, dataAlign * sizeof(T)); + pipe.InitBuffer(sinBuf, dataAlign * sizeof(T)); + pipe.InitBuffer(cosBuf, dataAlign * sizeof(T)); + pipe.InitBuffer(pointBuf, dataAlign * sizeof(T)); + pipe.InitBuffer(min1Buf, dataAlign * sizeof(T)); + pipe.InitBuffer(min2Buf, dataAlign * sizeof(T)); + pipe.InitBuffer(max1Buf, dataAlign * sizeof(T)); + pipe.InitBuffer(max2Buf, dataAlign * sizeof(T)); + if constexpr(sizeof(T) == sizeof(half)) { + pipe.InitBuffer(calcBuf, dataAlign * 2 * 7 * sizeof(float)); + curTemp = calcBuf.Get(dataAlign * 2 * 7); + boxTemp = curTemp[8]; + } + } + + __aicore__ inline void Process() + { + uint32_t core_id = GetBlockIdx(); + LocalTensor oneLocal = oneMask.AllocTensor(); + Duplicate(oneLocal, static_cast(1), dataAlign); + for (size_t i = 0; i < boxNum; ++i) { + for (size_t j = 0; j < loopTime; ++j) { + uint32_t start = core_id * eachSum + dataAlign * j; + if (i >= start + dataAlign) { + DataCopy(maskGm[i * maskNum + start], oneLocal, dataAlign); + continue; + } + bool is_last = (isLastCore) && (j == loopTime - 1); + CopyIn(i, start, is_last); + Compute(i, start, is_last); + CopyOut(i, start); + } + } + oneMask.FreeTensor(oneLocal); + } + +private: + __aicore__ inline void CopyIn(int32_t cur_box, int32_t com_box, + bool is_last) + { + LocalTensor curLocal = inQueueCur.AllocTensor(); + LocalTensor boxLocal = inQueueBox.AllocTensor(); + DataCopy(curLocal, boxGm[cur_box * 7], dataAlign); + DataCopy(boxLocal, boxGm[com_box * 7], dataAlign * 7); + inQueueCur.EnQue(curLocal); + inQueueBox.EnQue(boxLocal); + } + + __aicore__ inline void Compute(int32_t cur_box, int32_t com_box, + bool is_last) + { + uint32_t cmpNum = is_last ? tailNum : dataAlign; + if constexpr(sizeof(T) == sizeof(half)) { + LocalTensor curLocal = inQueueCur.DeQue(); + LocalTensor boxLocal = inQueueBox.DeQue(); + Cast(curTemp, curLocal, RoundMode::CAST_NONE, dataAlign); + Cast(boxTemp, boxLocal, RoundMode::CAST_NONE, 7 * dataAlign); + inQueueCur.FreeTensor(curLocal); + inQueueBox.FreeTensor(boxLocal); + } else { + curTemp = inQueueCur.DeQue(); + boxTemp = inQueueBox.DeQue(); + } + + PipeBarrier(); + LocalTensor outLocal = outQueueMask.AllocTensor(); + for (size_t i = 0; i < cmpNum; i++) { + if (cur_box >= com_box + i) { + outLocal.SetValue(i, 1); + continue; + } + LocalTensor comLocal = comBuf.Get(); + for (size_t k = 0; k < 7; k++) { + comLocal.SetValue(k, static_cast(boxTemp.GetValue(i * 7 + k))); + } + auto flag = iou_bev(curTemp, comLocal); + if (flag > overlapThresh) { + outLocal.SetValue(i, 0); + } else { + outLocal.SetValue(i, 1); + } + } + PipeBarrier(); + outQueueMask.EnQue(outLocal); + if constexpr(sizeof(T) != sizeof(half)) { + inQueueCur.FreeTensor(curTemp); + inQueueBox.FreeTensor(boxTemp); + } + } + + __aicore__ inline void CopyOut(int32_t cur_box, int32_t com_box) + { + LocalTensor outLocal = outQueueMask.DeQue(); + DataCopy(maskGm[cur_box * maskNum + com_box], outLocal, dataAlign); + outQueueMask.FreeTensor(outLocal); + } + +private: + __aicore__ inline float cross(const Point &a, const Point &b) + { + return a.x * b.y - a.y * b.x; + } + + __aicore__ inline float cross(const Point &p1, const Point &p2, + const Point &p0) + { + return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y); + } + + __aicore__ int check_rect_cross(const Point &p1, const Point &p2, + const Point &q1, const Point &q2) + { + int ret = min(p1.x, p2.x) <= max(q1.x, q2.x) && + min(q1.x, q2.x) <= max(p1.x, p2.x) && + min(p1.y, p2.y) <= max(q1.y, q2.y) && + min(q1.y, q2.y) <= max(p1.y, p2.y); + return ret; + } + + __aicore__ inline int check_in_box2d(const LocalTensor &box, + const Point &p) + { + const float MARGIN = 1e-2; + float center_x = box.GetValue(0); + float center_y = box.GetValue(1); + LocalTensor angleLocal = angleBuf.Get(); + LocalTensor sinLocal = sinBuf.Get(); + LocalTensor cosLocal = cosBuf.Get(); + angleLocal.SetValue(0, -box.GetValue(6)); + Sin(sinLocal, angleLocal); + Cos(cosLocal, angleLocal); + float angle_cos = cosLocal.GetValue(0); + float angle_sin = sinLocal.GetValue(0); + float rot_x = + (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin); + float rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos; + + return (abs(rot_x) < box.GetValue(3) / 2 + MARGIN && + abs(rot_y) < box.GetValue(4) / 2 + MARGIN); + } + + __aicore__ inline int intersection(const Point &p1, const Point &p0, + const Point &q1, const Point &q0, + Point &ans_point) + { + if (check_rect_cross(p0, p1, q0, q1) == 0) { + return 0; + } + float s1 = cross(q0, p1, p0); + float s2 = cross(p1, q1, p0); + float s3 = cross(p0, q1, q0); + float s4 = cross(q1, p1, q0); + if (!(s1 * s2 > static_cast(0.0) && + s3 * s4 > static_cast(0.0))) { + return 0; + } + float s5 = cross(q1, p1, p0); + if (abs(s5 - s1) > EPS) { + ans_point.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1) == 0 ? ((s5 - s1) + EPS) : (s5 - s1); + ans_point.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1) == 0 ? ((s5 - s1) + EPS) : (s5 - s1); + } else { + float a0 = p0.y - p1.y; + float b0 = p1.x - p0.x; + float c0 = p0.x * p1.y - p1.x * p0.y; + float a1 = q0.y - q1.y; + float b1 = q1.x - q0.x; + float c1 = q0.x * q1.y - q1.x * q0.y; + float D = a0 * b1 - a1 * b0; + + ans_point.x = (b0 * c1 - b1 * c0) / D == 0 ? D + EPS : D; + ans_point.y = (a1 * c0 - a0 * c1) / D == 0 ? D + EPS : D; + } + + return 1; + } + + __aicore__ inline void rotate_around_center(const Point ¢er, + const float angle_cos, + const float angle_sin, Point &p) + { + float new_x = + (p.x - center.x) * angle_cos - (p.y - center.y) * angle_sin + center.x; + float new_y = + (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y; + p.set(new_x, new_y); + } + + __aicore__ inline int point_cmp(const Point &a, const Point &b, + const Point ¢er) + { + return math_atan2(a.y - center.y, a.x - center.x) > + math_atan2(b.y - center.y, b.x - center.x); + } + + __aicore__ inline float math_atan2(float a, float b) + { + float atan2val; + if (b > 0) { + atan2val = math_atan(a / b); + } else if ((b < 0) && (a >= 0)) { + atan2val = math_atan(a / b) + static_cast(M_PI); + } else if ((b < 0) && (a < 0)) { + atan2val = math_atan(a / b) - static_cast(M_PI); + } else if ((b == 0) && (a > 0)) { + atan2val = static_cast(M_PI) / 2; + } else if ((b == 0) && (a < 0)) { + atan2val = 0 - (static_cast(M_PI) / 2); + } else if ((b == 0) && (a == 0)) { + atan2val = ATAN2_DEFAULT_VALUE; + } + return atan2val; + } + + __aicore__ inline float math_atan(const float x) + { + LocalTensor angleLocal = angleBuf.Get(); + LocalTensor atanLocal = sinBuf.Get(); + angleLocal.SetValue(0, x); + Atan(atanLocal, angleLocal, 1); + return atanLocal.GetValue(0); + } + + __aicore__ inline float box_overlap(const LocalTensor &boxATensor, + const LocalTensor &boxBTensor) + { + // params box_a: [x, y, z, dx, dy, dz, heading] + // params box_b: [x, y, z, dx, dy, dz, heading] + + float a_angle = boxATensor.GetValue(6); + float b_angle = boxBTensor.GetValue(6); + float a_dx_half = boxATensor.GetValue(3) / 2; + float b_dx_half = boxBTensor.GetValue(3) / 2; + float a_dy_half = boxATensor.GetValue(4) / 2; + float b_dy_half = boxBTensor.GetValue(4) / 2; + float a_x1 = boxATensor.GetValue(0) - a_dx_half; + float a_y1 = boxATensor.GetValue(1) - a_dy_half; + float a_x2 = boxATensor.GetValue(0) + a_dx_half; + float a_y2 = boxATensor.GetValue(1) + a_dy_half; + float b_x1 = boxBTensor.GetValue(0) - b_dx_half; + float b_y1 = boxBTensor.GetValue(1) - b_dy_half; + float b_x2 = boxBTensor.GetValue(0) + b_dx_half; + float b_y2 = boxBTensor.GetValue(1) + b_dy_half; + + Point center_a(boxATensor.GetValue(0), boxATensor.GetValue(1)); + Point center_b(boxBTensor.GetValue(0), boxBTensor.GetValue(1)); + + Point box_a_corners[5]; + box_a_corners[0].set(a_x1, a_y1); + box_a_corners[1].set(a_x2, a_y1); + box_a_corners[2].set(a_x2, a_y2); + box_a_corners[3].set(a_x1, a_y2); + + Point box_b_corners[5]; + box_b_corners[0].set(b_x1, b_y1); + box_b_corners[1].set(b_x2, b_y1); + box_b_corners[2].set(b_x2, b_y2); + box_b_corners[3].set(b_x1, b_y2); + + // get oriented corners + LocalTensor angleLocal = angleBuf.Get(); + LocalTensor sinLocal = sinBuf.Get(); + LocalTensor cosLocal = cosBuf.Get(); + angleLocal.SetValue(0, a_angle); + angleLocal.SetValue(1, b_angle); + Sin(sinLocal, angleLocal); + Cos(cosLocal, angleLocal); + float a_angle_cos = cosLocal.GetValue(0); + float a_angle_sin = sinLocal.GetValue(0); + float b_angle_cos = cosLocal.GetValue(1); + float b_angle_sin = sinLocal.GetValue(1); + + for (int k = 0; k < 4; k++) { + rotate_around_center(center_a, a_angle_cos, a_angle_sin, + box_a_corners[k]); + rotate_around_center(center_b, b_angle_cos, b_angle_sin, + box_b_corners[k]); + } + + box_a_corners[4] = box_a_corners[0]; + box_b_corners[4] = box_b_corners[0]; + + // get intersection of lines + Point cross_points[16]; + Point poly_center; + int cnt = 0; + int flag = 0; + + poly_center.set(0, 0); + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + flag = intersection(box_a_corners[i + 1], box_a_corners[i], + box_b_corners[j + 1], box_b_corners[j], + cross_points[cnt]); + if (flag) { + poly_center = poly_center + cross_points[cnt]; + cnt++; + } + } + } + + // check corners + for (int k = 0; k < 4; k++) { + if (check_in_box2d(boxATensor, box_b_corners[k])) { + poly_center = poly_center + box_b_corners[k]; + cross_points[cnt] = box_b_corners[k]; + cnt++; + } + if (check_in_box2d(boxBTensor, box_a_corners[k])) { + poly_center = poly_center + box_a_corners[k]; + cross_points[cnt] = box_a_corners[k]; + cnt++; + } + } + if (cnt != 0) { + poly_center.x /= cnt; + poly_center.y /= cnt; + } + + // sort the points of polygon + Point temp; + for (int j = 0; j < cnt - 1; j++) { + for (int i = 0; i < cnt - j - 1; i++) { + if (point_cmp(cross_points[i], cross_points[i + 1], poly_center)) { + temp = cross_points[i]; + cross_points[i] = cross_points[i + 1]; + cross_points[i + 1] = temp; + } + } + } + + // get the overlap areas + float area = 0; + for (int k = 0; k < cnt - 1; k++) { + area += cross(cross_points[k] - cross_points[0], + cross_points[k + 1] - cross_points[0]); + } + + return abs(area) / static_cast(2.0); + } + + __aicore__ inline float iou_bev(const LocalTensor &boxATensor, + const LocalTensor &boxBTensor) + { + // params box_a: [x, y, z, dx, dy, dz, heading] + // params box_b: [x, y, z, dx, dy, dz, heading] + float sa = boxATensor.GetValue(3) * boxATensor.GetValue(4); + float sb = boxBTensor.GetValue(3) * boxBTensor.GetValue(4); + float s_overlap = box_overlap(boxATensor, boxBTensor); + return s_overlap / max(sa + sb - s_overlap, EPS); + } + +private: + TPipe pipe; + TQue inQueueCur, inQueueBox; + TQue outQueueMask, oneMask; + TBuf calcBuf; + TBuf comBuf; + + TBuf p1Buf, p2Buf, q1Buf, q2Buf; + TBuf angleBuf, sinBuf, cosBuf, pointBuf; + TBuf min1Buf, min2Buf, max1Buf, max2Buf; + + GlobalTensor boxGm; + GlobalTensor maskGm; + LocalTensor curTemp, boxTemp; + uint32_t usedCoreNum; + uint32_t loopTime; + uint32_t eachSum; + uint32_t boxNum; + uint32_t tailSum; + uint32_t tailNum; + uint32_t maskNum; + uint32_t dataAlign = 16; + float overlapThresh; + bool isLastCore; +}; + +extern "C" __global__ __aicore__ + +void nms3d(GM_ADDR boxes, GM_ADDR mask, + GM_ADDR workspace, GM_ADDR tiling) +{ + GET_TILING_DATA(tilingData, tiling); + const Nms3dTilingData *__restrict tilingDevice = &tilingData; + if (TILING_KEY_IS(1)) { + KernelNms3d op; + op.Init(boxes, mask, tilingDevice); + op.Process(); + } else if (TILING_KEY_IS(2)) { + KernelNms3d op; + op.Init(boxes, mask, tilingDevice); + op.Process(); + } +} \ No newline at end of file diff --git a/ads/common/ops/npu_nms3d.py b/ads/common/ops/npu_nms3d.py new file mode 100644 index 00000000..45731a98 --- /dev/null +++ b/ads/common/ops/npu_nms3d.py @@ -0,0 +1,20 @@ +import torch +from torch.autograd import Function + +import torch_npu +import ads_c + + +class Nms3dFunction(Function): + @staticmethod + def forward(ctx, boxes, scores, iou_threshold: float): + if boxes.shape[1] != 7: + raise Exception('Input boxes shape should be (N, 7)') + order = scores.sort(0, descending=True)[1] + boxes = boxes[order].contiguous() + + keep, num_out = ads_c.nms3d(boxes, iou_threshold) + return order[keep[:num_out].long()].contiguous() + + +npu_nms3d = Nms3dFunction.apply diff --git a/tests/torch/test_npu_nms3d.py b/tests/torch/test_npu_nms3d.py new file mode 100644 index 00000000..2046b393 --- /dev/null +++ b/tests/torch/test_npu_nms3d.py @@ -0,0 +1,272 @@ +import unittest +from math import cos, sin, fabs, atan2 +from typing import List + +import numpy as np +import torch +import torch_npu +from torch_npu.testing.common_utils import create_common_tensor +from torch_npu.testing.testcase import TestCase, run_tests + +import ads.common + +torch.npu.config.allow_internal_format = False +torch_npu.npu.set_compile_mode(jit_compile=False) +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] +EPS = 1e-8 + + +class Point: + def __init__(self, x=0.0, y=0.0): + self.x = x + self.y = y + + def set(self, _x: float, _y: float): + self.x = _x + self.y = _y + + def __add__(self, other): + x = self.x + other.x + y = self.y + other.y + return Point(x, y) + + def __sub__(self, other): + x = self.x - other.x + y = self.y - other.y + return Point(x, y) + + +def cross(p1: Point, p2: Point, p0: Point) -> float: + if p0 is None: + return p1.x * p2.y - p1.y * p2.x + return (p1.x - p0.x) * (p2.y - p0.y) - (p2.x - p0.x) * (p1.y - p0.y) + + +def check_rect_cross(p1: Point, p2: Point, q1: Point, q2: Point) -> bool: + ret = min(p1.x, p2.x) <= max(q1.x, q2.x) and \ + min(q1.x, q2.x) <= max(p1.x, p2.x) and \ + min(p1.y, p2.y) <= max(q1.y, q2.y) and \ + min(q1.y, q2.y) <= max(p1.y, p2.y) + return ret + + +def box_overlap(box_a: List[float], box_b: List[float]): + a_angle = box_a[6] + b_angle = box_b[6] + a_dx_half = box_a[3] / 2 + b_dx_half = box_b[3] / 2 + a_dy_half = box_a[4] / 2 + b_dy_half = box_b[4] / 2 + a_x1 = box_a[0] - a_dx_half + a_y1 = box_a[1] - a_dy_half + a_x2 = box_a[0] + a_dx_half + a_y2 = box_a[1] + a_dy_half + b_x1 = box_b[0] - b_dx_half + b_y1 = box_b[1] - b_dy_half + b_x2 = box_b[0] + b_dx_half + b_y2 = box_b[1] + b_dy_half + + center_a = Point(box_a[0], box_a[1]) + center_b = Point(box_b[0], box_b[1]) + + box_a_corners = [Point()] * 5 + box_a_corners[0] = Point(a_x1, a_y1) + box_a_corners[1] = Point(a_x2, a_y1) + box_a_corners[2] = Point(a_x2, a_y2) + box_a_corners[3] = Point(a_x1, a_y2) + + box_b_corners = [Point()] * 5 + box_b_corners[0] = Point(b_x1, b_y1) + box_b_corners[1] = Point(b_x2, b_y1) + box_b_corners[2] = Point(b_x2, b_y2) + box_b_corners[3] = Point(b_x1, b_y2) + # get oriented corners + a_angle_cos = cos(a_angle) + a_angle_sin = sin(a_angle) + + b_angle_cos = cos(b_angle) + b_angle_sin = sin(b_angle) + for k in range(4): + rotate_point_a = rotate_around_center(center_a, a_angle_cos, a_angle_sin, box_a_corners[k]) + box_a_corners[k] = rotate_point_a + rotate_point_b = rotate_around_center(center_b, b_angle_cos, b_angle_sin, box_b_corners[k]) + box_b_corners[k] = rotate_point_b + box_a_corners[4] = box_a_corners[0] + box_b_corners[4] = box_b_corners[0] + cross_points = [Point()] * 16 + poly_center = Point(0, 0) + cnt = 0 + flag = 0 + for i in range(4): + for j in range(4): + flag, ans_point = intersection(box_a_corners[i + 1], box_a_corners[i], + box_b_corners[j + 1], box_b_corners[j]) + + cross_points[cnt] = ans_point + + if flag: + poly_center = poly_center + cross_points[cnt] + cnt += 1 + # check corners + for k in range(4): + if check_in_box2d(box_a, box_b_corners[k]): + poly_center = poly_center + box_b_corners[k] + cross_points[cnt] = box_b_corners[k] + cnt += 1 + if check_in_box2d(box_b, box_a_corners[k]): + poly_center = poly_center + box_a_corners[k] + cross_points[cnt] = box_a_corners[k] + cnt += 1 + + if cnt != 0: + poly_center.x /= cnt + poly_center.y /= cnt + # sort the points of polygon + + for j in range(cnt - 1): + for i in range(cnt - j - 1): + if point_cmp(cross_points[i], cross_points[i + 1], poly_center): + temp = cross_points[i] + cross_points[i] = cross_points[i + 1] + cross_points[i + 1] = temp + + # get the overlap areas + area = 0 + for k in range(cnt - 1): + v1 = cross_points[k] - cross_points[0] + v2 = cross_points[k + 1] - cross_points[0] + area += cross(v1, v2, None) + return fabs(area) / 2.0 + + +def rotate_around_center(center: Point, angle_cos: float, angle_sin: float, p: Point) -> Point: + new_x = (p.x - center.x) * angle_cos - (p.y - center.y) * angle_sin + center.x + new_y = (p.x - center.x) * angle_sin + (p.y - center.y) * angle_cos + center.y + p.set(new_x, new_y) + return p + + +def check_in_box2d(box: List[float], p: Point): + # params: box (7) [x, y, z, dx, dy, dz, heading] + MARGIN = 1e-2 + + center_x = box[0] + center_y = box[1] + # rotate the point in the opposite direction of box + angle_cos = cos(-box[6]) + angle_sin = sin(-box[6]) + rot_x = (p.x - center_x) * angle_cos + (p.y - center_y) * (-angle_sin) + rot_y = (p.x - center_x) * angle_sin + (p.y - center_y) * angle_cos + + return fabs(rot_x) < box[3] / 2 + MARGIN and fabs(rot_y) < box[4] / 2 + MARGIN + + +def point_cmp(a: Point, b: Point, center: Point): + return atan2(a.y - center.y, a.x - center.x) > atan2(b.y - center.y, b.x - center.x) + + +def intersection(p1: Point, p0: Point, q1: Point, q0: Point): + ans_point = Point() + # fast exclusion + if check_rect_cross(p0, p1, q0, q1) == 0: + return 0, ans_point + + # check cross standing + s1 = cross(q0, p1, p0) + s2 = cross(p1, q1, p0) + s3 = cross(p0, q1, q0) + s4 = cross(q1, p1, q0) + + if not (s1 * s2 > 0 and s3 * s4 > 0): + return 0, ans_point + + # calculate intersection of two lines + s5 = cross(q1, p1, p0) + if fabs(s5 - s1) > EPS: + try: + ans_point.x = (s5 * q0.x - s1 * q1.x) / (s5 - s1) + ans_point.y = (s5 * q0.y - s1 * q1.y) / (s5 - s1) + except ZeroDivisionError as e: + print("intersection value can not be 0.") + else: + a0 = p0.y - p1.y + b0 = p1.x - p0.x + c0 = p0.x * p1.y - p1.x * p0.y + a1 = q0.y - q1.y + b1 = q1.x - q0.x + c1 = q0.x * q1.y - q1.x * q0.y + + D = a0 * b1 - a1 * b0 + if D != 0: + ans_point.x = (b0 * c1 - b1 * c0) / D + ans_point.y = (a1 * c0 - a0 * c1) / D + + return 1, ans_point + + +def iou_bev(box_a: List[float], box_b: List[float]): + # params box_a: [x, y, z, dx, dy, dz, heading] + # params box_b: [x, y, z, dx, dy, dz, heading] + sa = box_a[3] * box_a[4] + sb = box_b[3] * box_b[4] + s_overlap = box_overlap(box_a, box_b) + max_val = max(sa + sb - s_overlap, EPS) + try: + result = s_overlap / max_val + except ZeroDivisionError as e: + print("value of area union can not be 0.") + return result + + +class TestNms3d(TestCase): + def cpu_to_exec(self, boxes, scores, threshold=0.0): + boxes = boxes.numpy() + scores_npu = scores.npu() + order_npu = scores_npu.sort(0, descending=True)[1] + order_cpu = order_npu.cpu() + order = order_cpu.numpy() + boxes = boxes.take(order, 0) + keep, num_out = self.cpu_nms_forward(boxes, threshold) + keep = keep.astype(np.int64) + keep = order[keep[:num_out]] + return torch.from_numpy(keep) + + def cpu_nms_forward(self, boxes, nms_overlap_thresh=0.0): + mask = np.ones(boxes.shape[0], dtype=int) + keep = -np.ones(boxes.shape[0]) + num_out = 0 + for i in range(0, boxes.shape[0]): + if mask[i] == 0: + continue + keep[num_out] = i + num_out += 1 + for j in range(i + 1, boxes.shape[0]): + if iou_bev(boxes[i], boxes[j]) > nms_overlap_thresh: + mask[j] = 0 + return keep, num_out + + def npu_to_exec(self, boxes, scores, threshold=0.0): + keep = ads.common.npu_nms3d(boxes, scores, threshold) + return keep.cpu() + + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `Nms3d` is only supported on 910B, skip this ut!") + def test_nms3d(self): + shape_format = [ + [[np.float32, -1, [5, 7]], [np.float32, -1, [5]], 0.1], + [[np.float32, -1, [100, 7]], [np.float32, -1, [100]], 0.2], + [[np.float32, -1, [500, 7]], [np.float32, -1, [500]], 0.3], + [[np.float16, -1, [5, 7]], [np.float16, -1, [5]], 0.1], + [[np.float16, -1, [100, 7]], [np.float16, -1, [100]], 0.2], + [[np.float16, -1, [500, 7]], [np.float16, -1, [500]], 0.3], + ] + for item in shape_format: + boxes_cpu, boxes_npu = create_common_tensor(item[0], 0, 10) + scores_cpu, scores_npu = create_common_tensor(item[1], 0, 1) + threshold = item[2] + out_cpu = self.cpu_to_exec(boxes_cpu, scores_cpu, threshold) + out_npu = self.npu_to_exec(boxes_npu, scores_npu, threshold) + self.assertRtolEqual(out_cpu, out_npu) + + +if __name__ == '__main__': + run_tests() -- Gitee From 39314df873c5f5343ce5b8594ec484036619c0ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E6=99=93=E6=AD=86?= Date: Tue, 20 Feb 2024 12:28:43 +0000 Subject: [PATCH 02/12] =?UTF-8?q?!54=20points=5Fin=5Fbox=20=E6=80=A7?= =?UTF-8?q?=E8=83=BD=E4=BC=98=E5=8C=96=20Merge=20pull=20request=20!54=20fr?= =?UTF-8?q?om=20=E7=8E=8B=E6=99=93=E6=AD=86/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ops/kernels/op_kernel/points_in_box.cpp | 118 +++++++----------- 1 file changed, 48 insertions(+), 70 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/points_in_box.cpp b/ads/common/ops/kernels/op_kernel/points_in_box.cpp index 2c80e479..7135f762 100644 --- a/ads/common/ops/kernels/op_kernel/points_in_box.cpp +++ b/ads/common/ops/kernels/op_kernel/points_in_box.cpp @@ -74,7 +74,13 @@ public: private: __aicore__ inline void Compute(int32_t progress, int32_t tensor_size, uint64_t address) { - LocalTensor boxesLocal = inQueueBOXES.AllocTensor(); + LocalTensor boxesLocal_cx = inQueueBOXES.AllocTensor(); + LocalTensor boxesLocal_cy = boxesLocal_cx[this->available_ub_size]; + LocalTensor boxesLocal_cz = boxesLocal_cx[this->available_ub_size * 2]; + LocalTensor boxesLocal_dx = boxesLocal_cx[this->available_ub_size * 3]; + LocalTensor boxesLocal_dy = boxesLocal_cx[this->available_ub_size * 4]; + LocalTensor boxesLocal_dz = boxesLocal_cx[this->available_ub_size * 5]; + LocalTensor boxesLocal_rz = boxesLocal_cx[this->available_ub_size * 6]; LocalTensor pointLocal = inQueuePTS.AllocTensor(); LocalTensor zLocal = outQueueOUTPUT.AllocTensor(); LocalTensor shiftx = shiftxque.Get(); @@ -93,133 +99,105 @@ private: uint64_t mask = 64; DataCopyParams copyParams_out{1, (uint16_t)(tensor_size * sizeof(DTYPE_BOXES_IDX_OF_POINTS)), 0, 0}; DataCopyParams copyParams_in{1, (uint16_t)(tensor_size * 3 * sizeof(DTYPE_BOXES)), 0, 0}; + DataCopyParams copyParams_box{1, (uint16_t)(this->box_number * sizeof(DTYPE_BOXES)), 0, 0}; DataCopyPadParams padParams{true, 0, 0, 0}; DataCopyPad(pointLocal, ptsGm[address * 3], copyParams_in, padParams); Duplicate(zLocal, oneminsnumber, tensor_size); + DataCopyPad(boxesLocal_cx, boxesGm, copyParams_box, padParams); + DataCopyPad(boxesLocal_cy, boxesGm[this->box_number], copyParams_box, padParams); + DataCopyPad(boxesLocal_cz, boxesGm[this->box_number*2], copyParams_box, padParams); + DataCopyPad(boxesLocal_dx, boxesGm[this->box_number*3], copyParams_box, padParams); + DataCopyPad(boxesLocal_dy, boxesGm[this->box_number*4], copyParams_box, padParams); + DataCopyPad(boxesLocal_dz, boxesGm[this->box_number*5], copyParams_box, padParams); + DataCopyPad(boxesLocal_rz, boxesGm[this->box_number*6], copyParams_box, padParams); + set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); for (int32_t i = 0; i < tensor_size; i++) { if (zLocal.GetValue(i) == -1) { uint32_t batch_id = i / this->npoints; auto x = pointLocal.GetValue(i * 3); auto y = pointLocal.GetValue(i * 3 + 1); auto z = pointLocal.GetValue(i * 3 + 2); - set_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); - DataCopyPad(boxesLocal, boxesGm[batch_id * this->box_number * 7], copyParams_in, padParams); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + int repeat = (tensor_size + mask - 1) / mask; + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + // shift_x = x - boxes_ub[ :, 0] - Muls(shiftx, boxesLocal, oneminsnumber, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); - Adds(shiftx, shiftx, x, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); - set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Muls(shiftx, boxesLocal_cx, oneminsnumber, mask, repeat, { 1, 1, 8, 8 }); + Adds(shiftx, shiftx, x, mask, repeat, { 1, 1, 8, 8 }); // shift_y = y - boxes_ub[ :, 1] - DataCopyPad(boxesLocal, boxesGm[batch_id * this->box_number * 7 + this->box_number], - copyParams_in, padParams); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - - Muls(shifty, boxesLocal, oneminsnumber, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); - Adds(shifty, shifty, y, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); - set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Muls(shifty, boxesLocal_cy, oneminsnumber, mask, repeat, { 1, 1, 8, 8 }); + Adds(shifty, shifty, y, mask, repeat, { 1, 1, 8, 8 }); // cosa = Cos(-boxes_ub[ :, 6]) - DataCopyPad(boxesLocal, boxesGm[batch_id * this->box_number * 7 + this->box_number*6], - copyParams_in, padParams); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - Muls(temp, boxesLocal, oneminsnumber, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); + Muls(temp, boxesLocal_rz, oneminsnumber, mask, repeat, { 1, 1, 8, 8 }); Cos(cosa, temp, uint8temp, tensor_size); // sina = Sin(-boxes_ub[ :, 6]) - Muls(temp, boxesLocal, oneminsnumber, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); + Muls(temp, boxesLocal_rz, oneminsnumber, mask, repeat, { 1, 1, 8, 8 }); Sin(sina, temp, uint8temp, (tensor_size + mask - 1)); // local_x = shift_x * cosa + shift_y * (-sina) - Mul(temp, shiftx, cosa, mask, (tensor_size + mask - 1)/mask, {1, 1, 1, 8, 8, 8 }); + Mul(temp, shiftx, cosa, mask, repeat, {1, 1, 1, 8, 8, 8 }); Duplicate(xlocal, zeronumber, tensor_size); - Add(xlocal, xlocal, temp, mask, (tensor_size + mask - 1)/mask, {1, 1, 1, 8, 8, 8 }); - Muls(temp, sina, oneminsnumber, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); - Mul(temp, shifty, temp, mask, (tensor_size + mask - 1)/mask, {1, 1, 1, 8, 8, 8 }); - Add(xlocal, xlocal, temp, mask, (tensor_size + mask - 1)/mask, {1, 1, 1, 8, 8, 8 }); + Add(xlocal, xlocal, temp, mask, repeat, {1, 1, 1, 8, 8, 8 }); + Muls(temp, sina, oneminsnumber, mask, repeat, { 1, 1, 8, 8 }); + Mul(temp, shifty, temp, mask, repeat, {1, 1, 1, 8, 8, 8 }); + Add(xlocal, xlocal, temp, mask, repeat, {1, 1, 1, 8, 8, 8 }); // local_y = shift_x * sina + shift_y * cosa - Mul(temp, shiftx, sina, mask, (tensor_size + mask - 1)/mask, {1, 1, 1, 8, 8, 8 }); - Mul(sina, shifty, cosa, mask, (tensor_size + mask - 1)/mask, {1, 1, 1, 8, 8, 8 }); + Mul(temp, shiftx, sina, mask, repeat, {1, 1, 1, 8, 8, 8 }); + Mul(sina, shifty, cosa, mask, repeat, {1, 1, 1, 8, 8, 8 }); Add(ylocal, sina, temp, tensor_size); - Abs(xlocal, xlocal, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); + Abs(xlocal, xlocal, mask, repeat, { 1, 1, 8, 8 }); pipe_barrier(PIPE_V); - Abs(ylocal, ylocal, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); + Abs(ylocal, ylocal, mask, repeat, { 1, 1, 8, 8 }); // zlocal = z-cz sina - set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - DataCopyPad(boxesLocal, boxesGm[batch_id * this->box_number * 7 + this->box_number*2], - copyParams_in, padParams); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - Muls(sina, boxesLocal, oneminsnumber, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); - Adds(sina, sina, z, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); - Abs(sina, sina, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); + Muls(sina, boxesLocal_cz, oneminsnumber, mask, repeat, { 1, 1, 8, 8 }); + Adds(sina, sina, z, mask, repeat, { 1, 1, 8, 8 }); + Abs(sina, sina, mask, repeat, { 1, 1, 8, 8 }); // z_size + 1e-5 cosa - set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - DataCopyPad(boxesLocal, boxesGm[batch_id * this->box_number * 7 + this->box_number*5], - copyParams_in, padParams); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - Muls(cosa, boxesLocal, halfnumber, (tensor_size + mask - 1)/mask); + Muls(cosa, boxesLocal_dz, halfnumber, mask, repeat, { 1, 1, 8, 8 }); pipe_barrier(PIPE_ALL); // x_size + 1e-5 shiftx - DataCopyPad(boxesLocal, boxesGm[batch_id * this->box_number * 7 + this->box_number*3], - copyParams_in, padParams); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - Muls(shiftx, boxesLocal, halfnumber, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); - set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); - wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + Muls(shiftx, boxesLocal_dx, halfnumber, mask, repeat, { 1, 1, 8, 8 }); // y_size + 1e-5 shifty - DataCopyPad(boxesLocal, boxesGm[batch_id * this->box_number * 7 + this->box_number*4], - copyParams_in, padParams); - set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); - - Muls(shifty, boxesLocal, halfnumber, mask, (tensor_size + mask - 1)/mask, { 1, 1, 8, 8 }); + Muls(shifty, boxesLocal_dy, halfnumber, mask, repeat, { 1, 1, 8, 8 }); set_flag(PIPE_V, PIPE_S, EVENT_ID0); wait_flag(PIPE_V, PIPE_S, EVENT_ID0); uint64_t mask = 256/sizeof(float); - int repeat = tensor_size/mask; BinaryRepeatParams repeatParams = { 1, 1, 1, 8, 8, 8 }; set_flag(PIPE_S, PIPE_V, EVENT_ID0); wait_flag(PIPE_S, PIPE_V, EVENT_ID0); // dup full zeronumber tensor - Duplicate(boxesLocal, zeronumber, tensor_size); + Duplicate(boxesLocal_cx, zeronumber, tensor_size); // dup full onenumber tensor Duplicate(temp, onenumber, tensor_size); // cmp_1 = Abs(local_x) < x_size + 1e-5 uint8temp = xlocal <= shiftx; Duplicate(xlocal, zeronumber, tensor_size); - Select(xlocal, uint8temp, temp, boxesLocal, + Select(xlocal, uint8temp, temp, boxesLocal_cx, SELMODE::VSEL_TENSOR_TENSOR_MODE, mask, repeat, repeatParams); // cmp_2 = Abs(local_y) < y_size+ 1e-5 uint8temp = ylocal <= shifty; Duplicate(ylocal, zeronumber, tensor_size); - Select(ylocal, uint8temp, temp, boxesLocal, + Select(ylocal, uint8temp, temp, boxesLocal_cx, SELMODE::VSEL_TENSOR_TENSOR_MODE, mask, repeat, repeatParams); - // cmp_3 = Abs(zlocal) < z_size + 1e-5 + // cmp_3 = Abs(zlocal) < z_size uint8temp = sina <= cosa; Duplicate(sina, zeronumber, tensor_size); - Select(sina, uint8temp, temp, boxesLocal, + Select(sina, uint8temp, temp, boxesLocal_cx, SELMODE::VSEL_TENSOR_TENSOR_MODE, mask, repeat, repeatParams); pipe_barrier(PIPE_V); @@ -238,7 +216,7 @@ private: } pipe_barrier(PIPE_ALL); DataCopyPad(outputGm[address], zLocal, copyParams_out); - inQueuePTS.FreeTensor(boxesLocal); + inQueuePTS.FreeTensor(boxesLocal_cx); inQueueBOXES.FreeTensor(pointLocal); outQueueOUTPUT.FreeTensor(zLocal); } -- Gitee From 44b0168d22fc2fea0e1b2666bf12a244aa8d94f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=94=90=E7=87=95=E9=94=8B?= Date: Wed, 21 Feb 2024 06:43:59 +0000 Subject: [PATCH 03/12] =?UTF-8?q?!59=20=E5=8C=85=E5=90=8D=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=82=B9=E8=BF=9B=E8=A1=8C=E5=88=86=E5=89=B2=20Merge?= =?UTF-8?q?=20pull=20request=20!59=20from=20=E5=94=90=E7=87=95=E9=94=8B/ma?= =?UTF-8?q?ster?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3a2ca028..fb27eede 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,7 @@ def get_sha(pytorch_root: Union[str, Path]) -> str: except Exception: return "Unknown" -VERSION = "1.0_" + torch.__version__[0:6] +VERSION = "1.0." + torch.__version__[0:6] torch_npu_root = Path(__file__).parent sha = get_sha(torch_npu_root) if not os.getenv("BUILD_WITHOUT_SHA"): -- Gitee From 29732ef8cdecc87cfad828fe1bc44c8d7bc8d843 Mon Sep 17 00:00:00 2001 From: l00636998 Date: Wed, 21 Feb 2024 06:49:43 +0000 Subject: [PATCH 04/12] =?UTF-8?q?!28=20=E6=96=B0=E5=A2=9EFurthesPointSampl?= =?UTF-8?q?ing=E7=AE=97=E5=AD=90=20Merge=20pull=20request=20!28=20from=20l?= =?UTF-8?q?00636998/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ads/common/__init__.py | 1 + .../csrc/FurthestPointSamplingKernelNpu.cpp | 27 + ads/common/ops/csrc/functions.h | 1 + ads/common/ops/csrc/pybind.cpp | 5 +- ads/common/ops/furthest_point_sampling.py | 22 + ads/common/ops/kernels/inc/base.h | 9 + .../op_host/furthest_point_sampling.cpp | 335 ++++++++++ .../op_host/furthest_point_sampling_tiling.h | 44 ++ .../op_kernel/furthest_point_sampling.cpp | 587 ++++++++++++++++++ .../op_kernel/furthest_point_sampling.h | 139 +++++ tests/test_furthest_point_sampling.py | 126 ++++ 11 files changed, 1295 insertions(+), 1 deletion(-) create mode 100644 ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp create mode 100644 ads/common/ops/furthest_point_sampling.py create mode 100644 ads/common/ops/kernels/op_host/furthest_point_sampling.cpp create mode 100644 ads/common/ops/kernels/op_host/furthest_point_sampling_tiling.h create mode 100644 ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp create mode 100644 ads/common/ops/kernels/op_kernel/furthest_point_sampling.h create mode 100644 tests/test_furthest_point_sampling.py diff --git a/ads/common/__init__.py b/ads/common/__init__.py index d20c765a..ee34c193 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -31,3 +31,4 @@ from .ops.dynamic_voxelization import voxelization from .ops.dynamic_voxelization import Voxelization from .ops.nms3d_normal import npu_nms3d_normal from .ops.npu_nms3d import npu_nms3d +from .ops.furthest_point_sampling import npu_furthest_point_sampling diff --git a/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp b/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp new file mode 100644 index 00000000..29d775fe --- /dev/null +++ b/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp @@ -0,0 +1,27 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "OpApiCommon.h" +#include "functions.h" + +at::Tensor npu_furthest_point_sampling(const at::Tensor &point_xyz, const at::Tensor &nearset_temp, const int32_t num_points) +{ + at::Tensor output = at::empty({static_cast(point_xyz.sizes()[0]), static_cast(num_points)}, + nearset_temp.options().dtype(at::kInt)); + EXEC_NPU_CMD(aclnnFurthestPointSampling, point_xyz, nearset_temp, num_points, output); + return output; +} \ No newline at end of file diff --git a/ads/common/ops/csrc/functions.h b/ads/common/ops/csrc/functions.h index 58c4ad8b..8ca1e309 100644 --- a/ads/common/ops/csrc/functions.h +++ b/ads/common/ops/csrc/functions.h @@ -144,6 +144,7 @@ at::Tensor npu_multi_scale_deformable_attn_function(const at::Tensor& value, const at::Tensor& attention_weights); std::tuple multi_scale_deformable_attn_grad(const at::Tensor& value, const at::Tensor& shape, const at::Tensor& level_start_index, const at::Tensor& location, const at::Tensor& attn_weight, const at::Tensor& grad_output); +at::Tensor npu_furthest_point_sampling(const at::Tensor &point_xyz, const at::Tensor &nearset_temp, const int32_t num_points); at::Tensor npu_ads_add(const at::Tensor &tensor1, const at::Tensor &tensor2); at::Tensor DynamicVoxelization( diff --git a/ads/common/ops/csrc/pybind.cpp b/ads/common/ops/csrc/pybind.cpp index f3a8fb45..985d8fbd 100644 --- a/ads/common/ops/csrc/pybind.cpp +++ b/ads/common/ops/csrc/pybind.cpp @@ -91,10 +91,13 @@ void init_common(pybind11::module &m) // dyn_voxelization m.def("dynamic_voxelization", &DynamicVoxelization); - + // nms3d_normal m.def("nms3d_normal", &nms3d_normal); + // npu_furthest_point_sampling + m.def("npu_furthest_point_sampling", &npu_furthest_point_sampling); + // ads_nms3d m.def("nms3d", &nms3d); } diff --git a/ads/common/ops/furthest_point_sampling.py b/ads/common/ops/furthest_point_sampling.py new file mode 100644 index 00000000..2ad30a54 --- /dev/null +++ b/ads/common/ops/furthest_point_sampling.py @@ -0,0 +1,22 @@ +import numpy as np +import torch +from torch.autograd import Function +from torch.nn import Module + +import torch_npu +import ads_c + + +class AdsFurthestPointSampling(Function): + @staticmethod + def forward(ctx, point_xyz, num_points): + B, N = point_xyz.size()[:2] + point_xyz = point_xyz.permute(0, 2, 1).contiguous() + + nearest_dist = torch.tensor(np.ones((B, N)) * 1e10, dtype=torch.float32, device='npu').contiguous() + output = ads_c.npu_furthest_point_sampling(point_xyz, nearest_dist, num_points) + + return output + + +npu_furthest_point_sampling = AdsFurthestPointSampling.apply \ No newline at end of file diff --git a/ads/common/ops/kernels/inc/base.h b/ads/common/ops/kernels/inc/base.h index 3c853e6c..b0fe79bf 100644 --- a/ads/common/ops/kernels/inc/base.h +++ b/ads/common/ops/kernels/inc/base.h @@ -3,14 +3,23 @@ // .INPUT(x2, TensorType({DT_FLOAT})) // .OUTPUT(y, TensorType({DT_FLOAT})) // .OP_END_FACTORY_REG(Add) + // REG_OP(FurthestPointSamplingWithDist) // .INPUT(points_dist, TensorType({DT_FLOAT})) // .INPUT(nearest_temp, TensorType({DT_FLOAT})) // .OUTPUT(index, TensorType({DT_INT32})) // .REQUIRED_ATTR(num_points, Int) // .OP_END_FACTORY_REG(FurthestPointSamplingWithDist) + // REG_OP(Nms3dNormal) // .INPUT(boxes, TensorType({DT_FLOAT, DT_FLOAT16})) // .OUTPUT(keep, TensorType({DT_INT16})) // .REQUIRED_ATTR(nms_overlap_thresh, Float) // .OP_END_FACTORY_REG(Nms3dNormal) + +// REG_OP(FurthestPointSampling) +// .INPUT(point_xyz, TensorType({DT_FLOAT})) +// .INPUT(nearest_temp, TensorType({DT_FLOAT})) +// .OUTPUT(index, TensorType({DT_INT32})) +// .REQUIRED_ATTR(num_points, Int) +// .OP_END_FACTORY_REG(FurthestPointSampling) diff --git a/ads/common/ops/kernels/op_host/furthest_point_sampling.cpp b/ads/common/ops/kernels/op_host/furthest_point_sampling.cpp new file mode 100644 index 00000000..b8e4bfc5 --- /dev/null +++ b/ads/common/ops/kernels/op_host/furthest_point_sampling.cpp @@ -0,0 +1,335 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file furthest_point_sampling.cc + * \brief + */ +#include "tiling/platform/platform_ascendc.h" +#include "tiling/tiling_api.h" +#include "register/op_def_registry.h" +#include "furthest_point_sampling_tiling.h" + +using namespace ge; +using namespace std; +using namespace AscendC; + +namespace optiling { +/****************constexpr definition*****************/ +constexpr int64_t FP32_MODE = 0; + +/****************struct definition*****************/ +struct ub_memory_tag { + uint64_t ub_size = 0; + uint64_t ub_reserve = 2048; + // 8 :dev into 8 pieces(point_x, point_y, point_z, temp, distance, pointTempX, pointTempY, pointTempZ, store N data each) + uint64_t ub_data_blocks = 8; +}; + +/****************function definition*****************/ +template +inline T getSmallestMulVal(T data, T multiple) +{ + if (multiple == 0) { + return 0; + } + return ((data + multiple - 1) / multiple); +} + +template +inline T getSmallestMul(T data, T multiple) +{ + return (getSmallestMulVal(data, multiple) * multiple); +} + +/****************class definition*****************/ +class FurthestPointSamplingTiling { +public: + explicit FurthestPointSamplingTiling(gert::TilingContext* context) : TilingContext(context) {}; + ge::graphStatus Init(); + ge::graphStatus RunKernelTiling(); +private: + inline void SetTilingKeyMode(ge::DataType dType); + inline uint64_t UbBlocksDataSpace(uint64_t data_num); + inline uint64_t UbBlocksWorkSpace(uint64_t data_num); + inline uint64_t UbBlocksSpace(uint64_t data_num); + inline uint64_t FindMaxDataBlock(); + // in Kernel, we use ReduceMax requiring us to calc size of worklocal + inline uint32_t calcWorkLocalSize(uint32_t max_repeat_times); +private: + FurthestPointSamplingTilingData TilingData; + gert::TilingContext* TilingContext = nullptr; + + uint32_t coreNum; + + uint32_t batch; + uint32_t N; + uint32_t numPoints; + uint32_t pieces; + uint32_t formerNum; + uint32_t tailNum; + uint32_t workSize; + uint32_t idxTempSize; + uint32_t bigCoreBatch; + uint32_t smallCoreBatch; + uint32_t bigCoreNum; + uint32_t repeats; + + ub_memory_tag ub_memory; + + uint64_t point_dtype_size; +}; + +/****************class impl*****************/ +ge::graphStatus FurthestPointSamplingTiling::Init() +{ + const gert::StorageShape *point_xyz_shape = TilingContext->GetInputShape(0); + const gert::RuntimeAttrs *attrs = TilingContext->GetAttrs(); + uint64_t max_data_num; + + auto platformInfoPtr = TilingContext->GetPlatformInfo(); + if (platformInfoPtr == nullptr) { + return ge::GRAPH_FAILED; + } + + auto platformInfo = platform_ascendc::PlatformAscendC(platformInfoPtr); + + // Set Tiling Key + SetTilingKeyMode(TilingContext->GetInputDesc(0)->GetDataType()); + + // get core num + this->coreNum = platformInfo.GetCoreNumAiv(); + if (this->coreNum == 0) { + return ge::GRAPH_FAILED; + } + + // get ub_size,cal the capability that is aligned with 256 bytes + platformInfo.GetCoreMemSize(platform_ascendc::CoreMemType::UB, this->ub_memory.ub_size); + + // Get input args + this->batch = point_xyz_shape->GetStorageShape().GetDim(0); + this->N = point_xyz_shape->GetStorageShape().GetDim(2); + this->numPoints = *(attrs->GetAttrPointer(0)); + + // get the capability on UB + max_data_num = FindMaxDataBlock(); // pieces, repeats, workSize calc in this func + + if (this->repeats == 0) { + return ge::GRAPH_FAILED; + } + + // Tiling Args calc + this->bigCoreBatch = getSmallestMulVal(this->batch, this->coreNum); + this->smallCoreBatch = this->batch / this->coreNum; + if (this->bigCoreBatch == this->smallCoreBatch) { + this->bigCoreNum = this->coreNum; + } else if ((this->bigCoreBatch == 1) && (this->smallCoreBatch == 0)) { + this->bigCoreNum = this->batch; + } else { + this->bigCoreNum = (this->batch - (this->smallCoreBatch * this->coreNum)) / + (this->bigCoreBatch - this->smallCoreBatch); + } + + this->formerNum = ((this->repeats * 256) / this->point_dtype_size); + this->tailNum = this->N - this->formerNum * (this->pieces - 1); + + return ge::GRAPH_SUCCESS; +} + +ge::graphStatus FurthestPointSamplingTiling::RunKernelTiling() +{ + size_t sysWorkspaceSize = 16 * 1024 * 1024; // Alloc 16M workspace + size_t userWorkSpaceSize = this->batch * this->N * this->point_dtype_size; // NearestDist needs a space to be moved out + size_t *currentWorkSpace = TilingContext->GetWorkspaceSizes(1); + currentWorkSpace[0] = userWorkSpaceSize + sysWorkspaceSize; + + TilingData.set_batch(this->batch); + TilingData.set_N(this->N); + TilingData.set_numPoints(this->numPoints); + TilingData.set_pieces(this->pieces); + TilingData.set_formerNum(this->formerNum); + TilingData.set_tailNum(this->tailNum); + TilingData.set_workSize(this->workSize); + TilingData.set_idxTempSize(this->idxTempSize); + TilingData.set_bigCoreBatch(this->bigCoreBatch); + TilingData.set_smallCoreBatch(this->smallCoreBatch); + TilingData.set_bigCoreNum(this->bigCoreNum); + if (this->batch <= this->coreNum) { + TilingContext->SetBlockDim(this->batch); + } else { + TilingContext->SetBlockDim(this->coreNum); + } + TilingData.set_repeats(this->repeats); + + TilingData.SaveToBuffer(TilingContext->GetRawTilingData()->GetData(), TilingContext->GetRawTilingData()->GetCapacity()); + TilingContext->GetRawTilingData()->SetDataSize(TilingData.GetDataSize()); + + return ge::GRAPH_SUCCESS; +} + +inline uint32_t FurthestPointSamplingTiling::calcWorkLocalSize(uint32_t max_repeat_times) +{ + uint32_t elemPerBlock = 32 / this->point_dtype_size; // num of data one block can store + uint32_t elemPerRepeat = 256 / this->point_dtype_size; // num of data one repeat can deal + + uint32_t iter1OutputCount = max_repeat_times * 2; // num of temp data in 1st stage + uint32_t iter2AlignStart = getSmallestMul(iter1OutputCount, elemPerBlock); // align with 32 bytes + uint32_t iter2OutputCount = getSmallestMulVal(iter1OutputCount, elemPerRepeat) * 2; // num of temp data in 2nd stage + uint32_t iter3AlignStart = getSmallestMul(iter2OutputCount, elemPerBlock); // align with 32 bytes + uint32_t iter3OutputCount = getSmallestMulVal(iter1OutputCount, elemPerRepeat) * 2; // num of temp data in 3rd stage + uint32_t iter3AlignEnd = getSmallestMul(iter2OutputCount, elemPerBlock); // align with 32 bytes + + uint32_t finalWorkLocalNeedSize = iter2AlignStart + iter3AlignStart + iter3AlignEnd; + uint32_t totalBytes = finalWorkLocalNeedSize * this->point_dtype_size; + + if (totalBytes % 32 != 0) { + return ge::GRAPH_FAILED; + } + return totalBytes; +} + +inline uint64_t FurthestPointSamplingTiling::FindMaxDataBlock() +{ + // divide & conquer ==> find the capability, if bigger than N, split then cal + uint64_t M = this->ub_memory.ub_size - this->ub_memory.ub_reserve; + uint64_t low = 1; // at least there exits one data in UB + uint64_t high = M / this->point_dtype_size; + uint64_t max_data_num = 0; + + while (low <= high) { + uint64_t mid = low + (high - low) / 2; + if (UbBlocksSpace(mid) <= M) { + max_data_num = mid; + low = mid + 1; + } else { + high = mid - 1; + } + } + + return max_data_num; +} + +inline void FurthestPointSamplingTiling::SetTilingKeyMode(ge::DataType dType) +{ + switch (dType) { + case ge::DT_FLOAT: + TilingContext->SetTilingKey(FP32_MODE); + this->point_dtype_size = 4; // 4: float32, 4 bytes + break; + default: + TilingContext->SetTilingKey(FP32_MODE); + this->point_dtype_size = 4; // 4: float32, 4 bytes + break; + } +} + +inline uint64_t FurthestPointSamplingTiling::UbBlocksDataSpace(uint64_t data_num) +{ + // data type is the same among the first 5 blocks, the num is data_num, aligned with 256 bytes + return getSmallestMul(this->point_dtype_size * data_num, 256); +} + +inline uint64_t FurthestPointSamplingTiling::UbBlocksWorkSpace(uint64_t data_num) +{ + uint64_t singleBlockRepeats = getSmallestMulVal(this->point_dtype_size * data_num, 256); + uint64_t totalRepeats = getSmallestMulVal(this->point_dtype_size * this->N, 256); + + this->pieces = (uint32_t)getSmallestMulVal(totalRepeats, singleBlockRepeats); + this->repeats = (singleBlockRepeats < totalRepeats) ? ((uint32_t)singleBlockRepeats) : ((uint32_t)totalRepeats); + return (uint64_t)calcWorkLocalSize(this->repeats); +} + +inline uint64_t FurthestPointSamplingTiling::UbBlocksSpace(uint64_t data_num) +{ + uint64_t dataSpace = UbBlocksDataSpace(data_num); + uint64_t workSpace = UbBlocksWorkSpace(data_num); + + this->workSize = workSpace; + this->idxTempSize = getSmallestMul(this->pieces * 2 * this->point_dtype_size, 32); + + return (this->ub_memory.ub_data_blocks * dataSpace + workSpace + this->idxTempSize); +} + +/****************main body*****************/ +static ge::graphStatus TilingFurthestPointSampling(gert::TilingContext* context) +{ + FurthestPointSamplingTiling tilingObject(context); + + tilingObject.Init(); + return tilingObject.RunKernelTiling(); +} +} + +namespace ge { +static ge::graphStatus InfershapeForFurthestPointSampling(gert::InferShapeContext *context) +{ + const gert::Shape *point_xyz_shape = context->GetInputShape(0); + const gert::RuntimeAttrs *attrs = context->GetAttrs(); + gert::Shape *index_shape = context->GetOutputShape(0); + if ((point_xyz_shape == nullptr) || (attrs == nullptr) || (index_shape == nullptr)) { + return ge::GRAPH_FAILED; + } + + uint32_t batch = point_xyz_shape->GetDim(0); + uint32_t N = point_xyz_shape->GetDim(2); + uint32_t num_points = *(attrs->GetAttrPointer(0)); + + index_shape->SetDimNum(2); + index_shape->SetDim(0, batch); + index_shape->SetDim(1, num_points); + + return GRAPH_SUCCESS; +} +} + +namespace ops { +class FurthestPointSampling : public OpDef { +public: + explicit FurthestPointSampling(const char* name) : OpDef(name) + { + this->Input("point_xyz") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Input("nearest_temp") + .ParamType(REQUIRED) + .DataType({ge::DT_FLOAT}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Output("index") + .ParamType(REQUIRED) + .DataType({ge::DT_INT32}) + .Format({ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND}); + this->Attr("num_points") + .AttrType(REQUIRED) + .Int(); + + this->SetInferShape(ge::InfershapeForFurthestPointSampling); + this->AICore().SetTiling(optiling::TilingFurthestPointSampling); + + OpAICoreConfig aicore_config; + aicore_config.DynamicCompileStaticFlag(true) + .DynamicFormatFlag(true) + .DynamicRankSupportFlag(true) + .DynamicShapeSupportFlag(true); + this->AICore().AddConfig("ascend910b", aicore_config); + } +}; + +OP_ADD(FurthestPointSampling); +} \ No newline at end of file diff --git a/ads/common/ops/kernels/op_host/furthest_point_sampling_tiling.h b/ads/common/ops/kernels/op_host/furthest_point_sampling_tiling.h new file mode 100644 index 00000000..3e962288 --- /dev/null +++ b/ads/common/ops/kernels/op_host/furthest_point_sampling_tiling.h @@ -0,0 +1,44 @@ +/** + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/*! + * \file furthest_point_sampling.h + * \brief + */ +#ifndef OPS_BUILT_IN_OP_TILING_RUNTIME_FURTHEST_POINT_SAMPLING_H_ +#define OPS_BUILT_IN_OP_TILING_RUNTIME_FURTHEST_POINT_SAMPLING_H_ +#include "register/tilingdata_base.h" + +namespace optiling { +BEGIN_TILING_DATA_DEF(FurthestPointSamplingTilingData) + TILING_DATA_FIELD_DEF(uint32_t, N); + TILING_DATA_FIELD_DEF(uint32_t, batch); + TILING_DATA_FIELD_DEF(uint32_t, numPoints); + TILING_DATA_FIELD_DEF(uint32_t, pieces); + TILING_DATA_FIELD_DEF(uint32_t, formerNum); + TILING_DATA_FIELD_DEF(uint32_t, tailNum); + TILING_DATA_FIELD_DEF(uint32_t, workSize); + TILING_DATA_FIELD_DEF(uint32_t, idxTempSize); + TILING_DATA_FIELD_DEF(uint32_t, bigCoreBatch); + TILING_DATA_FIELD_DEF(uint32_t, smallCoreBatch); + TILING_DATA_FIELD_DEF(uint32_t, bigCoreNum); + TILING_DATA_FIELD_DEF(uint32_t, repeats); +END_TILING_DATA_DEF; + +REGISTER_TILING_DATA_CLASS(FurthestPointSampling, FurthestPointSamplingTilingData) +} + +#endif // OPS_BUILT_IN_OP_TILING_RUNTIME_FURTHEST_POINT_SAMPLING_H_ \ No newline at end of file diff --git a/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp new file mode 100644 index 00000000..5f53f360 --- /dev/null +++ b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp @@ -0,0 +1,587 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * This file constains code of cpu debug and npu code.We read data from bin file + * and write result to file. + */ +#include "furthest_point_sampling.h" + +using namespace AscendC; + +// Entrance of kernel +extern "C" __global__ __aicore__ void furthest_point_sampling( + GM_ADDR point_xyz, + GM_ADDR temp, + GM_ADDR index, + GM_ADDR workspace, + GM_ADDR tiling) { + GET_TILING_DATA(tiling_data, tiling); + + tilingArgs TA; + // Since type of tiling_data unknown, create a class out of reliability. + TA.N = tiling_data.N; + TA.batch = tiling_data.batch; + TA.numPoints = tiling_data.numPoints; + TA.pieces = tiling_data.pieces; + TA.formerNum = tiling_data.formerNum; + TA.tailNum = tiling_data.tailNum; + TA.workSize = tiling_data.workSize; + TA.idxTempSize = tiling_data.idxTempSize; + TA.bigCoreBatch = tiling_data.bigCoreBatch; + TA.smallCoreBatch = tiling_data.smallCoreBatch; + TA.bigCoreNum = tiling_data.bigCoreNum; + TA.repeats = tiling_data.repeats; + + if (TILING_KEY_IS(0)) { + furthestPointSamplingKernel op(point_xyz, temp, index, workspace, &TA); + op.Process(); + } +} + +template +__aicore__ inline furthestPointSamplingKernel::furthestPointSamplingKernel(GM_ADDR point_xyz, + GM_ADDR temp, GM_ADDR index, GM_ADDR workspace, tilingArgs *tiling) +{ + // Init tiling args. + this->TA = tiling; + // host tiling have ensured formerNum is aligned with 32bytes and bigger than tailNum. + this->sizeofFormer = this->TA->formerNum * sizeof(dataType); + this->sizeofTail = this->TA->tailNum * sizeof(dataType); + this->dataNumIn32Bytes = 32 / sizeof(dataType); + this->dataNumIn64Bytes = 64 / sizeof(dataType); + this->dataNumIn256Bytes = 256 / sizeof(dataType); + this->dataNumIn1024Bytes = 1024 / sizeof(dataType); + // Init GM. + InitGm(point_xyz, temp, index, workspace); + + // Must be aligned with 32bytes. + this->pipe.InitBuffer(this->pointXQue, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointYQue, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointZQue, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointTempXUb, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointTempYUb, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->pointTempZUb, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->nearestDistQue, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->distUb, BUFFER_NUM, this->sizeofFormer); + this->pipe.InitBuffer(this->workUb, BUFFER_NUM, this->TA->workSize); + + this->pipe.InitBuffer(this->idxQue, BUFFER_NUM, 1024); // copy out 256 fp32s once + + this->pipe.InitBuffer(this->idxTempUb, BUFFER_NUM, this->TA->idxTempSize); + this->pipe.InitBuffer(this->pointSampled, BUFFER_NUM, 32 * 3); + // Malloc. + this->ubBlocks.pointXLocal = pointXQue.AllocTensor(); + this->ubBlocks.pointYLocal = pointYQue.AllocTensor(); + this->ubBlocks.pointZLocal = pointZQue.AllocTensor(); + this->ubBlocks.pointTempXLocal = pointTempXUb.AllocTensor(); + this->ubBlocks.pointTempYLocal = pointTempYUb.AllocTensor(); + this->ubBlocks.pointTempZLocal = pointTempZUb.AllocTensor(); + this->ubBlocks.nearestDistLocal = nearestDistQue.AllocTensor(); + this->ubBlocks.distLocal = distUb.AllocTensor(); + this->ubBlocks.workLocal = workUb.AllocTensor(); + + this->ubBlocks.idxLocal = idxQue.AllocTensor(); + + this->ubBlocks.idxTempLocal = idxTempUb.AllocTensor(); + this->ubBlocks.pointSampledLocal = pointSampled.AllocTensor(); +} + +template +__aicore__ inline void furthestPointSamplingKernel::Process() +{ + uint32_t batch_num = (GetBlockIdx() < this->TA->bigCoreNum) ? (this->TA->bigCoreBatch) : (this->TA->smallCoreBatch); + + for (this->core_batch = 0; this->core_batch < batch_num; this->core_batch++) { + this->batchOffsetPoint = this->core_batch * this->TA->N * 3; + this->batchOffsetNearest = this->core_batch * this->TA->N; + // Set:idxGm[0] = 0 + CopyInIdx(0); + if (this->TA->numPoints == 1) { + CopyOut(0); // special case: only one points sampled. + } + if (this->TA->pieces == 1) { + Process_complete_data(); + } else { + Process_split_data(); + } + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyInIdx(uint32_t loopNum) +{ + DataCopyParams data_copy_param = {1, 1, 0, 0}; + uint32_t offsetGmX = this->batchOffsetPoint + this->maxDistIdx; + uint32_t offsetGmY = offsetGmX + this->TA->N; + uint32_t offsetGmZ = offsetGmY + this->TA->N; + uint32_t offsetLocalX = 0; + uint32_t offsetLocalY = this->dataNumIn32Bytes; + uint32_t offsetLocalZ = this->dataNumIn64Bytes; + uint32_t offsetIdx = loopNum & (this->dataNumIn1024Bytes - 1); // aka. loopNum % this->dataNumIn1024Bytes + + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID0); + +#ifndef __GET_CODE_CHANNEL__ + DataCopy(this->ubBlocks.pointSampledLocal[offsetLocalX], pointGm[offsetGmX], data_copy_param); + DataCopy(this->ubBlocks.pointSampledLocal[offsetLocalY], pointGm[offsetGmY], data_copy_param); + DataCopy(this->ubBlocks.pointSampledLocal[offsetLocalZ], pointGm[offsetGmZ], data_copy_param); +#endif + + set_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_S, EVENT_ID0); + + this->ubBlocks.idxLocal.SetValue(offsetIdx, this->maxDistIdx); + this->pointXSampled = -1 * this->ubBlocks.pointSampledLocal.GetValue(offsetLocalX); + this->pointYSampled = -1 * this->ubBlocks.pointSampledLocal.GetValue(offsetLocalY); + this->pointZSampled = -1 * this->ubBlocks.pointSampledLocal.GetValue(offsetLocalZ); + this->maxDist = 0; + this->maxDistIdx = 0; +} + +template +__aicore__ inline void furthestPointSamplingKernel::Process_complete_data() +{ + uint32_t loopNum; + + for (loopNum = 1; loopNum < this->TA->numPoints; loopNum++) { + if (loopNum == 1) { + Process_first_sampling(0); + } else { + ComputePointsSquare(); + + pipe_barrier(PIPE_V); + + ComputeDist(); + + pipe_barrier(PIPE_V); + + ComputeSamplePoints(0, 0); + } + pipe_barrier(PIPE_V); + + updateDist(); + + CopyInIdx(loopNum); + + CopyOut(loopNum); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::Process_split_data() +{ + uint32_t loopNum, loopSplit; + + for (loopNum = 1; loopNum < this->TA->numPoints; loopNum++) { + for (loopSplit = 0; loopSplit < this->TA->pieces; loopSplit++) { + if (loopNum == 1) { + Process_first_sampling(loopSplit); + } else { + uint32_t comBlock = (loopSplit + this->TA->pieces - 1) % this->TA->pieces; + + // Cal point_x -> Mov point_x, Cal point_y -> Mov point_y, Cal point_z -> Mov point_z + ComputePointDeltaSquare(this->ubBlocks.pointXLocal, this->ubBlocks.pointTempXLocal, this->pointXSampled); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID0); + + CopyInPointAxis(pointAxis_x, loopSplit); + + ComputePointDeltaSquare(this->ubBlocks.pointYLocal, this->ubBlocks.pointTempYLocal, this->pointYSampled); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID1); + + CopyInPointAxis(pointAxis_y, loopSplit); + + ComputePointDeltaSquare(this->ubBlocks.pointZLocal, this->ubBlocks.pointTempZLocal, this->pointZSampled); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID2); + + CopyInPointAxis(pointAxis_z, loopSplit); + + pipe_barrier(PIPE_V); + + ComputeDist(); + + pipe_barrier(PIPE_V); + + ComputeSamplePoints(loopSplit, comBlock); + + set_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + wait_flag(PIPE_V, PIPE_MTE2, EVENT_ID3); + + set_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + wait_flag(PIPE_MTE3, PIPE_MTE2, EVENT_ID0); + + CopyInNearestDistTemp(loopSplit); + } + } + pipe_barrier(PIPE_V); + + updateDist(); + + CopyInIdx(loopNum); + + CopyOut(loopNum); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::Process_first_sampling(uint32_t loopSplit) +{ + // Mov point_x -> Cal point_x, Mov point_y -> Cal point_y, Mov point_z -> Cal point_z + CopyInPointAxis(pointAxis_x, loopSplit); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID0); + + ComputePointDeltaSquare(this->ubBlocks.pointXLocal, this->ubBlocks.pointTempXLocal, this->pointXSampled); + + CopyInPointAxis(pointAxis_y, loopSplit); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID1); + + ComputePointDeltaSquare(this->ubBlocks.pointYLocal, this->ubBlocks.pointTempYLocal, this->pointYSampled); + + CopyInPointAxis(pointAxis_z, loopSplit); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID2); + + ComputePointDeltaSquare(this->ubBlocks.pointZLocal, this->ubBlocks.pointTempZLocal, this->pointZSampled); + + pipe_barrier(PIPE_V); + + ComputeDist(); + + CopyInNearestDist(loopSplit); + + set_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + wait_flag(PIPE_MTE2, PIPE_V, EVENT_ID3); + + ComputeSamplePoints(loopSplit, loopSplit); +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyInPointAxis(PointAxis pointAxis, + uint32_t loopSplit) +{ + uint32_t offset; + DataCopyParams data_copy_param = {1, 0, 0, 0}; + DataCopyPadParams pad_param = {false, 0, 0, 0}; + + if (loopSplit == (this->TA->pieces - 1)) { + data_copy_param.blockLen = this->sizeofTail; + } else { + data_copy_param.blockLen = this->sizeofFormer; + } + switch (pointAxis) { + case pointAxis_x: + offset = this->batchOffsetPoint + this->TA->formerNum * loopSplit; + break; + case pointAxis_y: + offset = this->batchOffsetPoint + this->TA->formerNum * loopSplit + this->TA->N; + break; + case pointAxis_z: + offset = this->batchOffsetPoint + this->TA->formerNum * loopSplit + this->TA->N * 2; + break; + default: + break; + } + + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID1); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID1); + + switch (pointAxis) { + case pointAxis_x: +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.pointXLocal, pointGm[offset], data_copy_param, pad_param); +#endif + break; + case pointAxis_y: +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.pointYLocal, pointGm[offset], data_copy_param, pad_param); +#endif + break; + case pointAxis_z: +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.pointZLocal, pointGm[offset], data_copy_param, pad_param); +#endif + break; + default: + break; + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyInNearestDist(uint32_t loopSplit) +{ + uint32_t offset = this->batchOffsetNearest + this->TA->formerNum * loopSplit; + DataCopyParams data_copy_param = {1, 0, 0, 0}; + DataCopyPadParams pad_param = {false, 0, 0, 0}; + + if (loopSplit == (this->TA->pieces - 1)) { + data_copy_param.blockLen = this->sizeofTail; + } else { + data_copy_param.blockLen = this->sizeofFormer; + } + + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID2); + +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.nearestDistLocal, nearestDistGm[offset], data_copy_param, pad_param); +#endif +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyInNearestDistTemp(uint32_t loopSplit) +{ + uint32_t offset = this->batchOffsetNearest + this->TA->formerNum * loopSplit; + DataCopyParams data_copy_param = {1, 0, 0, 0}; + DataCopyPadParams pad_param = {false, 0, 0, 0}; + + if (loopSplit == (this->TA->pieces - 1)) { + data_copy_param.blockLen = this->sizeofTail; + } else { + data_copy_param.blockLen = this->sizeofFormer; + } + + set_flag(PIPE_S, PIPE_MTE2, EVENT_ID2); + wait_flag(PIPE_S, PIPE_MTE2, EVENT_ID2); + +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(this->ubBlocks.nearestDistLocal, nearestDistTempGm[offset], data_copy_param, pad_param); +#endif +} + +template +__aicore__ inline void furthestPointSamplingKernel::ComputePointsSquare() +{ + uint32_t total_num, dupTime, offset, comp_num; + + // while cal,every data block is aligned with 256 bytes. + for (offset = 0, total_num = this->TA->formerNum; total_num > 0; + comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { + dupTime = (total_num * sizeof(dataType)) / 256; + dupTime = (dupTime > 255) ? 255 : dupTime; + + set_flag(PIPE_S, PIPE_V, EVENT_ID3); + wait_flag(PIPE_S, PIPE_V, EVENT_ID3); + + Adds(this->ubBlocks.pointTempXLocal[offset], this->ubBlocks.pointXLocal[offset], this->pointXSampled, + this->dataNumIn256Bytes, dupTime, {1, 1, 8, 8}); + Adds(this->ubBlocks.pointTempYLocal[offset], this->ubBlocks.pointYLocal[offset], this->pointYSampled, + this->dataNumIn256Bytes, dupTime, {1, 1, 8, 8}); + Adds(this->ubBlocks.pointTempZLocal[offset], this->ubBlocks.pointZLocal[offset], this->pointZSampled, + this->dataNumIn256Bytes, dupTime, {1, 1, 8, 8}); + + pipe_barrier(PIPE_V); + + Mul(this->ubBlocks.pointTempXLocal[offset], this->ubBlocks.pointTempXLocal[offset], + this->ubBlocks.pointTempXLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + Mul(this->ubBlocks.pointTempYLocal[offset], this->ubBlocks.pointTempYLocal[offset], + this->ubBlocks.pointTempYLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + Mul(this->ubBlocks.pointTempZLocal[offset], this->ubBlocks.pointTempZLocal[offset], + this->ubBlocks.pointTempZLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::ComputePointDeltaSquare( + LocalTensor &pointLocal, LocalTensor &pointTempLocal, dataType pointSampled) +{ + uint32_t total_num, dupTime, offset, comp_num; + + // while cal,every data block is aligned with 256 bytes. + for (offset = 0, total_num = this->TA->formerNum; total_num > 0; + comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { + dupTime = (total_num * sizeof(dataType)) / 256; + dupTime = (dupTime > 255) ? 255 : dupTime; + + set_flag(PIPE_S, PIPE_V, EVENT_ID3); + wait_flag(PIPE_S, PIPE_V, EVENT_ID3); + + Adds(pointTempLocal[offset], pointLocal[offset], pointSampled, this->dataNumIn256Bytes, + dupTime, {1, 1, 8, 8}); + + pipe_barrier(PIPE_V); + + Mul(pointTempLocal[offset], pointTempLocal[offset], pointTempLocal[offset], this->dataNumIn256Bytes, + dupTime, {1, 1, 1, 8, 8, 8}); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::ComputeDist() +{ + uint32_t total_num, dupTime, offset, comp_num; + + // while cal,every data block is aligned with 256 bytes. + for (offset = 0, total_num = this->TA->formerNum; total_num > 0; + comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { + dupTime = (total_num * sizeof(dataType)) / 256; + dupTime = (dupTime > 255) ? 255 : dupTime; + + set_flag(PIPE_S, PIPE_V, EVENT_ID0); + wait_flag(PIPE_S, PIPE_V, EVENT_ID0); + + Add(this->ubBlocks.distLocal[offset], this->ubBlocks.pointTempXLocal[offset], + this->ubBlocks.pointTempYLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + + pipe_barrier(PIPE_V); + + Add(this->ubBlocks.distLocal[offset], this->ubBlocks.distLocal[offset], + this->ubBlocks.pointTempZLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::ComputeSamplePoints(uint32_t loopSplit, + uint32_t comBlock) +{ + uint32_t total_num, dupTime, offset, comp_num, reduceCnt, reduceOffset; + + reduceCnt = ((this->TA->formerNum != this->TA->tailNum) && (comBlock == (this->TA->pieces - 1))) ? + this->TA->tailNum : this->TA->formerNum; + reduceOffset = comBlock * 2; + + for (offset = 0, total_num = this->TA->formerNum; total_num > 0; + comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { + dupTime = (total_num * sizeof(dataType)) / 256; + dupTime = (dupTime > 255) ? 255 : dupTime; + + set_flag(PIPE_S, PIPE_V, EVENT_ID1); + wait_flag(PIPE_S, PIPE_V, EVENT_ID1); + + Min(this->ubBlocks.nearestDistLocal[offset], this->ubBlocks.nearestDistLocal[offset], + this->ubBlocks.distLocal[offset], this->dataNumIn256Bytes, dupTime, {1, 1, 1, 8, 8, 8}); + } + + if (this->TA->pieces > 1) { + // set_flag: After Updated nearestDistLocal, Mov nearestDistLocal to GM. + set_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_V, PIPE_MTE3, EVENT_ID0); + CopyOutNearestDistTemp(comBlock); + } + + pipe_barrier(PIPE_V); + + // ReduceMax + ReduceMax(this->ubBlocks.idxTempLocal[reduceOffset], this->ubBlocks.nearestDistLocal, + this->ubBlocks.workLocal, reduceCnt, 1); +} + +template +__aicore__ inline void furthestPointSamplingKernel::updateDist() +{ + dataType tempValue; + + // this->TA->pieces >= 1 + for (uint32_t i = 1; i < (2 * this->TA->pieces); i = (i + 2)) { + tempValue = this->ubBlocks.idxTempLocal.GetValue(i); + if (this->maxDist < this->ubBlocks.idxTempLocal.GetValue(i-1)) { + this->maxDist = this->ubBlocks.idxTempLocal.GetValue(i-1); + this->maxDistIdx = (this->TA->formerNum * (i / 2)) + (*reinterpret_cast(&tempValue)); + } + } +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyOut(uint32_t loopNum) +{ + uint32_t elemNum = this->dataNumIn1024Bytes; + // elemNum is a multiple of 2. + if ((loopNum != 0) && (((loopNum + 1) & (elemNum - 1)) != 0) && ((loopNum + 1) != this->TA->numPoints)) { + // when num of sampled < 256 && not last loop, return; + return ; + } + + uint32_t offset = this->core_batch * this->TA->numPoints; + DataCopyExtParams data_copy_param = {1, sizeof(dataType), 0, 0, 0}; + if (((loopNum + 1) & (elemNum - 1)) == 0) { + data_copy_param.blockLen = 1024; + offset = offset + loopNum / elemNum * elemNum; + } else if ((loopNum + 1) == this->TA->numPoints) { + data_copy_param.blockLen = sizeof(dataType) * + (this->TA->numPoints - (this->TA->numPoints / elemNum * elemNum)); + offset = offset + (this->TA->numPoints / elemNum * elemNum); + } + + set_flag(PIPE_S, PIPE_MTE3, EVENT_ID0); + wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID0); + +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(idxGm[offset], this->ubBlocks.idxLocal, data_copy_param); +#endif +} + +template +__aicore__ inline void furthestPointSamplingKernel::CopyOutNearestDistTemp(uint32_t loopSplit) +{ + uint32_t offset = this->batchOffsetNearest + this->TA->formerNum * loopSplit; + DataCopyExtParams data_copy_param = {1, 0, 0, 0, 0}; + + if (loopSplit == (this->TA->pieces - 1)) { + data_copy_param.blockLen = this->sizeofTail; + } else { + data_copy_param.blockLen = this->sizeofFormer; + } + + set_flag(PIPE_S, PIPE_MTE3, EVENT_ID1); + wait_flag(PIPE_S, PIPE_MTE3, EVENT_ID1); + +#ifndef __GET_CODE_CHANNEL__ + DataCopyPad(nearestDistTempGm[offset], this->ubBlocks.nearestDistLocal, data_copy_param); +#endif +} + +template +__aicore__ inline void furthestPointSamplingKernel::InitGm(GM_ADDR point_xyz, GM_ADDR temp, + GM_ADDR index, GM_ADDR workspace) +{ + GM_ADDR usrWorkspace = AscendC::GetUserWorkspace(workspace); + uint32_t coreId = GetBlockIdx(); + uint32_t skipData, numData, skipIdx, numIdx; + uint32_t numDataBigCore = this->TA->bigCoreBatch * this->TA->N; + uint32_t numIdxBigCore = this->TA->bigCoreBatch * this->TA->numPoints; + + if (coreId < this->TA->bigCoreNum) { + numData = numDataBigCore; + numIdx = numIdxBigCore; + skipData = numData * coreId; + skipIdx = numIdx * coreId; + } else { + numData = this->TA->smallCoreBatch * this->TA->N; + numIdx = this->TA->smallCoreBatch * this->TA->numPoints; + skipData = this->TA->bigCoreNum * numDataBigCore + (coreId - this->TA->bigCoreNum) * numData; + skipIdx = this->TA->bigCoreNum * numIdxBigCore + (coreId - this->TA->bigCoreNum) * numIdx; + } + + this->pointGm.SetGlobalBuffer((__gm__ dataType*)point_xyz + skipData * 3, numData * 3); + this->nearestDistGm.SetGlobalBuffer((__gm__ dataType*)temp + skipData, numData); + this->idxGm.SetGlobalBuffer((__gm__ idxType*)index + skipIdx, numIdx); + this->nearestDistTempGm.SetGlobalBuffer((__gm__ dataType*)usrWorkspace + skipData, numData); +} + +template +__aicore__ inline furthestPointSamplingKernel::~furthestPointSamplingKernel() +{ + this->pointXQue.FreeTensor(this->ubBlocks.pointXLocal); + this->pointYQue.FreeTensor(this->ubBlocks.pointYLocal); + this->pointZQue.FreeTensor(this->ubBlocks.pointZLocal); + this->pointTempXUb.FreeTensor(this->ubBlocks.pointTempXLocal); + this->pointTempYUb.FreeTensor(this->ubBlocks.pointTempYLocal); + this->pointTempZUb.FreeTensor(this->ubBlocks.pointTempZLocal); + this->nearestDistQue.FreeTensor(this->ubBlocks.nearestDistLocal); + this->distUb.FreeTensor(this->ubBlocks.distLocal); + this->workUb.FreeTensor(this->ubBlocks.workLocal); + + this->idxQue.FreeTensor(this->ubBlocks.idxLocal); + + this->idxTempUb.FreeTensor(this->ubBlocks.idxTempLocal); + this->pointSampled.FreeTensor(this->ubBlocks.pointSampledLocal); +} \ No newline at end of file diff --git a/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h new file mode 100644 index 00000000..2fdb36df --- /dev/null +++ b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h @@ -0,0 +1,139 @@ +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2023. All rights reserved. + * This file constains code of cpu debug and npu code.We read data from bin file + * and write result to file. + */ +#ifndef FURTHEST_POINT_SAMPLING_H +#define FURTHEST_POINT_SAMPLING_H + +#include "kernel_tiling/kernel_tiling.h" +#include "kernel_operator.h" + +namespace AscendC { +constexpr uint32_t BUFFER_NUM = 1u; + +enum PointAxis { + pointAxis_x, + pointAxis_y, + pointAxis_z +}; + +template +struct UbBlocks_tag { + __aicore__ UbBlocks_tag() = default; + + LocalTensor pointXLocal; + LocalTensor pointYLocal; + LocalTensor pointZLocal; + LocalTensor pointTempXLocal; + LocalTensor pointTempYLocal; + LocalTensor pointTempZLocal; + LocalTensor nearestDistLocal; + LocalTensor distLocal; + LocalTensor idxLocal; + LocalTensor idxTempLocal; + LocalTensor pointSampledLocal; + LocalTensor workLocal; +}; +template +using UbBlocks = UbBlocks_tag; + +class tilingArgs { +public: + __aicore__ inline tilingArgs() = default; +public: + uint32_t N; + uint32_t batch; + uint32_t numPoints; + uint32_t pieces; + uint32_t formerNum; + uint32_t tailNum; + uint32_t workSize; + uint32_t idxTempSize; + uint32_t bigCoreBatch; + uint32_t smallCoreBatch; + uint32_t bigCoreNum; + uint32_t repeats; +}; + +template +class furthestPointSamplingKernel { +public: + __aicore__ inline furthestPointSamplingKernel(GM_ADDR point_xyz, GM_ADDR temp, GM_ADDR index, GM_ADDR workspace, + tilingArgs *tiling); + __aicore__ inline ~furthestPointSamplingKernel(); + __aicore__ inline void Process(); + +private: + __aicore__ inline void Process_first_sampling(uint32_t loopSplit = 0); + __aicore__ inline void Process_split_data(); + __aicore__ inline void Process_complete_data(); + +private: + __aicore__ inline void CopyInPointAxis(PointAxis pointAxis, uint32_t loopSplit = 0); + __aicore__ inline void CopyInNearestDist(uint32_t loopSplit = 0); + __aicore__ inline void CopyInNearestDistTemp(uint32_t loopSplit = 0); + __aicore__ inline void CopyInIdx(uint32_t loopNum); + __aicore__ inline void CopyOut(uint32_t loopNum); + __aicore__ inline void CopyOutNearestDistTemp(uint32_t loopSplit = 0); + +private: + __aicore__ inline void ComputePointsSquare(); + __aicore__ inline void ComputePointDeltaSquare(LocalTensor &pointLocal, + LocalTensor &pointTempLocal, dataType pointSampled); + __aicore__ inline void ComputeDist(); + __aicore__ inline void ComputeSamplePoints(uint32_t loopSplit, uint32_t ComBlock); + __aicore__ inline void updateDist(); + +private: + __aicore__ inline void InitGm(GM_ADDR point_xyz, GM_ADDR temp, GM_ADDR index, GM_ADDR workspace); + +private: + TPipe pipe; + TQue pointXQue; + TQue pointYQue; + TQue pointZQue; + TQue pointTempXUb; + TQue pointTempYUb; + TQue pointTempZUb; + TQue nearestDistQue; + TQue distUb; + TQue workUb; + + TQue idxQue; + + TQue idxTempUb; + TQue pointSampled; + +private: + GlobalTensor pointGm; + GlobalTensor nearestDistGm; + GlobalTensor idxGm; + GlobalTensor nearestDistTempGm; + UbBlocks ubBlocks; + +private: + dataType pointXSampled {0}; + dataType pointYSampled {0}; + dataType pointZSampled {0}; + dataType maxDist {0}; + idxType maxDistIdx {0}; + uint32_t core_batch; + +private: + // tiling value + tilingArgs *TA; + +private: + uint32_t sizeofFormer; + uint32_t sizeofTail; + uint32_t dataNumIn32Bytes; + uint32_t dataNumIn64Bytes; + uint32_t dataNumIn256Bytes; + uint32_t dataNumIn1024Bytes; + uint32_t batchOffsetPoint; + uint32_t batchOffsetNearest; +}; +} + +#endif // FURTHEST_POINT_SAMPLING_H \ No newline at end of file diff --git a/tests/test_furthest_point_sampling.py b/tests/test_furthest_point_sampling.py new file mode 100644 index 00000000..192b54c3 --- /dev/null +++ b/tests/test_furthest_point_sampling.py @@ -0,0 +1,126 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from abc import ABC, abstractmethod +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] + + +class CreateBenchMarkTest(ABC): + def __init__(self): + self.batch = None + self.N = None + self.numPoints = None + + self.point = None + self.nearestDist = None + + @abstractmethod + def createData(self): + pass + + def compare_min(self, a): + if a[0] > a[1]: + return a[1] + else : + return a[0] + + def getCpuRes(self): + cpuRes = np.zeros([self.batch, self.numPoints], dtype=np.int32) + nearestDistCopy = self.nearestDist.copy() + + for i in range(self.batch): + sampled = 1 + index = 0 + while sampled < self.numPoints: + deltaX = self.point[i][0] - self.point[i][0][index] + deltaY = self.point[i][1] - self.point[i][1][index] + deltaZ = self.point[i][2] - self.point[i][2][index] + deltaX2 = deltaX * deltaX + deltaY2 = deltaY * deltaY + deltaZ2 = deltaZ * deltaZ + currentDist = deltaX2 + deltaY2 + deltaZ2 + + nearestDistCopy[i] = np.apply_along_axis(self.compare_min, 0, np.stack((currentDist, nearestDistCopy[i]), axis=0)) + index = np.argmax(nearestDistCopy[i]) + cpuRes[i][sampled] = index + sampled = sampled + 1 + return cpuRes + + +class Test1(CreateBenchMarkTest): + def createData(self): + self.batch = 47 + self.N = 717 + self.numPoints = 580 + + self.point = np.zeros([self.batch, 3, self.N], dtype=np.float32) + for i in range(self.batch): + for j in range(self.N): + self.point[i, 0, j] = j + + self.nearestDist = np.ones([self.batch, self.N], dtype=np.float32) * 1e10 + self.point = torch.from_numpy(self.point) + + +class Test2(CreateBenchMarkTest): + def createData(self): + self.batch = 193 + self.N = 579 + self.numPoints = 123 + + self.point = np.zeros([self.batch, 3, self.N], dtype=np.float32) + for i in range(self.batch): + for j in range(self.N): + self.point[i, 0, j] = j + self.point[i, 1, j] = j + 1 + self.point[i, 2, j] = j + 3 + + self.nearestDist = np.ones([self.batch, self.N], dtype=np.float32) * 1e10 + self.point = torch.from_numpy(self.point) + + +test1 = Test1() +test2 = Test2() + + +class TestFurthestPointSample(TestCase): + def cpu_op_exec(self, myTest): + return myTest.getCpuRes() + + def npu_op_exec(self, myTest): + return ads.common.npu_furthest_point_sampling(myTest.point.npu(), myTest.numPoints) + + def compare_res(self, myTest): + myTest.createData() + cpuOutput = self.cpu_op_exec(myTest) + npuOutput = self.npu_op_exec(myTest) + self.assertRtolEqual(cpuOutput, npuOutput) + + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `FurthestPointSampling` is only for 910B, skip it.") + def test_FurthestPointSample(self): + self.compare_res(test1) + self.compare_res(test2) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file -- Gitee From 2ca8ca0c7344e17eca86d6968f4d79bacc7251c4 Mon Sep 17 00:00:00 2001 From: l00636998 Date: Thu, 22 Feb 2024 08:58:09 +0000 Subject: [PATCH 05/12] =?UTF-8?q?!60=20=E7=AE=97=E5=AD=90Furthest=20Point?= =?UTF-8?q?=20Sampling=E9=81=97=E7=95=99=E9=97=AE=E9=A2=98=E6=95=B4?= =?UTF-8?q?=E6=94=B9=20Merge=20pull=20request=20!60=20from=20l00636998/mas?= =?UTF-8?q?ter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../csrc/FurthestPointSamplingKernelNpu.cpp | 2 +- .../op_kernel/furthest_point_sampling.cpp | 24 +++++++++---------- .../op_kernel/furthest_point_sampling.h | 2 ++ .../test_furthest_point_sampling.py | 4 ++-- 4 files changed, 17 insertions(+), 15 deletions(-) rename tests/{ => torch}/test_furthest_point_sampling.py (97%) diff --git a/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp b/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp index 29d775fe..dfa1884a 100644 --- a/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp +++ b/ads/common/ops/csrc/FurthestPointSamplingKernelNpu.cpp @@ -15,7 +15,7 @@ // limitations under the License. #include -#include "OpApiCommon.h" +#include "csrc/OpApiCommon.h" #include "functions.h" at::Tensor npu_furthest_point_sampling(const at::Tensor &point_xyz, const at::Tensor &nearset_temp, const int32_t num_points) diff --git a/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp index 5f53f360..9abc447b 100644 --- a/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp +++ b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.cpp @@ -64,7 +64,7 @@ __aicore__ inline furthestPointSamplingKernel::furthestPointS this->pipe.InitBuffer(this->distUb, BUFFER_NUM, this->sizeofFormer); this->pipe.InitBuffer(this->workUb, BUFFER_NUM, this->TA->workSize); - this->pipe.InitBuffer(this->idxQue, BUFFER_NUM, 1024); // copy out 256 fp32s once + this->pipe.InitBuffer(this->idxQue, BUFFER_NUM, 1024); // 1024: copy out 256 fp32s once this->pipe.InitBuffer(this->idxTempUb, BUFFER_NUM, this->TA->idxTempSize); this->pipe.InitBuffer(this->pointSampled, BUFFER_NUM, 32 * 3); @@ -201,11 +201,11 @@ __aicore__ inline void furthestPointSamplingKernel::Process_s CopyInPointAxis(pointAxis_z, loopSplit); - pipe_barrier(PIPE_V); + pipe_barrier(PIPE_ALL); ComputeDist(); - pipe_barrier(PIPE_V); + pipe_barrier(PIPE_ALL); ComputeSamplePoints(loopSplit, comBlock); @@ -366,8 +366,8 @@ __aicore__ inline void furthestPointSamplingKernel::ComputePo // while cal,every data block is aligned with 256 bytes. for (offset = 0, total_num = this->TA->formerNum; total_num > 0; comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { - dupTime = (total_num * sizeof(dataType)) / 256; - dupTime = (dupTime > 255) ? 255 : dupTime; + dupTime = (total_num * sizeof(dataType)) / ALLIGNED_BYTES; + dupTime = (dupTime > MAX_REPEAT_NUM) ? MAX_REPEAT_NUM : dupTime; set_flag(PIPE_S, PIPE_V, EVENT_ID3); wait_flag(PIPE_S, PIPE_V, EVENT_ID3); @@ -399,8 +399,8 @@ __aicore__ inline void furthestPointSamplingKernel::ComputePo // while cal,every data block is aligned with 256 bytes. for (offset = 0, total_num = this->TA->formerNum; total_num > 0; comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { - dupTime = (total_num * sizeof(dataType)) / 256; - dupTime = (dupTime > 255) ? 255 : dupTime; + dupTime = (total_num * sizeof(dataType)) / ALLIGNED_BYTES; + dupTime = (dupTime > MAX_REPEAT_NUM) ? MAX_REPEAT_NUM : dupTime; set_flag(PIPE_S, PIPE_V, EVENT_ID3); wait_flag(PIPE_S, PIPE_V, EVENT_ID3); @@ -423,8 +423,8 @@ __aicore__ inline void furthestPointSamplingKernel::ComputeDi // while cal,every data block is aligned with 256 bytes. for (offset = 0, total_num = this->TA->formerNum; total_num > 0; comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { - dupTime = (total_num * sizeof(dataType)) / 256; - dupTime = (dupTime > 255) ? 255 : dupTime; + dupTime = (total_num * sizeof(dataType)) / ALLIGNED_BYTES; + dupTime = (dupTime > MAX_REPEAT_NUM) ? MAX_REPEAT_NUM : dupTime; set_flag(PIPE_S, PIPE_V, EVENT_ID0); wait_flag(PIPE_S, PIPE_V, EVENT_ID0); @@ -451,8 +451,8 @@ __aicore__ inline void furthestPointSamplingKernel::ComputeSa for (offset = 0, total_num = this->TA->formerNum; total_num > 0; comp_num = dupTime * this->dataNumIn256Bytes, offset = offset + comp_num, total_num = total_num - comp_num) { - dupTime = (total_num * sizeof(dataType)) / 256; - dupTime = (dupTime > 255) ? 255 : dupTime; + dupTime = (total_num * sizeof(dataType)) / ALLIGNED_BYTES; + dupTime = (dupTime > MAX_REPEAT_NUM) ? MAX_REPEAT_NUM : dupTime; set_flag(PIPE_S, PIPE_V, EVENT_ID1); wait_flag(PIPE_S, PIPE_V, EVENT_ID1); @@ -468,7 +468,7 @@ __aicore__ inline void furthestPointSamplingKernel::ComputeSa CopyOutNearestDistTemp(comBlock); } - pipe_barrier(PIPE_V); + pipe_barrier(PIPE_ALL); // ReduceMax ReduceMax(this->ubBlocks.idxTempLocal[reduceOffset], this->ubBlocks.nearestDistLocal, diff --git a/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h index 2fdb36df..926ce4d2 100644 --- a/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h +++ b/ads/common/ops/kernels/op_kernel/furthest_point_sampling.h @@ -11,6 +11,8 @@ namespace AscendC { constexpr uint32_t BUFFER_NUM = 1u; +constexpr uint32_t MAX_REPEAT_NUM = 255u; +constexpr uint32_t ALLIGNED_BYTES = 256u; enum PointAxis { pointAxis_x, diff --git a/tests/test_furthest_point_sampling.py b/tests/torch/test_furthest_point_sampling.py similarity index 97% rename from tests/test_furthest_point_sampling.py rename to tests/torch/test_furthest_point_sampling.py index 192b54c3..2ea02eca 100644 --- a/tests/test_furthest_point_sampling.py +++ b/tests/torch/test_furthest_point_sampling.py @@ -108,11 +108,11 @@ class TestFurthestPointSample(TestCase): return myTest.getCpuRes() def npu_op_exec(self, myTest): - return ads.common.npu_furthest_point_sampling(myTest.point.npu(), myTest.numPoints) + return ads.common.npu_furthest_point_sampling(myTest.point.clone().permute(0, 2, 1).npu(), myTest.numPoints) def compare_res(self, myTest): myTest.createData() - cpuOutput = self.cpu_op_exec(myTest) + cpuOutput = torch.from_numpy(self.cpu_op_exec(myTest)) npuOutput = self.npu_op_exec(myTest) self.assertRtolEqual(cpuOutput, npuOutput) -- Gitee From f002d947c6bbc3bcba476800d0869903e040404c Mon Sep 17 00:00:00 2001 From: lishuai183 Date: Tue, 20 Feb 2024 15:25:15 +0800 Subject: [PATCH 06/12] init nms3d op. --- ads/common/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ads/common/__init__.py b/ads/common/__init__.py index ee34c193..ea2b2c64 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -32,3 +32,4 @@ from .ops.dynamic_voxelization import Voxelization from .ops.nms3d_normal import npu_nms3d_normal from .ops.npu_nms3d import npu_nms3d from .ops.furthest_point_sampling import npu_furthest_point_sampling +from .ops.npu_nms3d import npu_nms3d -- Gitee From 71a527522ff87662cf36df44b063e9ba92927e1c Mon Sep 17 00:00:00 2001 From: l00636998 Date: Wed, 21 Feb 2024 06:49:43 +0000 Subject: [PATCH 07/12] =?UTF-8?q?!28=20=E6=96=B0=E5=A2=9EFurthesPointSampl?= =?UTF-8?q?ing=E7=AE=97=E5=AD=90=20Merge=20pull=20request=20!28=20from=20l?= =?UTF-8?q?00636998/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ads/common/__init__.py | 1 - tests/test_furthest_point_sampling.py | 126 ++++++++++++++++++++++++++ 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 tests/test_furthest_point_sampling.py diff --git a/ads/common/__init__.py b/ads/common/__init__.py index ea2b2c64..45920814 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -30,6 +30,5 @@ from .ops.npu_multi_scale_deformable_attn_function import npu_multi_scale_deform from .ops.dynamic_voxelization import voxelization from .ops.dynamic_voxelization import Voxelization from .ops.nms3d_normal import npu_nms3d_normal -from .ops.npu_nms3d import npu_nms3d from .ops.furthest_point_sampling import npu_furthest_point_sampling from .ops.npu_nms3d import npu_nms3d diff --git a/tests/test_furthest_point_sampling.py b/tests/test_furthest_point_sampling.py new file mode 100644 index 00000000..192b54c3 --- /dev/null +++ b/tests/test_furthest_point_sampling.py @@ -0,0 +1,126 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from abc import ABC, abstractmethod +import numpy as np +import torch + +import torch_npu +from torch_npu.testing.testcase import TestCase, run_tests +from torch_npu.testing.common_utils import create_common_tensor +import ads.common + +DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10] + + +class CreateBenchMarkTest(ABC): + def __init__(self): + self.batch = None + self.N = None + self.numPoints = None + + self.point = None + self.nearestDist = None + + @abstractmethod + def createData(self): + pass + + def compare_min(self, a): + if a[0] > a[1]: + return a[1] + else : + return a[0] + + def getCpuRes(self): + cpuRes = np.zeros([self.batch, self.numPoints], dtype=np.int32) + nearestDistCopy = self.nearestDist.copy() + + for i in range(self.batch): + sampled = 1 + index = 0 + while sampled < self.numPoints: + deltaX = self.point[i][0] - self.point[i][0][index] + deltaY = self.point[i][1] - self.point[i][1][index] + deltaZ = self.point[i][2] - self.point[i][2][index] + deltaX2 = deltaX * deltaX + deltaY2 = deltaY * deltaY + deltaZ2 = deltaZ * deltaZ + currentDist = deltaX2 + deltaY2 + deltaZ2 + + nearestDistCopy[i] = np.apply_along_axis(self.compare_min, 0, np.stack((currentDist, nearestDistCopy[i]), axis=0)) + index = np.argmax(nearestDistCopy[i]) + cpuRes[i][sampled] = index + sampled = sampled + 1 + return cpuRes + + +class Test1(CreateBenchMarkTest): + def createData(self): + self.batch = 47 + self.N = 717 + self.numPoints = 580 + + self.point = np.zeros([self.batch, 3, self.N], dtype=np.float32) + for i in range(self.batch): + for j in range(self.N): + self.point[i, 0, j] = j + + self.nearestDist = np.ones([self.batch, self.N], dtype=np.float32) * 1e10 + self.point = torch.from_numpy(self.point) + + +class Test2(CreateBenchMarkTest): + def createData(self): + self.batch = 193 + self.N = 579 + self.numPoints = 123 + + self.point = np.zeros([self.batch, 3, self.N], dtype=np.float32) + for i in range(self.batch): + for j in range(self.N): + self.point[i, 0, j] = j + self.point[i, 1, j] = j + 1 + self.point[i, 2, j] = j + 3 + + self.nearestDist = np.ones([self.batch, self.N], dtype=np.float32) * 1e10 + self.point = torch.from_numpy(self.point) + + +test1 = Test1() +test2 = Test2() + + +class TestFurthestPointSample(TestCase): + def cpu_op_exec(self, myTest): + return myTest.getCpuRes() + + def npu_op_exec(self, myTest): + return ads.common.npu_furthest_point_sampling(myTest.point.npu(), myTest.numPoints) + + def compare_res(self, myTest): + myTest.createData() + cpuOutput = self.cpu_op_exec(myTest) + npuOutput = self.npu_op_exec(myTest) + self.assertRtolEqual(cpuOutput, npuOutput) + + @unittest.skipIf(DEVICE_NAME != 'Ascend910B', "OP `FurthestPointSampling` is only for 910B, skip it.") + def test_FurthestPointSample(self): + self.compare_res(test1) + self.compare_res(test2) + + +if __name__ == "__main__": + run_tests() \ No newline at end of file -- Gitee From 798f9520780e47a1f16239ce988368f3bd8376a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=88=92=E6=B5=A9=E6=9D=B0?= Date: Fri, 23 Feb 2024 10:02:34 +0000 Subject: [PATCH 08/12] =?UTF-8?q?!64=20add=20operator=20constraints=20Merg?= =?UTF-8?q?e=20pull=20request=20!64=20from=20=E8=88=92=E6=B5=A9=E6=9D=B0/m?= =?UTF-8?q?aster?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp | 7 +++++++ .../op_host/multi_scale_deformable_attention_grad.cpp | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp index e1e74560..8f7529ff 100644 --- a/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp +++ b/ads/common/ops/csrc/MultiScaleDeformableAttnFunctionKernelNpu.cpp @@ -93,6 +93,13 @@ std::tuple multi_scale_deformable_attn_grad( auto ori_dtype = value.scalar_type(); auto value_size = value.sizes(); auto location_size = location.sizes(); + auto channels = value_size[3]; + auto num_points = location_size[4]; + auto num_levels = location_size[3]; + auto data_total = channels + num_points + num_levels; + TORCH_CHECK(data_total < 512, "data_total is over 512: channels ", channels, " num_points is ", + num_points, " num_level is ", num_levels, "."); + TORCH_CHECK(channels % 8 == 0, "channels must be a multiple of eight, but channels is", channels, "."); auto grad_value_size = {value_size[0], value_size[1], value_size[2], value_size[3]}; auto grad_atten_weight_size = {location_size[0], location_size[1], location_size[2], location_size[3], location_size[4]}; auto grad_sample_loc_size = {location_size[0], location_size[1], location_size[2], location_size[3], location_size[5], location_size[4]}; diff --git a/ads/common/ops/kernels/op_host/multi_scale_deformable_attention_grad.cpp b/ads/common/ops/kernels/op_host/multi_scale_deformable_attention_grad.cpp index 54497529..148674a3 100644 --- a/ads/common/ops/kernels/op_host/multi_scale_deformable_attention_grad.cpp +++ b/ads/common/ops/kernels/op_host/multi_scale_deformable_attention_grad.cpp @@ -31,7 +31,7 @@ namespace optiling auto channels = value_shape.GetDim(3); auto num_query = sampling_loc_shape.GetDim(1); auto num_levels = sampling_loc_shape.GetDim(3); - auto num_point = sampling_loc_shape.GetDim(4); + auto num_point = sampling_loc_shape.GetDim(5); auto task_per_core = (batch_size * num_query - 1) / core_num + 1; auto core_used = (batch_size * num_query - 1) / task_per_core + 1; auto task_tail_core = batch_size * num_query - (core_used - 1) * task_per_core; -- Gitee From 57016f44492aae291dc4d12807dd0dc2b032bd6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=93=B2=E7=BB=AD?= Date: Fri, 23 Feb 2024 11:51:37 +0000 Subject: [PATCH 09/12] =?UTF-8?q?!61=20Enhance=20MSDA=20Merge=20pull=20req?= =?UTF-8?q?uest=20!61=20from=20=E5=88=98=E5=93=B2=E7=BB=AD/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../multi_scale_deformable_attn_function.cpp | 359 ----------------- ...ulti_scale_deformable_attn_function_v2.cpp | 360 +++++++++--------- 2 files changed, 177 insertions(+), 542 deletions(-) delete mode 100644 ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function.cpp diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function.cpp deleted file mode 100644 index c59529fa..00000000 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function.cpp +++ /dev/null @@ -1,359 +0,0 @@ - -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. - * - * This sample is a very basic sample that implements vector add on Ascend plaform. - */ -#include "kernel_operator.h" -using namespace AscendC; -constexpr int32_t BUFFER_NUM = 2; - -class KernelMultiScaleDeformableAttnFunctionV2 -{ -public: - __aicore__ inline KernelMultiScaleDeformableAttnFunctionV2() {} - __aicore__ inline void Init(GM_ADDR value, - GM_ADDR value_spatial_shapes, - GM_ADDR value_level_start_index, - GM_ADDR sampling_locations, - GM_ADDR attention_weights, - GM_ADDR output, MultiScaleDeformableAttnFunctionV2TilingData *tiling_data) - { - ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); - dataAlign = blockNum / sizeof(DTYPE_VALUE); - batchSize = tiling_data->batchSize; - numKeys = tiling_data->numKeys; - numHeads = tiling_data->numHeads; - embedDims = tiling_data->embedDims; - - numLevels = tiling_data->numLevels; - numQueries = tiling_data->numQueries; - numPoints = tiling_data->numPoints; - coreNum = tiling_data->coreNum; - - taskNum = batchSize * numQueries; - taskNumPerCore = DivCeil(taskNum, coreNum); - - embedDimsAlign = AlignUp(embedDims, dataAlign); - numPointsAlign = AlignUp(numPoints, dataAlign); - numLevelsAlign = AlignUp(numLevels, dataAlign); - - curBlockIdx = GetBlockIdx(); - startOffset = curBlockIdx * taskNumPerCore; - endOffset = (curBlockIdx + 1) * taskNumPerCore; - if (endOffset > taskNum) - { - endOffset = taskNum; - } - - valueGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(value), batchSize * numKeys * numHeads * embedDims); - locationGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(sampling_locations), batchSize * numQueries * numHeads * numLevels * numPoints * 2); - attentionWeightsGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(attention_weights), batchSize * numQueries * numHeads * numLevels * numPoints); - outputGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE *>(output), batchSize * numQueries * numHeads * embedDims); - - valueSpatialShapesGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE_SPATIAL_SHAPES *>(value_spatial_shapes), numLevels * 2); - valueLevelStartIndexGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE_SPATIAL_SHAPES *>(value_level_start_index), numLevels); - - pipe.InitBuffer(shapeQueue, BUFFER_NUM, AlignUp(numLevels * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(offsetQueue, BUFFER_NUM, numLevelsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(locationQueue, BUFFER_NUM, AlignUp(numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(attentionWeightsUb, BUFFER_NUM, AlignUp(numLevels * numPoints, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(outputQueue, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(tmpUb1, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb2, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb3, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb4, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(tmpResUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpResUb2, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(intOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(floatOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(tmpXUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpYUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpParam0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpParam1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(tmpIntX0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntY0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntX1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntY1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - - pipe.InitBuffer(leftTopWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftBottomWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightTopWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightBottomWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - - pipe.InitBuffer(leftTopValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftBottomValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightTopValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightBottomValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - } - - __aicore__ inline void Process() - { - for (uint32_t taskIdx = startOffset; taskIdx < endOffset; taskIdx++) - { - batch = taskIdx / numQueries; - query = taskIdx % numQueries; - pipe_barrier(PIPE_ALL); - Compute(batch, query); - } - } - -private: - __aicore__ inline bool isInRange(DTYPE_VALUE_SPATIAL_SHAPES x, DTYPE_VALUE_SPATIAL_SHAPES upper) - { - return 0 <= x && x < upper; - } - - __aicore__ inline void Compute(uint32_t batch, uint32_t query) - { - LocalTensor tmpResLocal = tmpResUb.Get(); - LocalTensor tmpResLocal2 = tmpResUb2.Get(); - - LocalTensor leftTopValueLocal = leftTopValueUb.Get(); - LocalTensor leftBottomValueUbLocal = leftBottomValueUb.Get(); - LocalTensor rightTopValueUbLocal = rightTopValueUb.Get(); - LocalTensor rightBottomValueUbLocal = rightBottomValueUb.Get(); - - LocalTensor leftTopWeiightLocal = leftTopWieightQueue.Get(); - LocalTensor leftBottomWeightLocal = leftBottomWieightQueue.Get(); - LocalTensor rightTopWeiightLocal = rightTopWieightQueue.Get(); - LocalTensor rightBottomWeightLocal = rightBottomWieightQueue.Get(); - - LocalTensor shapesLocal = shapeQueue.AllocTensor(); - LocalTensor offsetLocal = offsetQueue.AllocTensor(); - - LocalTensor locationLocal = locationQueue.AllocTensor(); - LocalTensor attentionWeightLocal = attentionWeightsUb.AllocTensor(); - - LocalTensor resLocal = outputQueue.AllocTensor(); - - LocalTensor xLocal = tmpXUb.Get(); - LocalTensor yLocal = tmpYUb.Get(); - - LocalTensor param0Local = tmpParam0Ub.Get(); - LocalTensor param1Local = tmpParam1Ub.Get(); - - LocalTensor x1Local = tmpIntX1Ub.Get(); - LocalTensor y1Local = tmpIntY1Ub.Get(); - - LocalTensor x0Local = tmpIntX0Ub.Get(); - LocalTensor y0Local = tmpIntY0Ub.Get(); - - LocalTensor tmpLocal1 = tmpUb1.Get(); - LocalTensor tmpLocal2 = tmpUb2.Get(); - LocalTensor tmpLocal3 = tmpUb3.Get(); - LocalTensor tmpLocal4 = tmpUb4.Get(); - - LocalTensor intOneLocal = intOneUb.Get(); - LocalTensor floatOneLocal = floatOneUb.Get(); - - Duplicate(intOneLocal, (DTYPE_VALUE_SPATIAL_SHAPES)1, numPointsAlign); - Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign); - DataCopyParams copyParams{1, (uint16_t)(embedDims * sizeof(DTYPE_VALUE)), 0, 0}; - - DataCopy(shapesLocal, valueSpatialShapesGm, AlignUp(numLevels * 2, dataAlign)); - DataCopy(offsetLocal, valueLevelStartIndexGm, numLevelsAlign); - Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); - moveOffset = batch * numQueries * numHeads * embedDims + query * numHeads * embedDims; - pipe_barrier(PIPE_ALL); - - for (uint32_t head = 0; head < numHeads; head++) - { - DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); - } - pipe_barrier(PIPE_ALL); - - for (uint32_t head = 0; head < numHeads; head++) - { - weightOffset = (batch * numQueries * numHeads * numLevels + query * numHeads * numLevels + head * numLevels) * numPoints; - - pipe_barrier(PIPE_ALL); - - DataCopy(locationLocal, locationGm[weightOffset * 2], AlignUp(numLevels * numPoints * 2, dataAlign)); - DataCopy(attentionWeightLocal, attentionWeightsGm[weightOffset], AlignUp(numLevels * numPoints, dataAlign)); - - pipe_barrier(PIPE_ALL); - for (uint32_t level = 0; level < numLevels; level++) - { - h = shapesLocal.GetValue(level * 2); - w = shapesLocal.GetValue(level * 2 + 1); - for (uint32_t point = 0; point < numPoints; point++) - { - locationOffset = (level * numPoints + point) * 2; - xLocal.SetValue(point, locationLocal.GetValue(locationOffset)); - yLocal.SetValue(point, locationLocal.GetValue(locationOffset + 1)); - } - - pipe_barrier(PIPE_ALL); - - Muls(tmpLocal1, xLocal, (DTYPE_VALUE)w, numPointsAlign); - Muls(tmpLocal2, yLocal, (DTYPE_VALUE)h, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Adds(param0Local, tmpLocal1, (DTYPE_VALUE)0.5, numPointsAlign); - Adds(param1Local, tmpLocal2, (DTYPE_VALUE)0.5, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Cast(x1Local, param0Local, RoundMode::CAST_FLOOR, numPointsAlign); - Cast(y1Local, param1Local, RoundMode::CAST_FLOOR, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Adds(tmpLocal3, param0Local, (DTYPE_VALUE)-1, numPointsAlign); - Adds(tmpLocal4, param1Local, (DTYPE_VALUE)-1, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Sub(x0Local, x1Local, intOneLocal, numPointsAlign); - Sub(y0Local, y1Local, intOneLocal, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Cast(xLocal, x0Local, RoundMode::CAST_NONE, numPointsAlign); - Cast(yLocal, y0Local, RoundMode::CAST_NONE, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Sub(tmpLocal1, tmpLocal3, xLocal, numPointsAlign); - Sub(tmpLocal2, tmpLocal4, yLocal, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Abs(param0Local, tmpLocal1, numPointsAlign); - Abs(param1Local, tmpLocal2, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Sub(xLocal, floatOneLocal, param0Local, numPointsAlign); - Sub(yLocal, floatOneLocal, param1Local, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Mul(leftTopWeiightLocal, xLocal, yLocal, numPointsAlign); - Mul(leftBottomWeightLocal, xLocal, param1Local, numPointsAlign); - Mul(rightTopWeiightLocal, param0Local, yLocal, numPointsAlign); - Mul(rightBottomWeightLocal, param0Local, param1Local, numPointsAlign); - pipe_barrier(PIPE_ALL); - - Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); - - for (uint32_t point = 0; point < numPoints; point++) - { - Duplicate(leftTopValueLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(leftBottomValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(rightTopValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(rightBottomValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - - x0 = x0Local.GetValue(point); - y0 = y0Local.GetValue(point); - x1 = x1Local.GetValue(point); - y1 = y1Local.GetValue(point); - - valueOffset = batch * numKeys * numHeads + offsetLocal.GetValue(level) * numHeads + head; - pipe_barrier(PIPE_ALL); - - if (isInRange(x0, w)) - { - if (isInRange(y0, h)) - { - DataCopy(leftTopValueLocal, valueGm[(valueOffset + (y0 * w + x0) * numHeads) * embedDims], embedDimsAlign); - } - if (isInRange(y1, h)) - { - DataCopy(leftBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x0) * numHeads) * embedDims], embedDimsAlign); - } - } - if (isInRange(x1, w)) - { - if (isInRange(y0, h)) - { - DataCopy(rightTopValueUbLocal, valueGm[(valueOffset + (y0 * w + x1) * numHeads) * embedDims], embedDimsAlign); - } - if (isInRange(y1, h)) - { - DataCopy(rightBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x1) * numHeads) * embedDims], embedDimsAlign); - } - } - pipe_barrier(PIPE_ALL); - - Muls(leftTopValueLocal, leftTopValueLocal, leftTopWeiightLocal.GetValue(point), embedDimsAlign); - Muls(rightTopValueUbLocal, rightTopValueUbLocal, rightTopWeiightLocal.GetValue(point), embedDimsAlign); - Muls(leftBottomValueUbLocal, leftBottomValueUbLocal, leftBottomWeightLocal.GetValue(point), embedDimsAlign); - Muls(rightBottomValueUbLocal, rightBottomValueUbLocal, rightBottomWeightLocal.GetValue(point), embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(tmpResLocal, leftTopValueLocal, rightTopValueUbLocal, embedDimsAlign); - Add(tmpResLocal2, leftBottomValueUbLocal, rightBottomValueUbLocal, embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(tmpResLocal, tmpResLocal, tmpResLocal2, embedDimsAlign); - pipe_barrier(PIPE_ALL); - Muls(tmpResLocal, tmpResLocal, attentionWeightLocal.GetValue(level * numPoints + point), embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(resLocal, resLocal, tmpResLocal, embedDimsAlign); - } - pipe_barrier(PIPE_ALL); - - SetAtomicAdd(); - DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); - SetAtomicNone(); - } - } - locationQueue.FreeTensor(locationLocal); - attentionWeightsUb.FreeTensor(attentionWeightLocal); - outputQueue.FreeTensor(resLocal); - shapeQueue.FreeTensor(shapesLocal); - offsetQueue.FreeTensor(offsetLocal); - } - -private: - TPipe pipe; - GlobalTensor valueGm, locationGm, attentionWeightsGm, outputGm; - GlobalTensor valueSpatialShapesGm, valueLevelStartIndexGm; - - TQue locationQueue, attentionWeightsUb, shapeQueue, offsetQueue; - TQue outputQueue; - - TBuf tmpResUb, tmpResUb2, tmpXUb, tmpYUb, tmpParam0Ub, tmpParam1Ub, tmpIntX0Ub, tmpIntY0Ub, tmpIntX1Ub, tmpIntY1Ub, tmpUb1, tmpUb2, tmpUb3, tmpUb4; - TBuf intOneUb, floatOneUb, leftTopValueUb, leftBottomValueUb, rightTopValueUb, rightBottomValueUb; - TBuf leftTopWieightQueue, leftBottomWieightQueue, rightTopWieightQueue, rightBottomWieightQueue; - - uint32_t batchSize; - uint32_t numKeys; - uint32_t numHeads; - uint32_t embedDims; - - uint32_t numLevels; - uint32_t numQueries; - uint32_t numPoints; - uint32_t coreNum; - - uint32_t embedDimsAlign; - uint32_t numPointsAlign; - uint32_t numLevelsAlign; - - uint32_t batch; - uint32_t query; - uint32_t head; - - uint32_t taskNum; - uint32_t taskNumPerCore; - uint32_t curBlockIdx; - uint32_t startOffset; - uint32_t endOffset; - uint32_t dataAlign; - uint32_t blockNum = 32; - - DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, moveOffset; -}; - -extern "C" __global__ __aicore__ void multi_scale_deformable_attn_function_v2(GM_ADDR value, - GM_ADDR value_spatial_shapes, - GM_ADDR value_level_start_index, - GM_ADDR sampling_locations, - GM_ADDR attention_weights, - GM_ADDR output, GM_ADDR workspace, GM_ADDR tiling) -{ - GET_TILING_DATA(tiling_data, tiling); - KernelMultiScaleDeformableAttnFunctionV2 op; - op.Init(value, value_spatial_shapes, value_level_start_index, - sampling_locations, attention_weights, output, &tiling_data); - op.Process(); -} diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp index c59529fa..0419e791 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp @@ -1,12 +1,11 @@ - /* - * Copyright (c) Huawei Technologies Co., Ltd. 2022-2023. All rights reserved. + * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved. * * This sample is a very basic sample that implements vector add on Ascend plaform. */ #include "kernel_operator.h" using namespace AscendC; -constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t BUFFER_NUM = 1; class KernelMultiScaleDeformableAttnFunctionV2 { @@ -17,8 +16,11 @@ public: GM_ADDR value_level_start_index, GM_ADDR sampling_locations, GM_ADDR attention_weights, - GM_ADDR output, MultiScaleDeformableAttnFunctionV2TilingData *tiling_data) + GM_ADDR output, + MultiScaleDeformableAttnFunctionV2TilingData *tiling_data, + TPipe *tmpPipe) { + pipe = tmpPipe; ASSERT(GetBlockNum() != 0 && "block dim can not be zero!"); dataAlign = blockNum / sizeof(DTYPE_VALUE); batchSize = tiling_data->batchSize; @@ -31,7 +33,9 @@ public: numPoints = tiling_data->numPoints; coreNum = tiling_data->coreNum; - taskNum = batchSize * numQueries; + tailNum = numHeads * embedDims; + + taskNum = numQueries; taskNumPerCore = DivCeil(taskNum, coreNum); embedDimsAlign = AlignUp(embedDims, dataAlign); @@ -54,53 +58,48 @@ public: valueSpatialShapesGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE_SPATIAL_SHAPES *>(value_spatial_shapes), numLevels * 2); valueLevelStartIndexGm.SetGlobalBuffer(reinterpret_cast<__gm__ DTYPE_VALUE_SPATIAL_SHAPES *>(value_level_start_index), numLevels); - pipe.InitBuffer(shapeQueue, BUFFER_NUM, AlignUp(numLevels * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(offsetQueue, BUFFER_NUM, numLevelsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(shapeQueue, BUFFER_NUM, AlignUp(numLevels * 2, dataAlign) * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(offsetQueue, BUFFER_NUM, numLevelsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(locationQueue, BUFFER_NUM, AlignUp(numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(attentionWeightsUb, BUFFER_NUM, AlignUp(numLevels * numPoints, dataAlign) * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(outputQueue, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(locationQueue, BUFFER_NUM, AlignUp(numHeads * numLevels * numPoints * 2, dataAlign) * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(attentionWeightsUb, BUFFER_NUM, AlignUp(numHeads * numLevels * numPoints, dataAlign) * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(outputQueue, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb1, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb2, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb3, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpUb4, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(emptyUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpResUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpResUb2, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb1, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb2, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb3, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpUb4, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(intOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(floatOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(intOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(floatOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpXUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpYUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpParam0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpParam1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpXUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpYUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpParam0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpParam1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(tmpIntX0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntY0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntX1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(tmpIntY1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntX0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntY0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntX1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntY1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe.InitBuffer(leftTopWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftBottomWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightTopWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightBottomWieightQueue, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(leftTopWieightQueue, BUFFER_NUM, 4 * numPointsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftTopValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(leftBottomValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightTopValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe.InitBuffer(rightBottomValueUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(valueUb, BUFFER_NUM, numPoints * 4 * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb2, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb3, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); } __aicore__ inline void Process() { for (uint32_t taskIdx = startOffset; taskIdx < endOffset; taskIdx++) { - batch = taskIdx / numQueries; - query = taskIdx % numQueries; - pipe_barrier(PIPE_ALL); - Compute(batch, query); + SetAtomicAdd(); + Compute(taskIdx); + SetAtomicNone(); } } @@ -110,215 +109,208 @@ private: return 0 <= x && x < upper; } - __aicore__ inline void Compute(uint32_t batch, uint32_t query) + __aicore__ inline void Compute(uint32_t query) { - LocalTensor tmpResLocal = tmpResUb.Get(); - LocalTensor tmpResLocal2 = tmpResUb2.Get(); - - LocalTensor leftTopValueLocal = leftTopValueUb.Get(); - LocalTensor leftBottomValueUbLocal = leftBottomValueUb.Get(); - LocalTensor rightTopValueUbLocal = rightTopValueUb.Get(); - LocalTensor rightBottomValueUbLocal = rightBottomValueUb.Get(); - - LocalTensor leftTopWeiightLocal = leftTopWieightQueue.Get(); - LocalTensor leftBottomWeightLocal = leftBottomWieightQueue.Get(); - LocalTensor rightTopWeiightLocal = rightTopWieightQueue.Get(); - LocalTensor rightBottomWeightLocal = rightBottomWieightQueue.Get(); + LocalTensor locationLocal = locationQueue.AllocTensor(); + LocalTensor attentionWeightLocal = attentionWeightsUb.AllocTensor(); LocalTensor shapesLocal = shapeQueue.AllocTensor(); LocalTensor offsetLocal = offsetQueue.AllocTensor(); - LocalTensor locationLocal = locationQueue.AllocTensor(); - LocalTensor attentionWeightLocal = attentionWeightsUb.AllocTensor(); + DataCopy(shapesLocal, valueSpatialShapesGm, AlignUp(numLevels * 2, dataAlign)); + DataCopy(offsetLocal, valueLevelStartIndexGm, numLevelsAlign); - LocalTensor resLocal = outputQueue.AllocTensor(); + DataCopyParams copyParams{1, (uint16_t)(embedDims * sizeof(DTYPE_VALUE)), 0, 0}; - LocalTensor xLocal = tmpXUb.Get(); - LocalTensor yLocal = tmpYUb.Get(); + LocalTensor valueLocal = valueUb.Get(); - LocalTensor param0Local = tmpParam0Ub.Get(); - LocalTensor param1Local = tmpParam1Ub.Get(); + event_t eventIdVToMte3 = static_cast(GetTPipePtr()->AllocEventID()); + event_t eventIdMte2ToV = static_cast(GetTPipePtr()->AllocEventID()); - LocalTensor x1Local = tmpIntX1Ub.Get(); - LocalTensor y1Local = tmpIntY1Ub.Get(); + for (uint32_t batch = 0; batch < batchSize; batch++) + { + LocalTensor emptyUbLocal = emptyUb.Get(); - LocalTensor x0Local = tmpIntX0Ub.Get(); - LocalTensor y0Local = tmpIntY0Ub.Get(); + LocalTensor weightLocal = leftTopWieightQueue.Get(); - LocalTensor tmpLocal1 = tmpUb1.Get(); - LocalTensor tmpLocal2 = tmpUb2.Get(); - LocalTensor tmpLocal3 = tmpUb3.Get(); - LocalTensor tmpLocal4 = tmpUb4.Get(); + LocalTensor xLocal = tmpXUb.Get(); + LocalTensor yLocal = tmpYUb.Get(); - LocalTensor intOneLocal = intOneUb.Get(); - LocalTensor floatOneLocal = floatOneUb.Get(); + LocalTensor tmpResLocal = tmpResUb.Get(); + LocalTensor tmpResLocal2 = tmpResUb2.Get(); + LocalTensor tmpResLocal3 = tmpResUb3.Get(); - Duplicate(intOneLocal, (DTYPE_VALUE_SPATIAL_SHAPES)1, numPointsAlign); - Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign); - DataCopyParams copyParams{1, (uint16_t)(embedDims * sizeof(DTYPE_VALUE)), 0, 0}; + LocalTensor param0Local = tmpParam0Ub.Get(); + LocalTensor param1Local = tmpParam1Ub.Get(); - DataCopy(shapesLocal, valueSpatialShapesGm, AlignUp(numLevels * 2, dataAlign)); - DataCopy(offsetLocal, valueLevelStartIndexGm, numLevelsAlign); - Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); - moveOffset = batch * numQueries * numHeads * embedDims + query * numHeads * embedDims; - pipe_barrier(PIPE_ALL); + LocalTensor x1Local = tmpIntX1Ub.Get(); + LocalTensor y1Local = tmpIntY1Ub.Get(); - for (uint32_t head = 0; head < numHeads; head++) - { - DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); - } - pipe_barrier(PIPE_ALL); + LocalTensor x0Local = tmpIntX0Ub.Get(); + LocalTensor y0Local = tmpIntY0Ub.Get(); - for (uint32_t head = 0; head < numHeads; head++) - { - weightOffset = (batch * numQueries * numHeads * numLevels + query * numHeads * numLevels + head * numLevels) * numPoints; + LocalTensor tmpLocal1 = tmpUb1.Get(); + LocalTensor tmpLocal2 = tmpUb2.Get(); + LocalTensor tmpLocal3 = tmpUb3.Get(); + LocalTensor tmpLocal4 = tmpUb4.Get(); + + LocalTensor intOneLocal = intOneUb.Get(); + LocalTensor floatOneLocal = floatOneUb.Get(); - pipe_barrier(PIPE_ALL); + Duplicate(intOneLocal, (DTYPE_VALUE_SPATIAL_SHAPES)1, numPointsAlign); + Duplicate(floatOneLocal, (DTYPE_VALUE)1, numPointsAlign); - DataCopy(locationLocal, locationGm[weightOffset * 2], AlignUp(numLevels * numPoints * 2, dataAlign)); - DataCopy(attentionWeightLocal, attentionWeightsGm[weightOffset], AlignUp(numLevels * numPoints, dataAlign)); + Duplicate(emptyUbLocal, DTYPE_VALUE(0), embedDimsAlign); + moveOffset = batch * numQueries * numHeads * embedDims + query * numHeads * embedDims; - pipe_barrier(PIPE_ALL); - for (uint32_t level = 0; level < numLevels; level++) + for (uint32_t head = 0; head < numHeads; head++) { - h = shapesLocal.GetValue(level * 2); - w = shapesLocal.GetValue(level * 2 + 1); - for (uint32_t point = 0; point < numPoints; point++) - { - locationOffset = (level * numPoints + point) * 2; - xLocal.SetValue(point, locationLocal.GetValue(locationOffset)); - yLocal.SetValue(point, locationLocal.GetValue(locationOffset + 1)); - } + DataCopyPad(outputGm[moveOffset + head * embedDims], emptyUbLocal, copyParams); + } - pipe_barrier(PIPE_ALL); + weightOffset = (batch * numQueries * numHeads * numLevels + query * numHeads * numLevels) * numPoints; - Muls(tmpLocal1, xLocal, (DTYPE_VALUE)w, numPointsAlign); - Muls(tmpLocal2, yLocal, (DTYPE_VALUE)h, numPointsAlign); - pipe_barrier(PIPE_ALL); + DataCopy(locationLocal, locationGm[weightOffset * 2], AlignUp(numHeads * numLevels * numPoints * 2, dataAlign)); + DataCopy(attentionWeightLocal, attentionWeightsGm[weightOffset], AlignUp(numHeads * numLevels * numPoints, dataAlign)); - Adds(param0Local, tmpLocal1, (DTYPE_VALUE)0.5, numPointsAlign); - Adds(param1Local, tmpLocal2, (DTYPE_VALUE)0.5, numPointsAlign); - pipe_barrier(PIPE_ALL); + for (uint32_t head = 0; head < numHeads; head++) + { + for (uint32_t level = 0; level < numLevels; level++) + { + h = shapesLocal.GetValue(level * 2); + w = shapesLocal.GetValue(level * 2 + 1); - Cast(x1Local, param0Local, RoundMode::CAST_FLOOR, numPointsAlign); - Cast(y1Local, param1Local, RoundMode::CAST_FLOOR, numPointsAlign); - pipe_barrier(PIPE_ALL); + weightOffset = (head * numLevels + level) * numPoints; + locationOffset = weightOffset * 2; + for (uint32_t point = 0; point < numPoints; point++) + { + xLocal.SetValue(point, locationLocal.GetValue(locationOffset + point * 2)); + yLocal.SetValue(point, locationLocal.GetValue(locationOffset + point * 2 + 1)); + } - Adds(tmpLocal3, param0Local, (DTYPE_VALUE)-1, numPointsAlign); - Adds(tmpLocal4, param1Local, (DTYPE_VALUE)-1, numPointsAlign); - pipe_barrier(PIPE_ALL); + Muls(tmpLocal1, xLocal, (DTYPE_VALUE)w, numPointsAlign); + Muls(tmpLocal2, yLocal, (DTYPE_VALUE)h, numPointsAlign); - Sub(x0Local, x1Local, intOneLocal, numPointsAlign); - Sub(y0Local, y1Local, intOneLocal, numPointsAlign); - pipe_barrier(PIPE_ALL); + Adds(param0Local, tmpLocal1, (DTYPE_VALUE)0.5, numPointsAlign); + Adds(param1Local, tmpLocal2, (DTYPE_VALUE)0.5, numPointsAlign); - Cast(xLocal, x0Local, RoundMode::CAST_NONE, numPointsAlign); - Cast(yLocal, y0Local, RoundMode::CAST_NONE, numPointsAlign); - pipe_barrier(PIPE_ALL); + Cast(x1Local, param0Local, RoundMode::CAST_FLOOR, numPointsAlign); + Cast(y1Local, param1Local, RoundMode::CAST_FLOOR, numPointsAlign); - Sub(tmpLocal1, tmpLocal3, xLocal, numPointsAlign); - Sub(tmpLocal2, tmpLocal4, yLocal, numPointsAlign); - pipe_barrier(PIPE_ALL); + Adds(tmpLocal3, param0Local, (DTYPE_VALUE)-1, numPointsAlign); + Adds(tmpLocal4, param1Local, (DTYPE_VALUE)-1, numPointsAlign); - Abs(param0Local, tmpLocal1, numPointsAlign); - Abs(param1Local, tmpLocal2, numPointsAlign); - pipe_barrier(PIPE_ALL); + Sub(x0Local, x1Local, intOneLocal, numPointsAlign); + Sub(y0Local, y1Local, intOneLocal, numPointsAlign); - Sub(xLocal, floatOneLocal, param0Local, numPointsAlign); - Sub(yLocal, floatOneLocal, param1Local, numPointsAlign); - pipe_barrier(PIPE_ALL); + Cast(xLocal, x0Local, RoundMode::CAST_NONE, numPointsAlign); + Cast(yLocal, y0Local, RoundMode::CAST_NONE, numPointsAlign); - Mul(leftTopWeiightLocal, xLocal, yLocal, numPointsAlign); - Mul(leftBottomWeightLocal, xLocal, param1Local, numPointsAlign); - Mul(rightTopWeiightLocal, param0Local, yLocal, numPointsAlign); - Mul(rightBottomWeightLocal, param0Local, param1Local, numPointsAlign); - pipe_barrier(PIPE_ALL); + Sub(tmpLocal1, tmpLocal3, xLocal, numPointsAlign); + Sub(tmpLocal2, tmpLocal4, yLocal, numPointsAlign); - Duplicate(resLocal, DTYPE_VALUE(0), embedDimsAlign); + Abs(param0Local, tmpLocal1, numPointsAlign); + Abs(param1Local, tmpLocal2, numPointsAlign); - for (uint32_t point = 0; point < numPoints; point++) - { - Duplicate(leftTopValueLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(leftBottomValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(rightTopValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); - Duplicate(rightBottomValueUbLocal, DTYPE_VALUE(0), embedDimsAlign); + Sub(xLocal, floatOneLocal, param0Local, numPointsAlign); + Sub(yLocal, floatOneLocal, param1Local, numPointsAlign); + + Mul(weightLocal, xLocal, yLocal, numPointsAlign); + Mul(weightLocal[numPointsAlign], xLocal, param1Local, numPointsAlign); + Mul(weightLocal[numPointsAlign * 2], param0Local, yLocal, numPointsAlign); + Mul(weightLocal[numPointsAlign * 3], param0Local, param1Local, numPointsAlign); - x0 = x0Local.GetValue(point); - y0 = y0Local.GetValue(point); - x1 = x1Local.GetValue(point); - y1 = y1Local.GetValue(point); + Mul(weightLocal, weightLocal, attentionWeightLocal[weightOffset], numPointsAlign); + Mul(weightLocal[numPointsAlign], weightLocal[numPointsAlign], attentionWeightLocal[weightOffset], numPointsAlign); + Mul(weightLocal[numPointsAlign * 2], weightLocal[numPointsAlign * 2], attentionWeightLocal[weightOffset], numPointsAlign); + Mul(weightLocal[numPointsAlign * 3], weightLocal[numPointsAlign * 3], attentionWeightLocal[weightOffset], numPointsAlign); - valueOffset = batch * numKeys * numHeads + offsetLocal.GetValue(level) * numHeads + head; - pipe_barrier(PIPE_ALL); + valueOffset = (batch * numKeys * numHeads + offsetLocal.GetValue(level) * numHeads + head) * embedDims; - if (isInRange(x0, w)) + Duplicate(valueLocal, DTYPE_VALUE(0), 4 * numPoints * embedDimsAlign); + for (uint32_t point = 0; point < numPoints; point++) { - if (isInRange(y0, h)) + x0 = x0Local.GetValue(point); + y0 = y0Local.GetValue(point); + x1 = x1Local.GetValue(point); + y1 = y1Local.GetValue(point); + + if (isInRange(x0, w)) { - DataCopy(leftTopValueLocal, valueGm[(valueOffset + (y0 * w + x0) * numHeads) * embedDims], embedDimsAlign); + if (isInRange(y0, h)) + { + DataCopy(valueLocal[point * embedDimsAlign * 4], valueGm[valueOffset + (y0 * w + x0) * tailNum], embedDimsAlign); + } + if (isInRange(y1, h)) + { + DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], valueGm[valueOffset + (y1 * w + x0) * tailNum], embedDimsAlign); + } } - if (isInRange(y1, h)) + if (isInRange(x1, w)) { - DataCopy(leftBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x0) * numHeads) * embedDims], embedDimsAlign); + if (isInRange(y0, h)) + { + DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueGm[valueOffset + (y0 * w + x1) * tailNum], embedDimsAlign); + } + if (isInRange(y1, h)) + { + DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], valueGm[valueOffset + (y1 * w + x1) * tailNum], embedDimsAlign); + } } } - if (isInRange(x1, w)) + SetFlag(eventIdMte2ToV); + WaitFlag(eventIdMte2ToV); + + for (uint32_t point = 0; point < numPoints; point++) { - if (isInRange(y0, h)) - { - DataCopy(rightTopValueUbLocal, valueGm[(valueOffset + (y0 * w + x1) * numHeads) * embedDims], embedDimsAlign); - } - if (isInRange(y1, h)) - { - DataCopy(rightBottomValueUbLocal, valueGm[(valueOffset + (y1 * w + x1) * numHeads) * embedDims], embedDimsAlign); - } + leftTopWeight = weightLocal.GetValue(point); + leftBottomWeight = weightLocal.GetValue(numPointsAlign + point); + rightTopWeiight = weightLocal.GetValue(numPointsAlign * 2 + point); + rightBottomWeight = weightLocal.GetValue(numPointsAlign * 3 + point); + + Muls(valueLocal[point * embedDimsAlign * 4], valueLocal[point * embedDimsAlign * 4], leftTopWeight, embedDimsAlign); + Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], leftBottomWeight, embedDimsAlign); + Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], rightTopWeiight, embedDimsAlign); + Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], rightBottomWeight, embedDimsAlign); + + Add(tmpResLocal[point * embedDimsAlign], valueLocal[point * embedDimsAlign * 4], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], embedDimsAlign); + Add(tmpResLocal2[point * embedDimsAlign], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], embedDimsAlign); + Add(tmpResLocal3[point * embedDimsAlign], tmpResLocal[point * embedDimsAlign], tmpResLocal2[point * embedDimsAlign], embedDimsAlign); + + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + DataCopyPad(outputGm[moveOffset + head * embedDims], tmpResLocal3[point * embedDimsAlign], copyParams); } - pipe_barrier(PIPE_ALL); - - Muls(leftTopValueLocal, leftTopValueLocal, leftTopWeiightLocal.GetValue(point), embedDimsAlign); - Muls(rightTopValueUbLocal, rightTopValueUbLocal, rightTopWeiightLocal.GetValue(point), embedDimsAlign); - Muls(leftBottomValueUbLocal, leftBottomValueUbLocal, leftBottomWeightLocal.GetValue(point), embedDimsAlign); - Muls(rightBottomValueUbLocal, rightBottomValueUbLocal, rightBottomWeightLocal.GetValue(point), embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(tmpResLocal, leftTopValueLocal, rightTopValueUbLocal, embedDimsAlign); - Add(tmpResLocal2, leftBottomValueUbLocal, rightBottomValueUbLocal, embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(tmpResLocal, tmpResLocal, tmpResLocal2, embedDimsAlign); - pipe_barrier(PIPE_ALL); - Muls(tmpResLocal, tmpResLocal, attentionWeightLocal.GetValue(level * numPoints + point), embedDimsAlign); - pipe_barrier(PIPE_ALL); - Add(resLocal, resLocal, tmpResLocal, embedDimsAlign); } - pipe_barrier(PIPE_ALL); - - SetAtomicAdd(); - DataCopyPad(outputGm[moveOffset + head * embedDims], resLocal, copyParams); - SetAtomicNone(); } } locationQueue.FreeTensor(locationLocal); attentionWeightsUb.FreeTensor(attentionWeightLocal); - outputQueue.FreeTensor(resLocal); + shapeQueue.FreeTensor(shapesLocal); offsetQueue.FreeTensor(offsetLocal); + + GetTPipePtr()->ReleaseEventID(eventIdVToMte3); + GetTPipePtr()->ReleaseEventID(eventIdMte2ToV); } private: - TPipe pipe; + TPipe *pipe; GlobalTensor valueGm, locationGm, attentionWeightsGm, outputGm; GlobalTensor valueSpatialShapesGm, valueLevelStartIndexGm; TQue locationQueue, attentionWeightsUb, shapeQueue, offsetQueue; TQue outputQueue; - TBuf tmpResUb, tmpResUb2, tmpXUb, tmpYUb, tmpParam0Ub, tmpParam1Ub, tmpIntX0Ub, tmpIntY0Ub, tmpIntX1Ub, tmpIntY1Ub, tmpUb1, tmpUb2, tmpUb3, tmpUb4; - TBuf intOneUb, floatOneUb, leftTopValueUb, leftBottomValueUb, rightTopValueUb, rightBottomValueUb; - TBuf leftTopWieightQueue, leftBottomWieightQueue, rightTopWieightQueue, rightBottomWieightQueue; + TBuf tmpResUb, tmpResUb2, tmpResUb3, tmpXUb, tmpYUb, tmpParam0Ub, tmpParam1Ub, tmpIntX0Ub, tmpIntY0Ub, tmpIntX1Ub, tmpIntY1Ub, tmpUb1, tmpUb2, tmpUb3, tmpUb4; + TBuf intOneUb, floatOneUb, leftTopWieightQueue, emptyUb; + TBuf valueUb; uint32_t batchSize; uint32_t numKeys; uint32_t numHeads; uint32_t embedDims; + uint32_t tailNum; uint32_t numLevels; uint32_t numQueries; @@ -341,6 +333,7 @@ private: uint32_t dataAlign; uint32_t blockNum = 32; + DTYPE_VALUE leftTopWeight, rightTopWeiight, leftBottomWeight, rightBottomWeight, attnWeight; DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, moveOffset; }; @@ -351,9 +344,10 @@ extern "C" __global__ __aicore__ void multi_scale_deformable_attn_function_v2(GM GM_ADDR attention_weights, GM_ADDR output, GM_ADDR workspace, GM_ADDR tiling) { + TPipe pipe; // GET_TILING_DATA(tiling_data, tiling); KernelMultiScaleDeformableAttnFunctionV2 op; op.Init(value, value_spatial_shapes, value_level_start_index, - sampling_locations, attention_weights, output, &tiling_data); + sampling_locations, attention_weights, output, &tiling_data, &pipe); op.Process(); } -- Gitee From 1aed8658b74e683fc5ba458d4217c21f58f93a67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E5=93=B2=E7=BB=AD?= Date: Sat, 24 Feb 2024 19:08:43 +0800 Subject: [PATCH 10/12] Type: Enhance MSDA. Team: Pytorch_Ops_Dev. Description: Enhance MSDA. --- ...ulti_scale_deformable_attn_function_v2.cpp | 113 ++++++++---------- 1 file changed, 48 insertions(+), 65 deletions(-) diff --git a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp index 0419e791..8be13356 100644 --- a/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp +++ b/ads/common/ops/kernels/op_kernel/multi_scale_deformable_attn_function_v2.cpp @@ -42,6 +42,8 @@ public: numPointsAlign = AlignUp(numPoints, dataAlign); numLevelsAlign = AlignUp(numLevels, dataAlign); + batchOffset = numPoints * embedDimsAlign; + curBlockIdx = GetBlockIdx(); startOffset = curBlockIdx * taskNumPerCore; endOffset = (curBlockIdx + 1) * taskNumPerCore; @@ -67,10 +69,7 @@ public: pipe->InitBuffer(emptyUb, BUFFER_NUM, embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpUb1, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpUb2, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpUb3, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpUb4, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpFloatUb, BUFFER_NUM, 4 * numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(intOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); pipe->InitBuffer(floatOneUb, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); @@ -80,17 +79,14 @@ public: pipe->InitBuffer(tmpParam0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); pipe->InitBuffer(tmpParam1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpIntX0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe->InitBuffer(tmpIntY0Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe->InitBuffer(tmpIntX1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); - pipe->InitBuffer(tmpIntY1Ub, BUFFER_NUM, numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); + pipe->InitBuffer(tmpIntUb, BUFFER_NUM, 4 * numPointsAlign * sizeof(DTYPE_VALUE_SPATIAL_SHAPES)); pipe->InitBuffer(leftTopWieightQueue, BUFFER_NUM, 4 * numPointsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(valueUb, BUFFER_NUM, numPoints * 4 * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpResUb, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpResUb2, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); - pipe->InitBuffer(tmpResUb3, BUFFER_NUM, numPoints * embedDimsAlign * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(valueUb, BUFFER_NUM, batchOffset * 4 * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb, BUFFER_NUM, batchOffset * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb2, BUFFER_NUM, batchOffset * sizeof(DTYPE_VALUE)); + pipe->InitBuffer(tmpResUb3, BUFFER_NUM, batchOffset * sizeof(DTYPE_VALUE)); } __aicore__ inline void Process() @@ -143,16 +139,8 @@ private: LocalTensor param0Local = tmpParam0Ub.Get(); LocalTensor param1Local = tmpParam1Ub.Get(); - LocalTensor x1Local = tmpIntX1Ub.Get(); - LocalTensor y1Local = tmpIntY1Ub.Get(); - - LocalTensor x0Local = tmpIntX0Ub.Get(); - LocalTensor y0Local = tmpIntY0Ub.Get(); - - LocalTensor tmpLocal1 = tmpUb1.Get(); - LocalTensor tmpLocal2 = tmpUb2.Get(); - LocalTensor tmpLocal3 = tmpUb3.Get(); - LocalTensor tmpLocal4 = tmpUb4.Get(); + LocalTensor tmpIntLocal = tmpIntUb.Get(); + LocalTensor tmpFloatLocal = tmpFloatUb.Get(); LocalTensor intOneLocal = intOneUb.Get(); LocalTensor floatOneLocal = floatOneUb.Get(); @@ -184,33 +172,25 @@ private: locationOffset = weightOffset * 2; for (uint32_t point = 0; point < numPoints; point++) { - xLocal.SetValue(point, locationLocal.GetValue(locationOffset + point * 2)); - yLocal.SetValue(point, locationLocal.GetValue(locationOffset + point * 2 + 1)); - } - - Muls(tmpLocal1, xLocal, (DTYPE_VALUE)w, numPointsAlign); - Muls(tmpLocal2, yLocal, (DTYPE_VALUE)h, numPointsAlign); - - Adds(param0Local, tmpLocal1, (DTYPE_VALUE)0.5, numPointsAlign); - Adds(param1Local, tmpLocal2, (DTYPE_VALUE)0.5, numPointsAlign); + tmp1 = locationLocal.GetValue(locationOffset + point * 2) * (DTYPE_VALUE)w; + tmp2 = locationLocal.GetValue(locationOffset + point * 2 + 1) * (DTYPE_VALUE)h; - Cast(x1Local, param0Local, RoundMode::CAST_FLOOR, numPointsAlign); - Cast(y1Local, param1Local, RoundMode::CAST_FLOOR, numPointsAlign); + tmpFloatLocal.SetValue(point, tmp1 + (DTYPE_VALUE)0.5); + tmpFloatLocal.SetValue(point + numPointsAlign, tmp2 + (DTYPE_VALUE)0.5); - Adds(tmpLocal3, param0Local, (DTYPE_VALUE)-1, numPointsAlign); - Adds(tmpLocal4, param1Local, (DTYPE_VALUE)-1, numPointsAlign); - - Sub(x0Local, x1Local, intOneLocal, numPointsAlign); - Sub(y0Local, y1Local, intOneLocal, numPointsAlign); + tmpFloatLocal.SetValue(point + numPointsAlign * 2, tmp1 - (DTYPE_VALUE)0.5); + tmpFloatLocal.SetValue(point + numPointsAlign * 3, tmp2 - (DTYPE_VALUE)0.5); + } - Cast(xLocal, x0Local, RoundMode::CAST_NONE, numPointsAlign); - Cast(yLocal, y0Local, RoundMode::CAST_NONE, numPointsAlign); + Cast(tmpIntLocal, tmpFloatLocal, RoundMode::CAST_FLOOR, 4 * numPointsAlign); + Cast(xLocal, tmpIntLocal, RoundMode::CAST_NONE, numPointsAlign); + Cast(yLocal, tmpIntLocal[numPointsAlign], RoundMode::CAST_NONE, numPointsAlign); - Sub(tmpLocal1, tmpLocal3, xLocal, numPointsAlign); - Sub(tmpLocal2, tmpLocal4, yLocal, numPointsAlign); + Sub(tmpFloatLocal, tmpFloatLocal[numPointsAlign * 2], xLocal, numPointsAlign); + Sub(tmpFloatLocal[numPointsAlign], tmpFloatLocal[numPointsAlign * 3], yLocal, numPointsAlign); - Abs(param0Local, tmpLocal1, numPointsAlign); - Abs(param1Local, tmpLocal2, numPointsAlign); + Abs(param0Local, tmpFloatLocal, numPointsAlign); + Abs(param1Local, tmpFloatLocal[numPointsAlign], numPointsAlign); Sub(xLocal, floatOneLocal, param0Local, numPointsAlign); Sub(yLocal, floatOneLocal, param1Local, numPointsAlign); @@ -227,34 +207,34 @@ private: valueOffset = (batch * numKeys * numHeads + offsetLocal.GetValue(level) * numHeads + head) * embedDims; - Duplicate(valueLocal, DTYPE_VALUE(0), 4 * numPoints * embedDimsAlign); + Duplicate(valueLocal, DTYPE_VALUE(0), 4 * batchOffset); for (uint32_t point = 0; point < numPoints; point++) { - x0 = x0Local.GetValue(point); - y0 = y0Local.GetValue(point); - x1 = x1Local.GetValue(point); - y1 = y1Local.GetValue(point); + x0 = tmpIntLocal.GetValue(point); + y0 = tmpIntLocal.GetValue(point + numPointsAlign); + x1 = tmpIntLocal.GetValue(point + numPointsAlign * 2); + y1 = tmpIntLocal.GetValue(point + numPointsAlign * 3); if (isInRange(x0, w)) { if (isInRange(y0, h)) { - DataCopy(valueLocal[point * embedDimsAlign * 4], valueGm[valueOffset + (y0 * w + x0) * tailNum], embedDimsAlign); + DataCopy(valueLocal[point * embedDimsAlign], valueGm[valueOffset + (y0 * w + x0) * tailNum], embedDimsAlign); } if (isInRange(y1, h)) { - DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], valueGm[valueOffset + (y1 * w + x0) * tailNum], embedDimsAlign); + DataCopy(valueLocal[batchOffset + point * embedDimsAlign], valueGm[valueOffset + (y1 * w + x0) * tailNum], embedDimsAlign); } } if (isInRange(x1, w)) { if (isInRange(y0, h)) { - DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueGm[valueOffset + (y0 * w + x1) * tailNum], embedDimsAlign); + DataCopy(valueLocal[batchOffset * 2 + point * embedDimsAlign], valueGm[valueOffset + (y0 * w + x1) * tailNum], embedDimsAlign); } if (isInRange(y1, h)) { - DataCopy(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], valueGm[valueOffset + (y1 * w + x1) * tailNum], embedDimsAlign); + DataCopy(valueLocal[batchOffset * 3 + point * embedDimsAlign], valueGm[valueOffset + (y1 * w + x1) * tailNum], embedDimsAlign); } } } @@ -268,17 +248,20 @@ private: rightTopWeiight = weightLocal.GetValue(numPointsAlign * 2 + point); rightBottomWeight = weightLocal.GetValue(numPointsAlign * 3 + point); - Muls(valueLocal[point * embedDimsAlign * 4], valueLocal[point * embedDimsAlign * 4], leftTopWeight, embedDimsAlign); - Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], leftBottomWeight, embedDimsAlign); - Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], rightTopWeiight, embedDimsAlign); - Muls(valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], rightBottomWeight, embedDimsAlign); + Muls(valueLocal[point * embedDimsAlign], valueLocal[point * embedDimsAlign], leftTopWeight, embedDimsAlign); + Muls(valueLocal[batchOffset + point * embedDimsAlign], valueLocal[batchOffset + point * embedDimsAlign], leftBottomWeight, embedDimsAlign); + Muls(valueLocal[batchOffset * 2 + point * embedDimsAlign], valueLocal[batchOffset * 2 + point * embedDimsAlign], rightTopWeiight, embedDimsAlign); + Muls(valueLocal[batchOffset * 3 + point * embedDimsAlign], valueLocal[batchOffset * 3 + point * embedDimsAlign], rightBottomWeight, embedDimsAlign); + } - Add(tmpResLocal[point * embedDimsAlign], valueLocal[point * embedDimsAlign * 4], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign], embedDimsAlign); - Add(tmpResLocal2[point * embedDimsAlign], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 2], valueLocal[point * embedDimsAlign * 4 + embedDimsAlign * 3], embedDimsAlign); - Add(tmpResLocal3[point * embedDimsAlign], tmpResLocal[point * embedDimsAlign], tmpResLocal2[point * embedDimsAlign], embedDimsAlign); + Add(tmpResLocal, valueLocal, valueLocal[batchOffset], batchOffset); + Add(tmpResLocal2, valueLocal[batchOffset * 2], valueLocal[batchOffset * 3], batchOffset); + Add(tmpResLocal3, tmpResLocal, tmpResLocal2, batchOffset); - SetFlag(eventIdVToMte3); - WaitFlag(eventIdVToMte3); + SetFlag(eventIdVToMte3); + WaitFlag(eventIdVToMte3); + for (uint32_t point = 0; point < numPoints; point++) + { DataCopyPad(outputGm[moveOffset + head * embedDims], tmpResLocal3[point * embedDimsAlign], copyParams); } } @@ -302,7 +285,7 @@ private: TQue locationQueue, attentionWeightsUb, shapeQueue, offsetQueue; TQue outputQueue; - TBuf tmpResUb, tmpResUb2, tmpResUb3, tmpXUb, tmpYUb, tmpParam0Ub, tmpParam1Ub, tmpIntX0Ub, tmpIntY0Ub, tmpIntX1Ub, tmpIntY1Ub, tmpUb1, tmpUb2, tmpUb3, tmpUb4; + TBuf tmpResUb, tmpResUb2, tmpResUb3, tmpXUb, tmpYUb, tmpParam0Ub, tmpParam1Ub, tmpIntUb, tmpFloatUb; TBuf intOneUb, floatOneUb, leftTopWieightQueue, emptyUb; TBuf valueUb; @@ -333,8 +316,8 @@ private: uint32_t dataAlign; uint32_t blockNum = 32; - DTYPE_VALUE leftTopWeight, rightTopWeiight, leftBottomWeight, rightBottomWeight, attnWeight; - DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, moveOffset; + DTYPE_VALUE tmp1, tmp2, leftTopWeight, rightTopWeiight, leftBottomWeight, rightBottomWeight, attnWeight; + DTYPE_VALUE_SPATIAL_SHAPES h, w, x0, y0, x1, y1, valueOffset, weightOffset, locationOffset, moveOffset, batchOffset; }; extern "C" __global__ __aicore__ void multi_scale_deformable_attn_function_v2(GM_ADDR value, -- Gitee From f3e0485ca4d6ebb9935e84b765d15188b198e89e Mon Sep 17 00:00:00 2001 From: l00636998 Date: Wed, 21 Feb 2024 06:49:43 +0000 Subject: [PATCH 11/12] =?UTF-8?q?!28=20=E6=96=B0=E5=A2=9EFurthesPointSampl?= =?UTF-8?q?ing=E7=AE=97=E5=AD=90=20Merge=20pull=20request=20!28=20from=20l?= =?UTF-8?q?00636998/master?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ads/common/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ads/common/__init__.py b/ads/common/__init__.py index 45920814..a57c71b1 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -32,3 +32,4 @@ from .ops.dynamic_voxelization import Voxelization from .ops.nms3d_normal import npu_nms3d_normal from .ops.furthest_point_sampling import npu_furthest_point_sampling from .ops.npu_nms3d import npu_nms3d +from .ops.furthest_point_sampling import npu_furthest_point_sampling -- Gitee From aecee1fb26cd720b8c9630f1e33a71fc37d81398 Mon Sep 17 00:00:00 2001 From: lishuai183 Date: Tue, 20 Feb 2024 15:25:15 +0800 Subject: [PATCH 12/12] init nms3d op. --- ads/common/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/ads/common/__init__.py b/ads/common/__init__.py index a57c71b1..17c0637a 100644 --- a/ads/common/__init__.py +++ b/ads/common/__init__.py @@ -33,3 +33,4 @@ from .ops.nms3d_normal import npu_nms3d_normal from .ops.furthest_point_sampling import npu_furthest_point_sampling from .ops.npu_nms3d import npu_nms3d from .ops.furthest_point_sampling import npu_furthest_point_sampling +from .ops.npu_nms3d import npu_nms3d -- Gitee