diff --git a/torch_npu/npu/graphs.py b/torch_npu/npu/graphs.py index 7e21ce5ed9a78512d66e1bd915eb834732aa00fc..1337964ed9e3e28ce803fe888be69e7affb6af98 100644 --- a/torch_npu/npu/graphs.py +++ b/torch_npu/npu/graphs.py @@ -7,6 +7,7 @@ import re import typing from copy import deepcopy from dataclasses import dataclass, field +import threading from typing import List, Dict, Any, Optional, Tuple import torch @@ -236,10 +237,14 @@ class NPUGraph(torch_npu._C._NPUGraph): """ return super().pool() - def update(self, cpu_update_input): + def update(self, cpu_update_input, blocking=True): if not self.auto_dispatch_capture: raise RuntimeError("The current graph configuration does not support update," "Try to capture by setting auto_dispatch_capture=True during capture", pta_error(ErrCode.PARAM)) + if not blocking: + thread = threading.Thread(target=self.graph_dispatch_mode.update_capture_record, kwargs={"cpu_update_input": cpu_update_input}) + thread.start() + return self.graph_dispatch_mode.update_capture_record(cpu_update_input)