diff --git a/third_party/acl/inc/aml/aml_fwk_detect.h b/third_party/acl/inc/aml/aml_fwk_detect.h index 26dbde4549643e80f7b8c675936dc8576feb85d1..4a9632b1649569da4a0cd31b55004adab8597a87 100644 --- a/third_party/acl/inc/aml/aml_fwk_detect.h +++ b/third_party/acl/inc/aml/aml_fwk_detect.h @@ -42,6 +42,8 @@ AmlStatus AmlAicoreDetectOnline(int32_t deviceId, const AmlAicoreDetectAttr *att AmlStatus AmlP2PDetectOnline(int32_t devId, void *comm, const AmlP2PDetectAttr *attr); +AmlStatus AmlStressRestore(int32_t deviceId); + #ifdef __cplusplus } #endif diff --git a/torch_npu/csrc/core/npu/interface/MlInterface.cpp b/torch_npu/csrc/core/npu/interface/MlInterface.cpp index 4008c8eb27ac669a9f7e230444b1e3454726accc..363e11025e5e879cc6a100444bd3454aef6c0dc6 100644 --- a/torch_npu/csrc/core/npu/interface/MlInterface.cpp +++ b/torch_npu/csrc/core/npu/interface/MlInterface.cpp @@ -34,6 +34,15 @@ bool IsExistAmlP2PDetectOnline() return isExist; } +bool IsExistAmlStressRestore() +{ + const static bool isExist = []() -> bool { + static auto func = GET_FUNC(AmlStressRestore); + return func != nullptr; + }(); + return isExist; +} + AmlStatus AmlAicoreDetectOnlineFace(int32_t deviceId, const AmlAicoreDetectAttr *attr) { typedef AmlStatus (*amlAicoreDetectOnline)(int32_t, const AmlAicoreDetectAttr *); @@ -56,5 +65,16 @@ AmlStatus AmlP2PDetectOnlineFace(int32_t deviceId, void *comm, const AmlP2PDetec return func(deviceId, comm, attr); } +AmlStatus AmlStressRestore(int32_t deviceId) +{ + typedef AmlStatus (*amlStressRestore)(int32_t); + static amlStressRestore func = nullptr; + if (func == nullptr) { + func = (amlStressRestore) GET_FUNC(AmlStressRestore); + } + TORCH_CHECK(func, "Failed to find function ", "AmlStressRestore", PTA_ERROR(ErrCode::NOT_FOUND)); + return func(deviceId); +} + } // namespace amlapi } // namespace c10_npu diff --git a/torch_npu/csrc/core/npu/interface/MlInterface.h b/torch_npu/csrc/core/npu/interface/MlInterface.h index 33389fcce3a53f7902c69ab7515dc7eb53a7d13b..695d118f3d786be859a0aac899e0369e5a762e26 100644 --- a/torch_npu/csrc/core/npu/interface/MlInterface.h +++ b/torch_npu/csrc/core/npu/interface/MlInterface.h @@ -13,6 +13,11 @@ bool IsExistAmlAicoreDetectOnline(); */ bool IsExistAmlP2PDetectOnline(); +/** + * This API is used to check whether AmlStressRestore exist. +*/ +bool IsExistAmlStressRestore(); + /** * This API is used to call AmlAicoreDetectOnline. */ @@ -23,5 +28,10 @@ AmlStatus AmlAicoreDetectOnlineFace(int32_t deviceId, const AmlAicoreDetectAttr */ AmlStatus AmlP2PDetectOnlineFace(int32_t deviceId, void *comm, const AmlP2PDetectAttr *attr); +/** + * This API is used to call AmlStressRestore. +*/ +AmlStatus AmlStressRestore(int32_t deviceId); + } // namespace amlapi } // namespace c10_npu diff --git a/torch_npu/csrc/npu/Stress_detect.cpp b/torch_npu/csrc/npu/Stress_detect.cpp index 4fb0f4a978356952cfd3373677a53ea051e0db96..532d56fbb32f2ea6ec011464e6c8bc79e4e8fb60 100644 --- a/torch_npu/csrc/npu/Stress_detect.cpp +++ b/torch_npu/csrc/npu/Stress_detect.cpp @@ -56,6 +56,11 @@ void StressDetector::worker_thread() ASCEND_LOGI("Stress detect with StressDetect start, device id is %d.", device_id); ret = c10_npu::acl::AclStressDetect(device_id, workspaceAddr, workspaceSize); ASCEND_LOGI("Stress detect with StressDetect end, device id is %d, result is %d.", device_id, ret); + if (c10_npu::amlapi::IsExistAmlStressRestore()) { + ASCEND_LOGI("Stress detect with AmlStressRestore start, device id is %d.", device_id); + ret = c10_npu::amlapi::AmlStressRestore(device_id); + ASCEND_LOGI("Stress detect with AmlStressRestore end, device id is %d, result is %d.", device_id, ret); + } } } else { if (c10_npu::amlapi::IsExistAmlP2PDetectOnline()) {