From 601420e7784f5dea43ceab0c21e75b1306f72820 Mon Sep 17 00:00:00 2001 From: yu-liang-bin Date: Thu, 11 Sep 2025 09:17:25 +0800 Subject: [PATCH] fix cann --- .../prof_db_parse/_fwk_api_db_parser.py | 208 +++++++++++++----- 1 file changed, 149 insertions(+), 59 deletions(-) diff --git a/torch_npu/profiler/analysis/prof_view/prof_db_parse/_fwk_api_db_parser.py b/torch_npu/profiler/analysis/prof_view/prof_db_parse/_fwk_api_db_parser.py index 98ea10ea68..d3a6405407 100644 --- a/torch_npu/profiler/analysis/prof_view/prof_db_parse/_fwk_api_db_parser.py +++ b/torch_npu/profiler/analysis/prof_view/prof_db_parse/_fwk_api_db_parser.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from ...prof_common_func._db_manager import TorchDb from .._base_parser import BaseParser from ...prof_common_func._log import ProfilerLogger @@ -79,7 +81,7 @@ class FwkApiDbParser(BaseParser): mstx_mark_apis.sort(key=lambda x: x[TorchOpDataOri.START_NS]) mstx_op_len = len(mstx_mark_apis) if task_enqueues and task_dequeues: - self.get_torch_op_connection_ids_with_task_queue(task_enqueues, task_dequeues, mstx_mark_apis, mstx_op_len, + self.get_torch_op_connection_ids_with_task_queue(task_enqueues, task_dequeues, mstx_mark_apis, cann_tx_apis) def get_torch_op_connection_ids_with_cann_api(self, task_enqueues: list, task_dequeues: list, torch_op_apis: list): @@ -99,70 +101,158 @@ class FwkApiDbParser(BaseParser): if not node_launch_apis: raise RuntimeWarning("Failed to get node launch apis") torch_op_apis.sort(key=lambda x: x[TorchOpDataOri.START_NS]) - torch_op_len = len(torch_op_apis) + task_dequeues.sort(key=lambda x: x[TaskQueueDataOri.START_NS]) if task_enqueues and task_dequeues: - self.get_torch_op_connection_ids_with_task_queue(task_enqueues, task_dequeues, torch_op_apis, torch_op_len, + self.get_torch_op_connection_ids_with_task_queue(task_enqueues, task_dequeues, torch_op_apis, node_launch_apis) else: - self.get_torch_op_connection_ids_without_task_queue(torch_op_apis, torch_op_len, node_launch_apis) - - def get_torch_op_connection_ids_with_task_queue(self, task_enqueues: list, task_dequeues: list, torch_op_apis: list, torch_op_len: int, node_lauch_apis: list): - connection_id_manager = ConnectionIdManager() - enqueue_corr_ids = {connection_id_manager.get_connection_ids_from_id(task_enqueue[TaskQueueDataOri.CORRELATION_ID])[0] for task_enqueue in task_enqueues} - dequeue_corr_ids = {connection_id_manager.get_connection_ids_from_id(task_dequeue[TaskQueueDataOri.CORRELATION_ID])[0] for task_dequeue in task_dequeues} - matched_corr_ids = enqueue_corr_ids & dequeue_corr_ids - enqueue_list = [enqueue for enqueue in task_enqueues if connection_id_manager.get_connection_ids_from_id(enqueue[TaskQueueDataOri.CORRELATION_ID])[0] in matched_corr_ids] - dequeue_list = [dequeue for dequeue in task_dequeues if connection_id_manager.get_connection_ids_from_id(dequeue[TaskQueueDataOri.CORRELATION_ID])[0] in matched_corr_ids] - - last_dequeue_index = 0 - last_torch_op_index = 0 - dequeue_len = len(dequeue_list) - for node_launch_api in node_lauch_apis: - for idx in range(last_dequeue_index, dequeue_len): - if node_launch_api[CannNodeLaunchApiOri.START_NS] > dequeue_list[idx][TaskQueueDataOri.START_NS] and \ - node_launch_api[CannNodeLaunchApiOri.END_NS] < dequeue_list[idx][TaskQueueDataOri.END_NS]: - last_dequeue_index = idx - enqeue = enqueue_list[idx] - last_torch_op_index = self.get_torch_op_connection_ids_with_enqueue(torch_op_apis, - torch_op_len, - enqeue, - last_torch_op_index, - node_launch_api[CannNodeLaunchApiOri.CORRELATION_ID]) + self._get_torch_op_connection_ids_without_task_queue(torch_op_apis, node_launch_apis) + + def get_torch_op_connection_ids_with_task_queue(self, task_enqueues: list, task_dequeues: list, torch_op_apis: list, + node_launch_apis: list): + # 1. Match node launch and dequeue + dequeue_corrections_ids = self._match_node_launch_and_dequeue(node_launch_apis, task_dequeues) + + if not dequeue_corrections_ids: + return + + # 2. Match dequeue and enqueue + enqueue_dict = self._match_dequeue_and_enqueue(dequeue_corrections_ids, task_enqueues) + + # 3. Match enqueue and torch op + self._match_enqueue_and_torch_op(enqueue_dict, torch_op_apis) + + @staticmethod + def _match_node_launch_and_dequeue(node_launch_apis, task_dequeues): + dequeue_dict = defaultdict(list) + node_launch_dict = defaultdict(list) + for dequeue in task_dequeues: + dequeue_dict[dequeue[TaskQueueDataOri.GLOBAL_TID]].append(dequeue) + for node_launch in node_launch_apis: + node_launch_dict[node_launch[CannNodeLaunchApiOri.GLOBAL_TID]].append(node_launch) + common_keys = dequeue_dict.keys() & node_launch_dict.keys() + dequeue_dict = {k: dequeue_dict[k] for k in common_keys} + node_launch_dict = {k: node_launch_dict[k] for k in common_keys} + dequeue_corrections_ids = [] + for tid in common_keys: + dequeue_index = 0 + for node_launch in node_launch_dict[tid]: + while dequeue_index < len(dequeue_dict[tid]): + if dequeue_dict[tid][dequeue_index][TaskQueueDataOri.START_NS] > node_launch[ + CannNodeLaunchApiOri.START_NS]: + break + if dequeue_dict[tid][dequeue_index][TaskQueueDataOri.END_NS] < node_launch[ + CannNodeLaunchApiOri.START_NS]: + break + if ( + dequeue_dict[tid][dequeue_index][TaskQueueDataOri.START_NS] < node_launch[ + CannNodeLaunchApiOri.START_NS] + and dequeue_dict[tid][dequeue_index][TaskQueueDataOri.END_NS] > node_launch[ + CannNodeLaunchApiOri.END_NS] + ): + dequeue_correction_id = dequeue_dict[tid][dequeue_index][ + TaskQueueDataOri.CORRELATION_ID] + node_launch_correction_id = node_launch[CannNodeLaunchApiOri.CORRELATION_ID] + dequeue_corrections_ids.append([dequeue_correction_id, node_launch_correction_id]) + dequeue_index += 1 + return dequeue_corrections_ids + + @staticmethod + def _match_dequeue_and_enqueue(dequeue_corrections_ids, task_enqueues): + dequeue_corrections_ids.sort(key=lambda x: x[0]) + task_enqueues.sort(key=lambda x: x[TaskQueueDataOri.CORRELATION_ID]) + enqueue_dict = defaultdict(list) + idx = 0 + for enqueue in task_enqueues: + while idx < len(dequeue_corrections_ids): + if enqueue[TaskQueueDataOri.CORRELATION_ID] < dequeue_corrections_ids[idx][0]: break - if dequeue_list[idx][TaskQueueDataOri.START_NS] > node_launch_api[CannNodeLaunchApiOri.END_NS]: + if enqueue[TaskQueueDataOri.CORRELATION_ID] == dequeue_corrections_ids[idx][0]: + enqueue_dict[enqueue[TaskQueueDataOri.GLOBAL_TID]].append(( + dequeue_corrections_ids[idx][1], + enqueue[TaskQueueDataOri.START_NS], + enqueue[TaskQueueDataOri.END_NS] + )) + idx += 1 break + idx += 1 + return enqueue_dict - def get_torch_op_connection_ids_with_enqueue(self, torch_op_apis: list, torch_op_len: int, enqeue: list, last_torch_op_index: int, connection_id: int) -> int: - last_op_api = None - for idx in range(last_torch_op_index, torch_op_len): - if enqeue[TaskQueueDataOri.START_NS] > torch_op_apis[idx][TorchOpDataOri.END_NS]: - continue - if enqeue[TaskQueueDataOri.START_NS] > torch_op_apis[idx][TorchOpDataOri.START_NS] and enqeue[TaskQueueDataOri.END_NS] < torch_op_apis[idx][TorchOpDataOri.END_NS]: - last_op_api = torch_op_apis[idx] - last_torch_op_index = idx - elif last_op_api: - break - if last_op_api: - torch_op_apis[last_torch_op_index][TorchOpDataOri.CONNECTION_ID].append(connection_id) - return last_torch_op_index - - def get_torch_op_connection_ids_without_task_queue(self, torch_op_apis: list, torch_op_len: int, node_lauch_apis: list): - last_op_api = None - last_op_index = 0 - for node_launch_api in node_lauch_apis: - for idx in range(last_op_index, torch_op_len): - if torch_op_apis[idx][TorchOpDataOri.GLOBAL_TID] != node_launch_api[CannNodeLaunchApiOri.GLOBAL_TID]: - continue - if node_launch_api[CannNodeLaunchApiOri.START_NS] > torch_op_apis[idx][TorchOpDataOri.END_NS]: - continue - if node_launch_api[CannNodeLaunchApiOri.START_NS] > torch_op_apis[idx][TorchOpDataOri.START_NS] and \ - node_launch_api[CannNodeLaunchApiOri.END_NS] < torch_op_apis[idx][TorchOpDataOri.END_NS]: - last_op_api = torch_op_apis[idx] - last_op_index = idx - elif last_op_api: - torch_op_apis[last_op_index][TorchOpDataOri.CONNECTION_ID].append(node_launch_api[CannNodeLaunchApiOri.CORRELATION_ID]) - last_op_api = None - break + @staticmethod + def _match_enqueue_and_torch_op(enqueue_dict, torch_op_apis): + torch_op_dict = defaultdict(list) + for torch_op_api in torch_op_apis: + torch_op_dict[torch_op_api[TorchOpDataOri.GLOBAL_TID]].append(torch_op_api) + common_keys = enqueue_dict.keys() & torch_op_dict.keys() + enqueue_dict = {k: enqueue_dict[k] for k in common_keys} + torch_op_dict = {k: torch_op_dict[k] for k in common_keys} + for tid in common_keys: + enqueues = enqueue_dict[tid] + torch_ops = torch_op_dict[tid] + torch_ops_len = len(torch_ops) + last_torch_op_index = 0 + for correction_id, enqueue_start_time, enqueue_end_time in enqueues: + last_torch_op_api = None + while last_torch_op_index < torch_ops_len: + current_op = torch_ops[last_torch_op_index] + op_start = current_op[TorchOpDataOri.START_NS] + op_end = current_op[TorchOpDataOri.END_NS] + if op_start > enqueue_start_time: + break + if op_end < enqueue_start_time: + last_torch_op_index += 1 + continue + if op_start < enqueue_start_time and op_end > enqueue_end_time: + last_torch_op_api = current_op + last_torch_op_index += 1 + else: + if last_torch_op_api: + break + last_torch_op_index += 1 + + if last_torch_op_api: + torch_ops[last_torch_op_index - 1][TorchOpDataOri.CONNECTION_ID].append(correction_id) + + @staticmethod + def _get_torch_op_connection_ids_without_task_queue(torch_op_apis: list, node_launch_apis: list): + torch_op_dict = defaultdict(list) + node_launch_dict = defaultdict(list) + for torch_op_api in torch_op_apis: + torch_op_dict[torch_op_api[TorchOpDataOri.GLOBAL_TID]].append(torch_op_api) + for node_launch in node_launch_apis: + node_launch_dict[node_launch[CannNodeLaunchApiOri.GLOBAL_TID]].append(node_launch) + common_keys = torch_op_dict.keys() & node_launch_dict.keys() + node_launch_dict = {k: node_launch_dict[k] for k in common_keys} + torch_op_dict = {k: torch_op_dict[k] for k in common_keys} + for tid in common_keys: + node_launch_apis = node_launch_dict[tid] + torch_ops = torch_op_dict[tid] + torch_ops_len = len(torch_ops) + for node_launch_api in node_launch_apis: + last_torch_op_api = None + last_torch_op_index = 0 + node_start = node_launch_api[CannNodeLaunchApiOri.START_NS] + node_end = node_launch_api[CannNodeLaunchApiOri.END_NS] + node_corr_id = node_launch_api[CannNodeLaunchApiOri.CORRELATION_ID] + while last_torch_op_index < torch_ops_len: + current_op = torch_ops[last_torch_op_index] + op_start = current_op[TorchOpDataOri.START_NS] + op_end = current_op[TorchOpDataOri.END_NS] + if op_start > node_start: + break + if node_start > op_end: + last_torch_op_index += 1 + continue + if node_start > op_start and node_end < op_end: + last_torch_op_api = current_op + last_torch_op_index += 1 + else: + if last_torch_op_api: + break + last_torch_op_index += 1 + + if last_torch_op_api: + torch_ops[last_torch_op_index - 1][TorchOpDataOri.CONNECTION_ID].append(node_corr_id) def set_start_string_id(self): Str2IdManager().set_start_id(DbConstant.START_STRING_ID_FWK_API) -- Gitee