diff --git a/torch_npu/npu/streams.py b/torch_npu/npu/streams.py index c4ba551f670d36b6b09ce10fb06c7f94bf8efca4..7a330fcf9f41387930a81f4d2e4b2790dcb85b28 100644 --- a/torch_npu/npu/streams.py +++ b/torch_npu/npu/streams.py @@ -3,6 +3,8 @@ import ctypes import torch_npu import torch_npu._C +import weakref +import threading __all__ = ["Stream", "Event", "SyncLaunchStream", "ExternalEvent"] @@ -196,6 +198,22 @@ class Event(torch_npu._C._NPUEventBase): else: return '' +# 全局映射(Event -> tag),使用弱引用字典自动清理无效键 +_GLOBAL_EVENT_TO_TAG = weakref.WeakKeyDictionary() +_GLOBAL_LOCK = threading.Lock() +_TAG_COUNTER = 0 +def get_unique_event_tag(event) -> str: + """为 Event 生成唯一 tag,析构时自动清理""" + with _GLOBAL_LOCK: + if event in _GLOBAL_EVENT_TO_TAG: + return _GLOBAL_EVENT_TO_TAG[event] + # 生成唯一 tag(不依赖字典长度, 否则event析构时删除会出问题) + global _TAG_COUNTER + unique_tag = f"external_event_{_TAG_COUNTER}" + _TAG_COUNTER += 1 + _GLOBAL_EVENT_TO_TAG[event] = unique_tag + + return unique_tag class ExternalEvent(torch_npu._C._NPUEventBase): r"""Wrapper around a NPU event with graph_external=True. @@ -213,8 +231,10 @@ class ExternalEvent(torch_npu._C._NPUEventBase): """ def __new__(cls): - return super(ExternalEvent, cls).__new__(cls, enable_timing=False, blocking=False, - interprocess=False, graph_external=True) + instance = super(ExternalEvent, cls).__new__(cls, enable_timing=False, blocking=False, + interprocess=False, graph_external=True) + instance._tag = get_unique_event_tag(instance) + return instance def record(self, stream=None): r"""Records the event in a given stream. @@ -256,6 +276,17 @@ class ExternalEvent(torch_npu._C._NPUEventBase): else: return '' + @property + def tag(self): + return self._tag + + @classmethod + def get_event_by_tag(cls, tag: str): + with _GLOBAL_LOCK: + for event, event_tag in _GLOBAL_EVENT_TO_TAG.items(): + if event_tag == tag: + return event + raise ValueError(f"No event found with tag: {tag!r}") class SyncLaunchStream(torch_npu._C._NPUStreamBase): r"""Wrapper around a SyncLaunch NPU stream.