From 9db1f9a6058877e887f9f97df784ac47458f319a Mon Sep 17 00:00:00 2001 From: p00669756 Date: Fri, 12 Sep 2025 11:56:41 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dcodegen=E7=94=9F=E6=88=90?= =?UTF-8?q?shape=E4=B8=8D=E5=8C=B9=E9=85=8D=E7=9A=84=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/_inductor/codegen/triton.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index 0ac327904..18c3d789a 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -219,11 +219,10 @@ class IterationRangesEntryNPUIndex(IterationRangesEntry): return f"[{','.join(self.directions)}]" tiling_axis = [x.symbol() for x in self.kernel.tiling_axis] - rev_orders = [x for x in self.kernel.golden_var_list if x in tiling_axis] + var_orders = [x for x in tiling_axis if x in self.kernel.golden_var_list] self.directions = ["None"] * len(tiling_axis) - if len(tiling_axis) != len(rev_orders): - raise RuntimeError(f"assert tiling len={len(tiling_axis)}, not equal to golden varlist len ={len(rev_orders)}") - var_orders = list(reversed(rev_orders)) + if len(tiling_axis) != len(var_orders): + raise RuntimeError(f"assert tiling len={len(tiling_axis)}, not equal to golden varlist len ={len(var_orders)}") index = var_orders.index(self.symbol()) self.directions[index] = ":" return f"[{','.join(self.directions)}]" -- Gitee From 73bae6e3d21ed1ec37f62f8a14cbff0dea8cbdfc Mon Sep 17 00:00:00 2001 From: p00669756 Date: Thu, 25 Sep 2025 15:54:09 +0800 Subject: [PATCH 2/2] =?UTF-8?q?autotune=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/_inductor/__init__.py | 2 +- torch_npu/_inductor/npu_triton_heuristics.py | 206 ++++++++++--------- 2 files changed, 106 insertions(+), 102 deletions(-) diff --git a/torch_npu/_inductor/__init__.py b/torch_npu/_inductor/__init__.py index 62443484e..3250f7dc8 100644 --- a/torch_npu/_inductor/__init__.py +++ b/torch_npu/_inductor/__init__.py @@ -104,7 +104,7 @@ patch_async_compile() # register fx_pass should be put behind of _register_npu_inductor_decompositons def _replace_benchmark_all_configs(): - from torch._inductor.triton_heuristics import CachingAutotuner + from torch._inductor.runtime.triton_heuristics import CachingAutotuner from .npu_triton_heuristics import benchmark_all_configs CachingAutotuner.benchmark_all_configs = benchmark_all_configs diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py index f954d4719..42871940a 100644 --- a/torch_npu/_inductor/npu_triton_heuristics.py +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -1198,117 +1198,121 @@ def foreach(triton_meta, num_warps, filename=None, inductor_meta=None): ) -@dynamo_timed -def benchmark_all_configs(self, *args, input_grid, **kwargs): +def benchmark_all_configs(self, *args, **kwargs): print(f"candidate launcher count = {len(self.launchers)}") - - tilling_kernel_list = [] - - def kernel_call(launcher): - def call_kernel(): - if launcher.config.pre_hook is not None: - launcher.config.pre_hook( - {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + with dynamo_timed("benchmark_all_configs"): + tilling_kernel_list = [] + + def kernel_call(launcher): + def call_kernel(): + if launcher.config.pre_hook is not None: + launcher.config.pre_hook( + {**dict(zip(self.arg_names, args)), **launcher.config.kwargs} + ) + cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) + launcher( + *cloned_args, + **cloned_kwargs, + stream=stream, ) - cloned_args, cloned_kwargs = self.clone_args(*args, **kwargs) - launcher( - *cloned_args, - **cloned_kwargs, - grid=input_grid, - stream=stream, - ) - - return call_kernel - for launcher in self.launchers: - if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: - return float("inf") + return call_kernel - stream = self.gpu_device.get_raw_stream( # type: ignore[call-arg] - self.gpu_device.current_device() - ) - tilling_kernel_list.append(kernel_call(launcher)) + for launcher in self.launchers: + if not self.custom_kernel and launcher.n_spills > config.triton.spill_threshold: + return float("inf") - def do_batch_benchmark(tilling_kernel_list): + device_interface = self.get_device_interface() + stream = device_interface.get_raw_stream(device_interface.current_device()) - def delete_file(base_path): - if os.path.exists(base_path): - shutil.rmtree(base_path) + tilling_kernel_list.append(kernel_call(launcher)) - stream = torch.npu.current_stream() - experimental_config = torch_npu.profiler._ExperimentalConfig( - aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, - profiler_level=torch_npu.profiler.ProfilerLevel.Level1, - l2_cache=False, - data_simplification=False - ) + def do_batch_benchmark(tilling_kernel_list): - random_uuid = uuid.uuid4().hex - md5_hash = hashlib.md5(random_uuid.encode()).hexdigest() + def delete_file(base_path): + if os.path.exists(base_path): + shutil.rmtree(base_path) - from torch_npu._inductor.config import profile_path + stream = torch.npu.current_stream() + experimental_config = torch_npu.profiler._ExperimentalConfig( + aic_metrics=torch_npu.profiler.AiCMetrics.PipeUtilization, + profiler_level=torch_npu.profiler.ProfilerLevel.Level1, + l2_cache=False, + data_simplification=False + ) - torch_path = profile_path + md5_hash - rep = 1 - with torch_npu.profiler.profile( - activities=[ - torch_npu.profiler.ProfilerActivity.NPU - ], - schedule=torch_npu.profiler.schedule(wait=0, warmup=1, active=rep, repeat=1, skip_first=1), - on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), - record_shapes=False, - profile_memory=False, - with_stack=False, - with_flops=False, - with_modules=False, - experimental_config=experimental_config) as prof: - stream.synchronize() - for _ in range(rep + 3): - for fn in tilling_kernel_list: - fn() - prof.step() - stream.synchronize() - - import pandas as pd - for root, _, files in os.walk(torch_path): - for file in files: - if file != 'kernel_details.csv': - continue - target_file = os.path.join(root, file) - df = pd.read_csv(target_file) - triton_rows = df[df['Name'].str.startswith('triton', na=False)] - ret = triton_rows['Duration(us)'].astype(float).tolist() - delete_file(torch_path) - return ret - - delete_file(torch_path) - return [] - - try: - timinglist = do_batch_benchmark(tilling_kernel_list) - if not len(timinglist) == len(self.launchers): - raise RuntimeError("not len(timinglist) == len(self.launchers)") - timings = {launcher: timing for launcher, timing in zip(self.launchers, timinglist)} - except Exception as e: - print("some cases in batch benchmark has error! Logging Exception as:") - print(e) - print("switched to single bench...") - timings = { - launcher: self.bench(launcher, *args, **kwargs) - for launcher in self.launchers - } + random_uuid = uuid.uuid4().hex + md5_hash = hashlib.md5(random_uuid.encode()).hexdigest() + + from torch_npu._inductor.config import profile_path + + torch_path = profile_path + md5_hash + WAIT = 1 + WARMUP = 1 + ACTIVE = 1 + REPEAT = 1 + SKIP_FIRST = 1 + TOTAL_STEP = (WAIT + WARMUP + ACTIVE + SKIP_FIRST) * REPEAT + with torch_npu.profiler.profile( + activities=[ + torch_npu.profiler.ProfilerActivity.NPU + ], + schedule=torch_npu.profiler.schedule(wait=WAIT, warmup=WARMUP, active=ACTIVE, repeat=REPEAT, skip_first=TOTAL_STEP), + on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(torch_path), + record_shapes=False, + profile_memory=False, + with_stack=False, + with_flops=False, + with_modules=False, + experimental_config=experimental_config) as prof: + stream.synchronize() + for _ in range(TOTAL_STEP): + for fn in tilling_kernel_list: + fn() + torch.npu.synchronize() + prof.step() + stream.synchronize() + + import pandas as pd + for root, _, files in os.walk(torch_path): + for file in files: + if file != 'kernel_details.csv': + continue + target_file = os.path.join(root, file) + df = pd.read_csv(target_file) + triton_rows = df[df['Name'].str.startswith('triton', na=False)] + ret = triton_rows['Duration(us)'].astype(float).tolist() + delete_file(torch_path) + return ret + + delete_file(torch_path) + return [] - for k, v in timings.items(): - self.coordesc_tuner.cache_benchmark_result(k.config, v) + try: + timinglist = do_batch_benchmark(tilling_kernel_list) + if not len(timinglist) == len(self.launchers): + raise RuntimeError("not len(timinglist) == len(self.launchers)") + timings = {launcher: timing for launcher, timing in zip(self.launchers, timinglist)} + except Exception as e: + print("some cases in batch benchmark has error! Logging Exception as:") + print(e) + print("switched to single bench...") + timings = { + launcher: self.bench(launcher, *args, **kwargs) + for launcher in self.launchers + } - if log.isEnabledFor(logging.DEBUG): for k, v in timings.items(): - log.debug( - "%s: %f, nreg %d, nspill %d, #shared-mem %s", - k.config, - v, - k.n_regs, - k.n_spills, - k.shared, - ) - return timings + self.coordesc_tuner.cache_benchmark_result(k.config, v) + + if log.isEnabledFor(logging.DEBUG): + for k, v in timings.items(): + log.debug( + "%s: %f, nreg %d, nspill %d, #shared-mem %s", + k.config, + v, + k.n_regs, + k.n_spills, + k.shared, + ) + return timings -- Gitee