diff --git a/torch_npu/csrc/core/npu/NPUAffinityController.cpp b/torch_npu/csrc/core/npu/NPUAffinityController.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6d951d6f6f2acd8a7bab052aa8f551df7fc2000f --- /dev/null +++ b/torch_npu/csrc/core/npu/NPUAffinityController.cpp @@ -0,0 +1,426 @@ + +#include "torch_npu/csrc/core/npu/NPUAffinityController.h" +#include "torch_npu/csrc/core/npu/NPUFunctions.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10_npu { + + static pthread_t mainthread_tid; + static pid_t parentPid; + static coreIdRange originalRange; + + const std::unordered_map threadTypeToNameMap = { + {releaseThread, "release_thread"}, + {aclThread, "acl_thread"}, + {mainThread, "main_thread"}, + {hcclCommWatchdogThread, "hcclComm_watchd"}, // thread name no more than 15 chars + {backwardThread, "backward_thread"}}; + + const std::unordered_map threadNameToTypeMap = { + {"release_thread", releaseThread}, + {"acl_thread", aclThread}, + {"main_thread", mainThread}, + {"hcclComm_watchd", hcclCommWatchdogThread}, + {"backward_thread", backwardThread}}; + + coreIdRange FindLongestCoreAffinityRange(pthread_t thread) + { + cpu_set_t mask; + CPU_ZERO(&mask); + + coreIdRange range = {-1, -1}; + int max_length = 0; + int current_start = -1; + int current_length = 0; + + if (pthread_getaffinity_np(thread, sizeof(mask), &mask) == 0) { + for (int i = 0; i < CPU_SETSIZE; i++) { + if (CPU_ISSET(i, &mask)) { + if (current_start == -1) { + current_start = i; + } + current_length++; + } else { + if (current_length > max_length) { + max_length = current_length; + range.start = current_start; + range.end = i - 1; + } + current_start = -1; + current_length = 0; + } + } + + if (current_length > max_length) { + max_length = current_length; + range.start = current_start; + range.end = CPU_SETSIZE - 1; + } + } else { + ASCEND_LOGW("[affinity] Failed to get thread affinity"); + } + + return range; + } + + void GetAffinityInfo() + { + mainthread_tid = pthread_self(); + parentPid = getpid(); + originalRange = FindLongestCoreAffinityRange(mainthread_tid); + ASCEND_LOGI("[affinity] Original Affinity is %d-%d", originalRange.start, originalRange.end); + } + + ThreadType getCurrentThreadType() + { + char thread_name[16]; + + if (prctl(PR_GET_NAME, thread_name, 0, 0, 0) == 0) { + std::string name(thread_name); + + auto it = threadNameToTypeMap.find(name); + if (it != threadNameToTypeMap.end()) { + return std::get<1>(*it); + } + } + return ThreadType::unknownThread; + } + + ThreadType getThreadType(pid_t tid) + { + char thread_name[16]; + std::string commFile = "/proc/" + std::to_string(tid) + "/comm"; // Path to thread name + + std::ifstream commStream(commFile); + if (commStream.is_open()) { + commStream.getline(thread_name, sizeof(thread_name)); + + std::string name(thread_name); + auto it = threadNameToTypeMap.find(name); + if (it != threadNameToTypeMap.end()) { + return std::get<1>(*it); + } + } + + return ThreadType::unknownThread; // Default if not found + } + + aclError SetThreadAffinity(coreIdRange core_range, pthread_t thread) + { + cpu_set_t mask; + CPU_ZERO(&mask); + + for (auto i = core_range.start; i <= core_range.end; i++) { + CPU_SET(i, &mask); + } + if (!pthread_setaffinity_np(thread, sizeof(mask), &mask)) { + ASCEND_LOGI("[affinity] Set Thread Affinity to %d-%d", core_range.start, core_range.end); + return ACL_ERROR_NONE; + } else { + ASCEND_LOGW("[affinity] Set Thread Affinity to %d-%d failed", core_range.start, core_range.end); + } + return ACL_ERROR_FEATURE_UNSUPPORTED; + } + + void bindToCoreRange(pid_t pid, const coreIdRange &core_range) + { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + + for (int core = core_range.start; core <= core_range.end; ++core) { + CPU_SET(core, &cpuset); + } + + if (sched_setaffinity(pid, sizeof(cpu_set_t), &cpuset) == -1) { + ASCEND_LOGW("[affinity] sched_setaffinity failed"); + } else { + ASCEND_LOGI("[affinity] Set Thread %d Affinity to %d-%d", pid, core_range.start, core_range.end); + } + } + + coreIdRange GetCPUDefaultRange(c10::DeviceIndex device_id) + { + int offset = originalRange.start; + int core_nums = originalRange.end - originalRange.start + 1; + int device_nums = device_count_ensure_non_zero(); + int block_size = (core_nums > 0 && device_nums > 0) ? (core_nums + device_nums - 1) / device_nums : 0; + return coreIdRange{offset + static_cast(device_id * block_size), + offset + static_cast(std::min((device_id + 1) * block_size, core_nums) - 1)}; + } + + std::string GetAffinityMapAsString(const std::unordered_map &threadToCoreidMap, c10::DeviceIndex device_id) + { + std::ostringstream oss; + oss << "threadToCoreidMap plan to bind device " << static_cast(device_id) << " to " + << " [" << threadToCoreidMap.at(unknownThread).start << "," << threadToCoreidMap.at(unknownThread).end << "]、" + << " [" << threadToCoreidMap.at(mainThread).start << "," << threadToCoreidMap.at(mainThread).end << "]、" + << " [" << threadToCoreidMap.at(backwardThread).start << "," << threadToCoreidMap.at(backwardThread).end << "]、" + << " [" << threadToCoreidMap.at(aclThread).start << "," << threadToCoreidMap.at(aclThread).end << "]、" + << " [" << threadToCoreidMap.at(releaseThread).start << "," << threadToCoreidMap.at(releaseThread).end << "]、" + << " [" << threadToCoreidMap.at(hcclCommWatchdogThread).start << "," << threadToCoreidMap.at(hcclCommWatchdogThread).end << "]"; + + return oss.str(); + } + + std::unordered_map GetCpuAffinityMap(c10::DeviceIndex device_id, coreIdRange current_core_range) + { + std::unordered_map threadToCoreidMap; + std::initializer_list thread_types = {unknownThread, mainThread, backwardThread, aclThread, + releaseThread, hcclCommWatchdogThread}; + + coreId offset = current_core_range.start; + + // calculate env2 default map + coreId core_nums = current_core_range.end - current_core_range.start; + if (core_nums < thread_types.size()) { + ASCEND_LOGW("[affinity] Available core numbers (%d) are insufficient for all %zu thread types. Binding available cores to all threads.", + core_nums, thread_types.size()); + for (auto thread_type : thread_types) { + threadToCoreidMap[thread_type] = current_core_range; + } + } else { + int remaining_type_count = thread_types.size() - 1; + int i = 0; + for (auto thread_type : thread_types) { + if (thread_type == ThreadType::unknownThread) { + threadToCoreidMap[ThreadType::unknownThread] = coreIdRange{current_core_range.start + remaining_type_count, current_core_range.end}; + } else { + threadToCoreidMap[thread_type] = coreIdRange{offset + i, offset + (i++)}; + } + } + } + + ASCEND_LOGI("[affinity] Thread affinity map for device %d: %s", device_id, GetAffinityMapAsString(threadToCoreidMap, device_id).c_str()); + + return threadToCoreidMap; + } + + aclError SetThreadAffinity(c10::DeviceIndex device_id) + { + return SetThreadAffinity(device_id, getCurrentThreadType()); + } + + void printCoreRanges(const std::vector &ranges, uint32_t mode) + { + std::ostringstream oss; + oss << "Mode: " << mode << " "; + + for (size_t i = 0; i < ranges.size(); ++i) { + oss << "Device " << i << " Core Range: " << ranges[i].start << " - " << ranges[i].end << " "; + } + + ASCEND_LOGI("[affinity] Core ranges: %s", oss.str().c_str()); + } + + bool isAllDigits(const std::string &str) + { + if (str.empty()) { + return false; + } + return std::all_of(str.begin(), str.end(), [](unsigned char c) { + return std::isdigit(c); + }); + } + + void parseCPUAffinityConf(uint32_t &mode, std::vector &ranges) + { + const char *input = c10_npu::option::OptionsManager::GetCpuAffinityConf(); + + if (input == nullptr || strlen(input) == 0) { + mode = 0; + return; + } + + mode = 0; + int device_nums = device_count_ensure_non_zero(); + ranges.clear(); + ranges.resize(device_nums); + + // init + for (int i = 0; i < device_nums; ++i) { + ranges[i] = GetCPUDefaultRange(i); + } + + std::string inputStr(input); + std::istringstream stream(inputStr); + std::string option; + + // Handle cases where only `mode` is provided, or `mode:` without value + if (isAllDigits(inputStr)) { + mode = static_cast(std::stoi(inputStr)); + return; // Return directly, `mode` has already been processed + } + + // Parse each option + while (std::getline(stream, option, ',')) { + // Split `option` based on colon + size_t colonPos = option.find(':'); + if (colonPos != std::string::npos) { + std::string key = option.substr(0, colonPos); + std::string value = option.substr(colonPos + 1); + + // Process `mode` + if (key == "mode") { + if (isAllDigits(value)) { + mode = static_cast(std::stoi(value)); + } else { + ASCEND_LOGW("[affinity] mode is %s, should be all digits", value.c_str()); + } + } else if (key.rfind("npu", 0) == 0) { + // Handle NPU core binding range + if (isAllDigits(key.substr(3))) { + int device_id = std::stoi(key.substr(3)); // Parse NPU device ID + if (device_id < device_nums) { + size_t dashPos = value.find('-'); + if (dashPos != std::string::npos) { + std::string startStr = value.substr(0, dashPos); + std::string endStr = value.substr(dashPos + 1); + if (isAllDigits(startStr) && isAllDigits(endStr)) { + coreId start = static_cast(std::stoi(startStr)); + coreId end = static_cast(std::stoi(endStr)); + ranges[device_id] = {start, end}; + } else { + ASCEND_LOGW("[affinity] core range is %s-%s, should be all digits", startStr.c_str(), endStr.c_str()); + } + } else { + if (isAllDigits(value)) { + coreId singleCore = static_cast(std::stoi(value)); + ranges[device_id] = {singleCore, singleCore}; + } else { + ASCEND_LOGW("[affinity] core range is string : %s, should be all digits", value.c_str()); + } + } + } + } + } + } else if (isAllDigits(option)) { + // If no colon and the value is a number, use it directly as `mode` + mode = static_cast(std::stoi(option)); + } + } + } + + // Function to execute a shell command and capture its output + std::string executeCommand(const std::string &exe) + { + std::array buffer; + std::string result; + std::shared_ptr pipe(popen(exe.c_str(), "r"), pclose); + if (!pipe) { + ASCEND_LOGE("[affinity] %s failed.", exe.c_str()); + } + while (fgets(buffer.data(), buffer.size(), pipe.get()) != nullptr) { + result += buffer.data(); + } + return result; + } + + // Function to parse PIDs and TIDs from pstree output + std::vector parsePIDsFromPstree(const std::string &pstreeOutput) + { + std::vector pids; + std::regex pidRegex(R"(\((\d+)\))"); // Matches numbers inside parentheses + std::smatch match; + std::string::const_iterator searchStart(pstreeOutput.cbegin()); + while (std::regex_search(searchStart, pstreeOutput.cend(), match, pidRegex)) { + pids.push_back(std::stoi(match[1])); + searchStart = match.suffix().first; + } + return pids; + } + + void SetAffinityForRemainingTasks(coreIdRange core_range) + { + // Check if the platform is Linux +#ifdef __linux__ + // Check if pstree command exists + if (access("/usr/bin/pstree", F_OK) == 0) { + // Run pstree to get child processes and threads + std::string pstreeCommand = "/usr/bin/pstree -p " + std::to_string(parentPid) + " -t"; + std::string pstreeOutput = executeCommand(pstreeCommand); + + // Parse PIDs/TIDs from the pstree output + std::vector pids = parsePIDsFromPstree(pstreeOutput); + ASCEND_LOGI("[affinity] Parse %d PIDs/TIDs from the pstree output of parentPid %d", pids.size(), parentPid); + + // Bind each PID/TID to the core range + for (pid_t pid : pids) { + ThreadType type = getThreadType(pid); + if (type == ThreadType::unknownThread && pid != parentPid) { + bindToCoreRange(pid, core_range); + } + } + } else { + ASCEND_LOGW("[affinity] pstree not found. Please install pstree or check your PATH."); + } +#else + ASCEND_LOGW("[affinity] This function is only supported on Linux platforms."); +#endif + } + + aclError SetThreadAffinity(c10::DeviceIndex device_id, ThreadType current_thread_type) + { + uint32_t bind_conf; + std::vector ranges; + parseCPUAffinityConf(bind_conf, ranges); + printCoreRanges(ranges, bind_conf); + + // bind_conf=1, bind cores averagely based on device_id + if (bind_conf == 1) { + return SetThreadAffinity(ranges[device_id], pthread_self()); + } else if (bind_conf == 2) { + auto thread_core_map = GetCpuAffinityMap(device_id, ranges[device_id]); + // When the PTA_init function runs on device 0, the main thread is initially assigned to this device 0. + // However, when the acl_thread is initialized, the target device ID(maybe 0-7) is determined. + // Therefore, the main thread should be rescheduled to the target device. + if (current_thread_type == ThreadType::aclThread) + SetThreadAffinity(thread_core_map.at(ThreadType::mainThread), mainthread_tid); + // To isolate interference, all such processes must be confined to separate regions before the dispatch phase. + if (current_thread_type == ThreadType::backwardThread || current_thread_type == ThreadType::unknownThread) { + SetThreadAffinity(thread_core_map.at(ThreadType::mainThread), mainthread_tid); + SetAffinityForRemainingTasks(thread_core_map.at(ThreadType::unknownThread)); + } + return SetThreadAffinity(thread_core_map.at(current_thread_type), pthread_self()); + } else { + ASCEND_LOGI("[affinity] Thread affinity setting is disabled."); + } + return ACL_ERROR_NONE; + } + + void SetBackwardThreadName(c10::DeviceIndex device_id) + { + static thread_local bool seted = false; + if (!seted) { + seted = true; + if (syscall(SYS_gettid) != getpid()) { + ASCEND_LOGI("[affinity] Set Backward Thread Name"); + SetThreadName(ThreadType::backwardThread); + SetThreadAffinity(device_id); + } + } + } + + void SetThreadName(ThreadType type) + { + // Ensure this is called at the start of the thread's execution to avoid frequent triggering of this function. + if (prctl(PR_SET_NAME, threadTypeToNameMap.at(type).c_str()) != 0) { + ASCEND_LOGW("[affinity] set thread name failed!"); + } + } + +} \ No newline at end of file diff --git a/torch_npu/csrc/core/npu/NPUAffinityController.h b/torch_npu/csrc/core/npu/NPUAffinityController.h new file mode 100644 index 0000000000000000000000000000000000000000..f2e78b69b68c689835f55733396f18445c2a11ed --- /dev/null +++ b/torch_npu/csrc/core/npu/NPUAffinityController.h @@ -0,0 +1,35 @@ +#pragma once +#include "torch_npu/csrc/core/npu/npu_log.h" + +namespace c10_npu { + + typedef unsigned int coreId; + + struct coreIdRange { + coreId start; + coreId end; + }; + + enum ThreadType { + unknownThread = 0, // Mostly refers to threads in PyTorch's motorized sleep thread pool, which are not considered in PTA. + mainThread = 1, // 1st performance hotspot, responsible for operator dispatching during the forward phase. + backwardThread = 2, // 2nd performance hotspot, responsible for operator dispatching during the backward phase. + aclThread = 3, // 3rd performance hotspot in PTA, responsible for handling the task queue. + releaseThread = 4, // Thread responsible for resource release. + hcclCommWatchdogThread = 5 // Thread responsible for HCCL communication monitoring. + }; + + aclError SetThreadAffinity(c10::DeviceIndex device); + aclError SetThreadAffinity(c10::DeviceIndex device, ThreadType current_thread_type); + void SetThreadName(ThreadType type); + + // The main thread of PTA, which is also the main thread of PyTorch, handles multiple phases of tasks + // (e.g., first parallel checkpoint data loading, then transitioning to forward training). + // Each phase may require different thread affinity settings. Therefore, we record the thread's TID + // to adjust its affinity later as needed. + void GetAffinityInfo(); + + // Set backwardThread Name Once + void SetBackwardThreadName(c10::DeviceIndex device_id); + +} \ No newline at end of file diff --git a/torch_npu/csrc/core/npu/NPUFunctions.cpp b/torch_npu/csrc/core/npu/NPUFunctions.cpp index 59456b3349d9aa0424beb4e4c75873d3025f6f11..4b7a40ec11a0932007e2d04c467f14910e82f81d 100644 --- a/torch_npu/csrc/core/npu/NPUFunctions.cpp +++ b/torch_npu/csrc/core/npu/NPUFunctions.cpp @@ -66,20 +66,6 @@ aclError GetDevice(int32_t *device) return err; } -inline bool has_set_pthread_affinity() -{ - unsigned int core_nums = static_cast(sysconf(_SC_NPROCESSORS_ONLN)); - - cpu_set_t mask; - pthread_getaffinity_np(pthread_self(), sizeof(mask), &mask); - for (unsigned int i = 0; i < core_nums; i++) { - if (!CPU_ISSET(i, &mask)) { - return true; - } - } - return false; -} - aclError SetDevice(c10::DeviceIndex device) { TORCH_CHECK(device >= 0, "device id must be positive!", PTA_ERROR(ErrCode::VALUE)); @@ -88,26 +74,6 @@ aclError SetDevice(c10::DeviceIndex device) return ACL_ERROR_NONE; } - static uint32_t bind_conf = c10_npu::option::OptionsManager::GetCpuAffinityConf(); - // bind_conf=1, bind cores averagely based on device_id - if (bind_conf == 1) { - static const bool set_pthread_affinity = has_set_pthread_affinity(); - if (!set_pthread_affinity) { - int core_nums = sysconf(_SC_NPROCESSORS_ONLN); - int device_nums = device_count_ensure_non_zero(); - int block_size = (core_nums + device_nums - 1) / device_nums; - unsigned int start_core = static_cast(device * block_size); - unsigned int end_core = static_cast(std::min((device + 1) * block_size, core_nums)); - - cpu_set_t mask; - CPU_ZERO(&mask); - for (unsigned int i = start_core; i < end_core; i++) { - CPU_SET(i, &mask); - } - pthread_setaffinity_np(pthread_self(), sizeof(mask), &mask); - } - } - aclError err = aclrtSetDevice(device); if (err == ACL_ERROR_NONE) { local_device = device; diff --git a/torch_npu/csrc/core/npu/NPUQueue.cpp b/torch_npu/csrc/core/npu/NPUQueue.cpp index 39bb3514f1305f4d71b9e0da7815cc8149655ba1..0ea9d985271bbde4bb5516d0eef0c683684df9ae 100644 --- a/torch_npu/csrc/core/npu/NPUQueue.cpp +++ b/torch_npu/csrc/core/npu/NPUQueue.cpp @@ -1,6 +1,7 @@ #include "torch_npu/csrc/core/npu/NPUQueue.h" #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/core/npu/npu_log.h" +#include "torch_npu/csrc/core/npu/NPUAffinityController.h" #include "torch_npu/csrc/framework/utils/NpuUtils.h" #include "torch_npu/csrc/core/npu/NPUFunctions.h" #include "torch_npu/csrc/framework/OpParamMaker.h" @@ -15,7 +16,6 @@ #include #include #include -#include #include namespace c10_npu { @@ -587,9 +587,8 @@ bool Repository::CheckInit() const { } void StartConsume(Repository* repo, c10::DeviceIndex device_id) { - if (prctl(PR_SET_NAME, ("ACL_thread")) != 0) { - ASCEND_LOGE("set thread name failed!"); - } + SetThreadName(ThreadType::aclThread); + SetThreadAffinity(device_id); aclError ret = c10_npu::SetDevice(device_id); if (ret != 0) { @@ -619,7 +618,7 @@ void Repository::InitRepo(c10::DeviceIndex device_id) { std::thread cur_consumer(StartConsume, this, device_id); consumer = std::move(cur_consumer); - releaseQueue.InitReleaseQueue(); + releaseQueue.InitReleaseQueue(device_id); } std::string Repository::GetPara() @@ -697,17 +696,17 @@ void ReleaseQueue::PopFromReleaseQueue() { } void StartRelease(ReleaseQueue* releaseQue) { - if (prctl(PR_SET_NAME, ("Release_thread")) != 0) { - ASCEND_LOGE("set thread name failed!"); - } + SetThreadName(ThreadType::releaseThread); + SetThreadAffinity(releaseQue->GetDeviceID()); - while (releaseQue->GetStatus() != RepoStatus::CAN_EXIT) { - releaseQue->PopFromReleaseQueue(); - } - return; + while (releaseQue->GetStatus() != RepoStatus::CAN_EXIT) { + releaseQue->PopFromReleaseQueue(); + } + return; } -void ReleaseQueue::InitReleaseQueue() { +void ReleaseQueue::InitReleaseQueue(c10::DeviceIndex device_id) +{ if (datas == nullptr) { datas = releaseManager().Init(kReleaseQueueCapacity); } @@ -716,6 +715,7 @@ void ReleaseQueue::InitReleaseQueue() { SetStatus(INIT); std::thread cur_releaser(StartRelease, this); releaser = std::move(cur_releaser); + device_idx = device_id; } ReleaseQueue::~ReleaseQueue() { @@ -740,6 +740,12 @@ RepoStatus ReleaseQueue::GetStatus() const { return repo_status.load(); } +c10::DeviceIndex ReleaseQueue::GetDeviceID() const +{ + return device_idx; +} + + void ReleaseQueue::SetStatus(RepoStatus desired) { if (initialized == false) { ASCEND_LOGE("Release queue is not initialized, shouldn't call SetStatus(). !!"); diff --git a/torch_npu/csrc/core/npu/NPUQueue.h b/torch_npu/csrc/core/npu/NPUQueue.h index 66e648069fb48c0d6b9eb776594dd0554bd956c5..2375ef945b985bcdb6339fb47be50590e0c2f3c9 100644 --- a/torch_npu/csrc/core/npu/NPUQueue.h +++ b/torch_npu/csrc/core/npu/NPUQueue.h @@ -38,8 +38,9 @@ public: ~ReleaseQueue(); void PushToReleaseQueue(void* cur_paras); void PopFromReleaseQueue(); - void InitReleaseQueue(); + void InitReleaseQueue(c10::DeviceIndex device_id); RepoStatus GetStatus() const; + c10::DeviceIndex GetDeviceID() const; private: inline bool IsEmptyQueue() {return read_idx.idx == write_idx.idx;}; @@ -52,6 +53,7 @@ private: private: void* datas = nullptr; std::thread releaser; + c10::DeviceIndex device_idx; private: sring_idx read_idx; diff --git a/torch_npu/csrc/core/npu/impl/NPUGuardImpl.h b/torch_npu/csrc/core/npu/impl/NPUGuardImpl.h index 4359db01364d853854a67bcb17469b0d66c32284..705e772799ca4e3412f29a2d545f8913555e5f63 100644 --- a/torch_npu/csrc/core/npu/impl/NPUGuardImpl.h +++ b/torch_npu/csrc/core/npu/impl/NPUGuardImpl.h @@ -8,6 +8,7 @@ #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/NPUException.h" #include "torch_npu/csrc/core/npu/NPUFunctions.h" +#include "torch_npu/csrc/core/npu/NPUAffinityController.h" #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h" #include "torch_npu/csrc/aten/NPUNativeFunctions.h" @@ -53,6 +54,7 @@ struct NPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { uncheckedSetDevice(d); } void uncheckedSetDevice(c10::Device d) const noexcept override { + SetBackwardThreadName(d.index()); NPU_CHECK_WARN(c10_npu::SetDevice(d.index())); } c10::Stream getStream(c10::Device d) const noexcept override { diff --git a/torch_npu/csrc/core/npu/register/OptionsManager.cpp b/torch_npu/csrc/core/npu/register/OptionsManager.cpp index 850336102048b30b0e748e941c8b2a28d97f9b65..6a07e170c46701a8f86ab83c12827187cc28f4e3 100644 --- a/torch_npu/csrc/core/npu/register/OptionsManager.cpp +++ b/torch_npu/csrc/core/npu/register/OptionsManager.cpp @@ -350,14 +350,9 @@ uint32_t OptionsManager::GetP2PBufferSize() return buf_size; } -uint32_t OptionsManager::GetCpuAffinityConf() +char* OptionsManager::GetCpuAffinityConf() { - const static uint32_t cpu_affinity_conf = []() -> uint32_t { - char* cpu_affinity_str = std::getenv("CPU_AFFINITY_CONF"); - int64_t cpu_affinity_conf = (cpu_affinity_str != nullptr) ? strtol(cpu_affinity_str, nullptr, 10) : 0; - return static_cast(cpu_affinity_conf); - }(); - return cpu_affinity_conf; + return std::getenv("CPU_AFFINITY_CONF"); } uint32_t OptionsManager::GetTaskQueueEnable() diff --git a/torch_npu/csrc/core/npu/register/OptionsManager.h b/torch_npu/csrc/core/npu/register/OptionsManager.h index 98e8fd72dc364ea330d821ba740194ea1cd114a4..65a9c38a4b9342d7c92e3ee402f067edcda86305 100644 --- a/torch_npu/csrc/core/npu/register/OptionsManager.h +++ b/torch_npu/csrc/core/npu/register/OptionsManager.h @@ -51,7 +51,7 @@ public: static std::pair GetSilenceSigmaThresh(); static uint32_t GetP2PBufferSize(); static uint32_t GetTaskQueueEnable(); - static uint32_t GetCpuAffinityConf(); + static char* GetCpuAffinityConf(); static bool CheckForceUncached(); static std::string GetOomSnapshotDumpPath(); static void IsOomSnapshotEnable(); diff --git a/torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.cpp b/torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.cpp index 679b2a262ae52eae5c4fe2b83e1c66eb55d97813..9081b686c4dc6d3864994726e8b91f9bf283381e 100644 --- a/torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.cpp +++ b/torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.cpp @@ -13,6 +13,7 @@ #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/NPUWorkspaceAllocator.h" #include "torch_npu/csrc/core/npu/NPUStream.h" +#include "torch_npu/csrc/core/npu/NPUAffinityController.h" #include "torch_npu/csrc/core/npu/NpuVariables.h" #include "torch_npu/csrc/core/npu/register/OptionRegister.h" #include "torch_npu/csrc/core/npu/register/OptionsManager.h" @@ -223,6 +224,7 @@ NpuSysCtrl::SysStatus NpuSysCtrl::Initialize(int device_id) ASCEND_LOGW("Npu device %d has been set before global init.", device_id_); } + GetAffinityInfo(); if (c10_npu::option::OptionsManager::CheckAclDumpDateEnable()) { const char *aclConfigPath = "acl.json"; @@ -266,8 +268,10 @@ NpuSysCtrl::SysStatus NpuSysCtrl::Initialize(int device_id) const auto& in = iter.second; call_(in); } + lazy_fn_.clear(); + init_flag_ = true; ASCEND_LOGD("Npu sys ctrl initialize successfully."); diff --git a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp index e8a9f5a2834b9de7e81c98e96f1f54e8399c2f1b..fc9e268f076bc9667a14ca4526b3dd313721b74a 100644 --- a/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp +++ b/torch_npu/csrc/distributed/ProcessGroupHCCL.cpp @@ -25,6 +25,7 @@ #include "torch_npu/csrc/core/NPUStorageImpl.h" #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/NPUGuard.h" +#include "torch_npu/csrc/core/npu/NPUAffinityController.h" #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/core/npu/register/OptionsManager.h" #include "torch_npu/csrc/distributed/HCCLUtils.hpp" @@ -791,6 +792,8 @@ ProcessGroupHCCL::~ProcessGroupHCCL() void ProcessGroupHCCL::hcclCommWatchdog() { try { + c10_npu::SetThreadName(c10_npu::ThreadType::hcclCommWatchdogThread); + VLOG(2) << "[Rank " << rank_ << "] HCCL watchdog thread started!"; workCleanupLoop(); VLOG(2) << "[Rank " << rank_ @@ -873,7 +876,9 @@ void ProcessGroupHCCL::workCleanupLoop() auto& work = *it; try { if (needSetDevice) { - NPU_CHECK_ERROR(c10_npu::SetDevice(static_cast(work.devices_[0].index()))); + c10::DeviceIndex device = static_cast(work.devices_[0].index()); + c10_npu::SetThreadAffinity(device); + NPU_CHECK_ERROR(c10_npu::SetDevice(device)); needSetDevice = false; } } catch (const std::exception& e) { diff --git a/torch_npu/csrc/npu/Module.cpp b/torch_npu/csrc/npu/Module.cpp index d73b536e94c520d305795b57a95bdaf28f08812c..948061f00899caafec9532ada7d23316c551dac0 100644 --- a/torch_npu/csrc/npu/Module.cpp +++ b/torch_npu/csrc/npu/Module.cpp @@ -24,6 +24,7 @@ #include "torch_npu/csrc/core/npu/NPUCachingAllocator.h" #include "torch_npu/csrc/core/npu/NPUStream.h" #include "torch_npu/csrc/core/npu/NPUQueue.h" +#include "torch_npu/csrc/core/npu/NPUAffinityController.h" #include "torch_npu/csrc/core/npu/NPUGuard.h" #include "torch_npu/csrc/core/npu/NpuVariables.h" #include "torch_npu/csrc/core/npu/sys_ctrl/npu_sys_ctrl.h" @@ -1211,6 +1212,28 @@ PyObject* THNPModule_npu_support_silentClientV2(PyObject* self, PyObject* noargs END_HANDLE_TH_ERRORS } +PyObject* THNPModule_npu_set_thread_affinity(PyObject* self, PyObject* noargs) +{ + HANDLE_TH_ERRORS + int device_index; + NPU_CHECK_ERROR_WITHOUT_UCE(c10_npu::GetDevice(&device_index)); + c10::DeviceIndex device = static_cast(device_index); + c10_npu::SetThreadAffinity(device, c10_npu::ThreadType::mainThread); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* THNPModule_npu_reset_thread_affinity(PyObject* self, PyObject* noargs) +{ + HANDLE_TH_ERRORS + int device_index; + NPU_CHECK_ERROR_WITHOUT_UCE(c10_npu::GetDevice(&device_index)); + c10::DeviceIndex device = static_cast(device_index); + c10_npu::SetThreadAffinity(device, c10_npu::ThreadType::unknownThread); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static struct PyMethodDef THNPModule_methods[] = { {"_npu_init", (PyCFunction)THNPModule_initExtension, METH_NOARGS, nullptr}, {"_npu_set_run_yet_variable_to_false", (PyCFunction)THNPModule_set_run_yet_variable_to_false_wrap, METH_NOARGS, nullptr}, @@ -1260,6 +1283,8 @@ static struct PyMethodDef THNPModule_methods[] = { {"_npu_set_call_state", (PyCFunction)THNPModule_npu_set_call_state, METH_O, nullptr}, {"_npu_set_module_train_state", (PyCFunction)THNPModule_npu_set_module_train_state, METH_O, nullptr}, {"_npu_support_silentClientV2", (PyCFunction)THNPModule_npu_support_silentClientV2, METH_NOARGS, nullptr}, + {"_npu_set_threads_affinity", (PyCFunction)THNPModule_npu_set_thread_affinity, METH_NOARGS, nullptr}, + {"_npu_reset_threads_affinity", (PyCFunction)THNPModule_npu_reset_thread_affinity, METH_NOARGS, nullptr}, {nullptr}}; TORCH_NPU_API PyMethodDef* THNPModule_get_methods() { diff --git a/torch_npu/utils/_module.py b/torch_npu/utils/_module.py index 92213c7ca99bdc59e83cb2de8bcd05c9ac068270..50f8e6ad3bac12ea417460db0641b37e643ba81d 100644 --- a/torch_npu/utils/_module.py +++ b/torch_npu/utils/_module.py @@ -362,7 +362,9 @@ def _mpdl_iter_init(self, *args, **kwargs): torch_npu.npu.synchronize() except: pass + torch_npu._C._npu_reset_threads_affinity() origin_mpdl_iter_init(self, *args, **kwargs) + torch_npu._C._npu_set_threads_affinity() def _parallel_apply(