diff --git a/torch_npu/_inductor/codegen/triton.py b/torch_npu/_inductor/codegen/triton.py index 201136d3f6f072ac01597b46b31c50f7df24dbc8..ade172451312e6cdff9c298a65b7c96acd26fafe 100644 --- a/torch_npu/_inductor/codegen/triton.py +++ b/torch_npu/_inductor/codegen/triton.py @@ -1379,7 +1379,7 @@ class NPUIndexTritonKernel(TritonKernel): result_var.mask_vars = indexing.mask_vars # type: ignore[assignment] if append_broadcast and append_broadcast != '[]': - line = f"tl.broadcast_to({result_var}, {append_broadcast})" + line = f"tl.reshape({result_var}, {append_broadcast})" result_var = self.cse.generate(load_buffer, line, dtype=dtype) # triton can handle broadcast elif index_analyze.need_permute: