From edbc659387e6a54c050354e06bd7274c6e91f866 Mon Sep 17 00:00:00 2001 From: zhanhao Date: Mon, 25 Dec 2023 16:28:04 +0800 Subject: [PATCH] fix ci --- ads/common/ops/csrc/OpApiCommon.h | 23 ++++++++++++++++++++++- utils/extension.py | 7 +++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/ads/common/ops/csrc/OpApiCommon.h b/ads/common/ops/csrc/OpApiCommon.h index 717543ac..92332df1 100644 --- a/ads/common/ops/csrc/OpApiCommon.h +++ b/ads/common/ops/csrc/OpApiCommon.h @@ -1,3 +1,18 @@ +// Copyright (c) 2023 Huawei Technologies Co., Ltd +// Copyright (c) 2019, Facebook CORPORATION. +// All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include #include @@ -19,6 +34,12 @@ #define NPU_NAME_SPACE at_npu::native +#ifdef COMPILE_WITH_XLA +#define DEVICE_TYPE at_npu::key::NativeDeviceType +#else +#define DEVICE_TYPE c10::DeviceType::PrivateUse1 +#endif + typedef struct aclOpExecutor aclOpExecutor; typedef struct aclTensor aclTensor; typedef struct aclScalar aclScalar; @@ -180,7 +201,7 @@ inline at::Tensor CopyTensorHostToDevice(const at::Tensor &cpu_tensor) at::Tensor cpuPinMemTensor = cpu_tensor.pin_memory(); int deviceIndex = 0; return cpuPinMemTensor.to( - c10::Device(at_npu::key::NativeDeviceType, deviceIndex), cpuPinMemTensor.scalar_type(), true, true); + c10::Device(DEVICE_TYPE, deviceIndex), cpuPinMemTensor.scalar_type(), true, true); } inline at::Tensor CopyScalarToDevice(const c10::Scalar &cpu_scalar, at::ScalarType scalar_data_type) diff --git a/utils/extension.py b/utils/extension.py index 28e398ae..57908424 100644 --- a/utils/extension.py +++ b/utils/extension.py @@ -16,6 +16,7 @@ import os import site +from pkg_resources import parse_version import setuptools import torch @@ -75,4 +76,10 @@ def NpuExtension(name, sources, *args, **kwargs): kwargs['libraries'] = libraries kwargs['language'] = 'c++' + + define_macros = [] + if parse_version(torch.__version__) < parse_version('2.1.0'): + define_macros += [('COMPILE_WITH_XLA', None)] + kwargs['define_macros'] = define_macros + return setuptools.Extension(name, sources, *args, **kwargs) -- Gitee