diff --git a/src/common/backend/utils/adt/vector.cpp b/src/common/backend/utils/adt/vector.cpp index 83cf27c0b3b5e4ade740deb40d74cb5219fc415e..d3ccedf860f92c04f73fb17ef54848756883728f 100644 --- a/src/common/backend/utils/adt/vector.cpp +++ b/src/common/backend/utils/adt/vector.cpp @@ -825,24 +825,27 @@ VECTOR_TARGET_CLONES float VectorInnerProduct(int dim, float *ax, float *bx) { float dis = 0.0f; - float32x4_t sum = vdupq_n_f32(0.0f); + float32x4_t sum0 = vdupq_n_f32(0.0); + float32x4_t sum1 = vdupq_n_f32(0.0); + float32x4_t sum2 = vdupq_n_f32(0.0); + float32x4_t sum3 = vdupq_n_f32(0.0); float *pta = ax; float *ptb = bx; - int i = 0; - int prefetch_len = 8; - int batch_num = 4; + int batch_num = 16; for (; i + batch_num <= dim; i += batch_num) { - prefetch_L1(pta + prefetch_len); - prefetch_L1(ptb + prefetch_len); - float32x4_t sub_a = vld1q_f32(pta); - float32x4_t sub_b = vld1q_f32(ptb); - sum = vmlaq_f32(sum, sub_a, sub_b); + float32x4x4_t packdata_a = vld1q_f32_x4(pta); + float32x4x4_t packdata_b = vld1q_f32_x4(ptb); + + sum0 = vmlaq_f32(sum0, packdata_a.val[0], packdata_b.val[0]); + sum1 = vmlaq_f32(sum1, packdata_a.val[1], packdata_b.val[1]); + sum2 = vmlaq_f32(sum2, packdata_a.val[2], packdata_b.val[2]); + sum3 = vmlaq_f32(sum3, packdata_a.val[3], packdata_b.val[3]); pta += batch_num; ptb += batch_num; } - dis = vaddvq_f32(sum); + dis = vaddvq_f32(sum0) + vaddvq_f32(sum1) + vaddvq_f32(sum2) + vaddvq_f32(sum3); for (; i < dim; ++i) { dis += ax[i] * bx[i]; }