diff --git a/CMakeLists.txt b/CMakeLists.txt index 935f23c93372e99371f6a7a11f6db3cf7d980116..6b12fb0619736d6be606b86ee304d65e933eaf45 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,6 +203,7 @@ endif() include_directories(${PROJECT_SOURCE_DIR}) include_directories(${PROJECT_SOURCE_DIR}/torch_npu/csrc/aten) +include_directories(${PROJECT_SOURCE_DIR}/torch_npu/csrc/inductor) include_directories(${PROJECT_SOURCE_DIR}/third_party/hccl/inc) include_directories(${PROJECT_SOURCE_DIR}/third_party/acl/inc) include_directories(${PROJECT_SOURCE_DIR}/third_party/Tensorpipe) @@ -229,6 +230,7 @@ set(ATEN_SRCS) set(CORE_SRCS) set(FRAMEWORK_SRCS) set(LOGGING_SRCS) +set(INDUCTOR_SRCS) if (NOT DEFINED BUILD_LIBTORCH) set(DIST_SRCS) @@ -251,6 +253,7 @@ add_subdirectory(${TORCHNPU_ROOT}/framework) add_subdirectory(${TORCHNPU_ROOT}/flopcount) add_subdirectory(${TORCHNPU_ROOT}/logging) add_subdirectory(${TORCHNPU_ROOT}/custom_dtype) +add_subdirectory(${TORCHNPU_ROOT}/inductor) if (NOT DEFINED BUILD_LIBTORCH) add_subdirectory(${TORCHNPU_ROOT}/distributed) @@ -279,10 +282,10 @@ if (DEFINED BUILD_TENSORPIPE) endif() if (DEFINED BUILD_LIBTORCH) - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${INDUCTOR_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) else() # Compile code with pybind11 - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${IPC_SRCS} ${UTILS_SRCS} ${SAN_SRCS} ${AFD_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${INDUCTOR_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${IPC_SRCS} ${UTILS_SRCS} ${SAN_SRCS} ${AFD_SRCS}) endif() add_library(${PLUGIN_NAME} SHARED ${CPP_SRCS}) diff --git a/build_libtorch_npu.py b/build_libtorch_npu.py index b91c2a28e6ea8ccac256220f2b143af7d6abb4fc..77f7a0b3bde5d11e638ccc9871686978db8b1f78 100644 --- a/build_libtorch_npu.py +++ b/build_libtorch_npu.py @@ -225,7 +225,8 @@ def copy_hpp(): "torch_npu/csrc/framework/*.h", "torch_npu/csrc/framework/*/*.h", "torch_npu/csrc/framework/*/*/*.h", - "torch_npu/csrc/libs/*.h" + "torch_npu/csrc/libs/*.h", + "torch_npu/csrc/inductor/**/*.h", ] glob_header_files = [] for regex_pattern in header_files: diff --git a/setup.py b/setup.py index 4bbf41448f4eee650159af467034e461ee9233b3..5df0b17e27a7eadc52bf3a159510bec1974bef49 100644 --- a/setup.py +++ b/setup.py @@ -481,6 +481,28 @@ def get_src_py_and_dst(): os.path.relpath(src, os.path.join(BASE_DIR, "patch/include"))) os.makedirs(os.path.dirname(dst), exist_ok=True) ret.append((src, dst)) + + aot_inductor_files = [ + # Follow torch v2.6.0. + # These aoti_runtime/*.cpp don't compile to libtorch_npu, + # but act like header files when generate cppwrapper in aot-inductor. + "torch_npu/_inductor/codegen/aoti_runtime/*.cpp" + ] + glob_aoti_files = [] + for regex_pattern in aot_inductor_files: + glob_aoti_files += glob.glob( + os.path.join(BASE_DIR, regex_pattern), recursive=True + ) + + for src in glob_aoti_files: + # Dst: torch_npu/_inductor/codegen/aoti_runtime/*.cpp + dst = os.path.join( + os.path.join(BASE_DIR, "build/packages/torch_npu/"), + os.path.relpath(src, os.path.join(BASE_DIR, "torch_npu")), + ) + os.makedirs(os.path.dirname(dst), exist_ok=True) + ret.append((src, dst)) + return ret diff --git a/torch_npu/_inductor/codecache.py b/torch_npu/_inductor/codecache.py index 1efec225f25adfc124285f8f489446adeb9bf4d7..3379bcc4211d7f3c8e3c4bfb341084dc28488b4a 100644 --- a/torch_npu/_inductor/codecache.py +++ b/torch_npu/_inductor/codecache.py @@ -88,29 +88,39 @@ def patch_aot_code_compiler_compile(): # which could not be skipped, so here we try to create a new npu op_json, # and clear the content of default op_json. from torch._inductor.codecache import AotCodeCompiler + AotCodeCompiler.src_compile = AotCodeCompiler.compile @classmethod def compile_npu( cls, graph: GraphLowering, - source_code: str, + wrapper_code: str, + kernel_code: str, serialized_extern_kernel_nodes: Optional[str], + *, device_type: str, - additional_files: List[str], + additional_files: list[str], ) -> Union[List[str], str]: result = cls.src_compile( - graph, source_code, serialized_extern_kernel_nodes, - device_type, additional_files + graph, + wrapper_code, + kernel_code, + serialized_extern_kernel_nodes, + device_type=device_type, + additional_files=additional_files, ) generated_files = additional_files if not config.aot_inductor.package: return result - + output_so = [r for r in result if r.endswith(".so")] if len(output_so) > 1: - raise RuntimeError(f"Could not generate npu op json, because there are" - f"more than one so in generated files: {result}" + pta_error(ErrCode.INTERNAL)) + raise RuntimeError( + f"Could not generate npu op json, because there are" + f"more than one so in generated files: {result}" + + pta_error(ErrCode.INTERNAL) + ) output_so = output_so[0] key = os.path.basename(output_so)[0].replace(".", "_") dir_basename = os.path.splitext(output_so)[0] @@ -120,11 +130,11 @@ def patch_aot_code_compiler_compile(): with open(extern_kernel_nodes_json, "w") as f: f.write(serialized_extern_kernel_nodes) generated_files.append(extern_kernel_nodes_json) - + if serialized_extern_kernel_nodes: source_json_file = dir_basename + ".json" with open(source_json_file, "w") as f: f.write(empty_json) return generated_files + AotCodeCompiler.compile = compile_npu - \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/aoti_runtime/interface.cpp b/torch_npu/_inductor/codegen/aoti_runtime/interface.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2d583effa34a95780d8dd643c10de115f6b914fb --- /dev/null +++ b/torch_npu/_inductor/codegen/aoti_runtime/interface.cpp @@ -0,0 +1,272 @@ +// Definition of AOTI runtime interface functions + +#include +#include + +#include +#include +#include +#include + +#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception &e) { \ + std::cerr << "Error: " << e.what() << std::endl; \ + return AOTI_RUNTIME_FAILURE; \ + } catch (...) { \ + std::cerr << "Unknown exception occurred." << std::endl; \ + return AOTI_RUNTIME_FAILURE; \ + } \ + return AOTI_RUNTIME_SUCCESS; + +#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \ + do { \ + AOTI_RUNTIME_CHECK(actual_size == expected_size, "expected " + std::string(name) + " vector size to be " + \ + std::to_string(expected_size) + ", but got " + \ + std::to_string(actual_size)); \ + } while (0) + +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) { aoti_torch_grad_mode_set_enabled(false); } + ~AOTINoGradGuard() { aoti_torch_grad_mode_set_enabled(prev_mode); } + bool prev_mode; +}; + +extern "C" { + +AOTIRuntimeError AOTInductorModelContainerCreate(AOTInductorModelContainerHandle *container_handle, size_t num_models, + bool is_cpu, const char *cubin_dir) { + return AOTInductorModelContainerCreateWithDevice(container_handle, num_models, is_cpu ? "cpu" : "npu", cubin_dir); +} + +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(AOTInductorModelContainerHandle *container_handle, + size_t num_models, const char *device_str, + const char *cubin_dir) { + if (num_models == 0) { + std::cerr << "Error: num_models must be positive, but got 0" << std::endl; + return AOTI_RUNTIME_FAILURE; + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::optional cubin_dir_opt; + if (cubin_dir != nullptr) { + cubin_dir_opt.emplace(cubin_dir); + } + auto *container = + new torch::aot_inductor::AOTInductorModelContainer(num_models, std::string(device_str), cubin_dir_opt); + *container_handle = reinterpret_cast(container); + }) +} + +AOTIRuntimeError AOTInductorModelContainerDelete(AOTInductorModelContainerHandle container_handle) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto *container = reinterpret_cast(container_handle); + delete container; + }); +} + +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle *output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle) { + auto *container = reinterpret_cast(container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run(input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle *output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle) { + auto *container = reinterpret_cast(container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_single_threaded(input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumConstants(AOTInductorModelContainerHandle container_handle, + size_t *num_constants) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *num_constants = container->num_constants(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantName(AOTInductorModelContainerHandle container_handle, size_t idx, + const char **name) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *name = container->constant_name(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(AOTInductorModelContainerHandle container_handle, + size_t idx, const char **original_fqn) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *original_fqn = container->constant_original_fqn(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(AOTInductorModelContainerHandle container_handle, + size_t idx, bool *from_folded) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantType(AOTInductorModelContainerHandle container_handle, size_t idx, + int32_t *type) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *type = container->constant_type(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(AOTInductorModelContainerHandle container_handle, size_t idx, + int32_t *dtype) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *dtype = container->constant_dtype(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, bool validate_full_update) { + auto *container = reinterpret_cast(container_handle); + auto input_map = reinterpret_cast *>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { container->update_constant_buffer(*input_map, use_inactive, validate_full_update); }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle) { + return AOTInductorModelContainerUpdateConstantBuffer(container_handle, constant_map_handle, + /*use_inactive*/ true, + /*validate_full_update*/ true); +} + +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(AOTInductorModelContainerHandle container_handle, + bool use_inactive, AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto *container = reinterpret_cast(container_handle); + auto stream = reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_const_fold(use_inactive, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(AOTInductorModelContainerHandle container_handle) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ container->swap_constant_buffer(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumInputs(AOTInductorModelContainerHandle container_handle, + size_t *ret_num_inputs) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *ret_num_inputs = container->num_inputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetInputName(AOTInductorModelContainerHandle container_handle, + size_t input_idx, const char **ret_input_names) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *ret_input_names = container->input_name(input_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(AOTInductorModelContainerHandle container_handle, + size_t *ret_num_outputs) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *ret_num_outputs = container->num_outputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetOutputName(AOTInductorModelContainerHandle container_handle, + size_t output_idx, const char **ret_output_names) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *ret_output_names = container->output_name(output_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec(AOTInductorModelContainerHandle container_handle, + const char **in_spec, const char **out_spec) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + *in_spec = container->get_in_spec(); + *out_spec = container->get_out_spec(); + }) +} + +AOTIRuntimeError AOTInductorModelCreate(AOTInductorModelHandle *model_handle, + AOTInductorConstantMapHandle constant_map_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto constant_array = std::make_shared>(); + auto input_map = reinterpret_cast *>(constant_map_handle); + + auto model = new torch::aot_inductor::AOTInductorModel( + constant_map, constant_array, + "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models + ""); + + if (input_map) { + for (auto const &kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + } else { + model->load_constants(); + } + + *model_handle = reinterpret_cast(model); + })} + +AOTIRuntimeError AOTInductorModelRun(AOTInductorModelHandle model_handle, AtenTensorHandle *input_handles, + AtenTensorHandle *output_handles) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + model->run_impl(input_handles, output_handles, (torch::aot_inductor::DeviceStreamType) nullptr, nullptr); + }) +} + +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + delete model; +})} + +AOTIRuntimeError AOTInductorModelGetNumOutputs(AOTInductorModelHandle model_handle, + size_t *ret_num_outputs){CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + *ret_num_outputs = model->num_outputs(); +})} + +AOTIRuntimeError AOTInductorModelUpdateConstantsMap(AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto input_map = reinterpret_cast *>(constant_map_handle); + + for (auto const &kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + model->update_constants_map(std::move(constant_map)); + }) +} + +} // extern "C" diff --git a/torch_npu/_inductor/codegen/cpp_wrapper.py b/torch_npu/_inductor/codegen/cpp_wrapper.py index 9a16cfabf8af4321b5b3ebfff157a2093c03803f..6511b550dc96b47c7e5498571f09c019cd6aad33 100644 --- a/torch_npu/_inductor/codegen/cpp_wrapper.py +++ b/torch_npu/_inductor/codegen/cpp_wrapper.py @@ -1,4 +1,6 @@ import functools +import re +import dataclasses import os import sys from itertools import chain, count, zip_longest @@ -12,12 +14,13 @@ from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name from torch._inductor.codegen.aoti_hipify_utils import maybe_hipify_code_wrapper from torch._inductor.codegen.common import get_device_op_overrides from torch._inductor.codegen.cpp_utils import cexpr, DTYPE_TO_CPP, DEVICE_TO_ATEN +from torch._inductor.codegen.triton_utils import should_unwrap_unspec_arg from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.codegen.multi_kernel import MultiKernelCall from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg -from torch._inductor.ir import IRNode, TensorBox +from torch._inductor.ir import IRNode, TensorBox, GraphPartitionSignature from torch._inductor.runtime.runtime_utils import dynamo_timed -from torch._inductor.utils import DeferredLineBase +from torch._inductor.utils import DeferredLineBase, IndentedBuffer from torch._inductor.virtualized import V from torch._inductor.utils import _align, ALIGN_BYTES @@ -34,153 +37,164 @@ def checkIfTrue(value, msg): return True -class DeferredNpuKernelLine(DeferredLineBase): - """ - When using cpp wrapper, NPU kernel load and launch needs to wait for Triton kernels - to be tuned and stored as cubin files, so use a deferred line to backfill those information - """ - - def __init__( - self, - kernel_name: str, - line_template: str, - keys: Tuple[str, ...], - additional_files: List[str], - ): - super().__init__(line_template) - checkIfTrue(not isinstance(line_template, DeferredLineBase), "line template can not be DeferredLineBase") - self.additional_files = additional_files - self.kernel_name = kernel_name - self.line_template = line_template - self.keys = keys - - def __call__(self): - if self.kernel_name.startswith("multi_kernel_"): - # MultiKernel will select one kernel after running the autotune block - self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) - params = CudaKernelParamCache.get(self.kernel_name) - checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") - - for key in self.keys: - checkIfTrue(key in params, f"{key} not found in CudaKernelParamCache[{self.kernel_name}]") - - if key == get_cpp_wrapper_cubin_path_name(): - checkIfTrue(os.path.exists(params[key]), f"{params[key]} does not exist") - self.additional_files.append(params[key]) - - return self.line_template % tuple(params[key] for key in self.keys) - - def _new_line(self, line): - return DeferredNpuKernelLine( - self.kernel_name, line, self.keys, self.additional_files - ) - +_cpp_string_literal_escapes = { + "\\": "\\\\", + '"': '\\"', + "\n": "\\n", + "\t": "\\t", + "\r": "\\r", +} +_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]') -class DeferredNpuDefaultGrid: - """ - A container for the default grid, which may be used by DeferredNpuGridLine - """ - - def __init__( - self, - kernel_name: str, - grid, - grid_callable: Optional[Callable[..., Any]] = None, - **grid_extra_kwargs, - ): - self.kernel_name = kernel_name - self.grid = grid - self.grid_callable = grid_callable - self.grid_extra_kwargs = grid_extra_kwargs - - def __iter__(self): - # DeferredNpuDefaultGrid can be passed to the base class, PythonWrapperCodegen, - # to generate the autotune code block, and thus we need this iterator - return iter(self.grid) - - def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]): - if isinstance(grid, (list, tuple)): - return [self._process_grid(e) for e in grid] - else: - return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid - - def __call__(self): - if self.kernel_name.startswith("multi_kernel_"): - # MultiKernel will select one kernel after running the autotune block - self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) - - grid = self.grid - checkIfTrue(isinstance(grid, (list, tuple)), f"expected {grid=} to be a list") - grid = self._process_grid(grid) +def cpp_string_literal(s: str) -> str: + escaped = _cpp_string_literal_pattern.sub( + lambda match: _cpp_string_literal_escapes[match.group(0)], s + ) + return f'"{escaped}"' - checkIfTrue(self.grid_callable is not None, "grid_callable can't be None") - if not self.grid_extra_kwargs: - grid_fn = self.grid_callable(*grid) - else: - grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) - - params = CudaKernelParamCache.get(self.kernel_name) - checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") +@dataclasses.dataclass +class UnwrapUnspecArg: + """Marker that we need to call .item() on the tensor""" - return grid_fn(params["meta"]) + dtype: torch_dtype -class DeferredNpuGridLine(DeferredLineBase): +@dataclasses.dataclass +class DeferredNpuTritonCallWrapper: """ - When using cpp wrapper, NPU kernel load and launch needs to wait for Triton kernels - to be tuned and stored as cubin files, so use a deferred line to backfill those information + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred generating the final wrapper around + the triton kernel until right before the prefix is written. """ - def __init__( - self, - kernel_name: str, - grid_var: str, - grid, - autotune_configs, - ): - super().__init__("") - self.kernel_name = kernel_name - self.grid_var = grid_var - self.grid = grid - self.autotune_configs = autotune_configs + wrapper_name: str + kernel_name: str + arg_types: list[Any] + kernel_id: int - def __call__(self): + def generate(self, wrapper): + prefix = wrapper.prefix if self.kernel_name.startswith("multi_kernel_"): # MultiKernel will select one kernel after running the autotune block self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) - params = CudaKernelParamCache.get(self.kernel_name) + def_args = params["def_args"] + arg_types = self.arg_types + inductor_meta = params["inductor_meta"] + + if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types): + # extra_launcher_args should already be in def_args + arg_types = arg_types + [SymbolicCallArg] * len( + inductor_meta["extra_launcher_args"] + ) - checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") - - if self.autotune_configs is not None: - # This indicates the Triton kernel is a user-defined one. - grid = None - if len(self.grid) == 1: - grid = self.grid[0] - else: - for i, c in enumerate(self.autotune_configs): - if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): - grid = self.grid[i] - break - checkIfTrue(grid is not None, "grid can not be None") - grid_args_str = ", ".join( - [cexpr(V.graph.sizevars.simplify(item)) for item in grid] + if not V.graph.aot_mode: + prefix.writeline( + maybe_hipify_code_wrapper( + f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;" + ) ) + kernel_var_name = self.kernel_name else: - launch_grid = (params['grid_x'], params['grid_y'], params['grid_z']) - grid_args_str = ", ".join( - [cexpr(item) for item in launch_grid] + kernel_var_name = f"kernels_.{self.kernel_name}" + + # tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types + template_types = [ + f"typename {name}_type_" + for name, arg_type in zip(def_args, arg_types) + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)) + ] + if V.graph.aot_mode: + template_types.append("typename kernels_type_") + if template_types: + prefix.writeline(f"template <{', '.join(template_types)}>") + prefix.writeline(f"static inline void {self.wrapper_name}(") + with prefix.indent(): + for name, arg_type in zip(def_args, arg_types): + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)): + prefix.writeline(f"const {name}_type_& {name},") + elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)): + prefix.writeline(f"int64_t {name},") + elif arg_type is float: + prefix.writeline(f"float {name},") + elif arg_type is bool: + prefix.writeline(f"bool {name},") + else: + raise ValueError(f"Unexpected arg type {arg_type}") + prefix.writeline(f"{wrapper.device_codegen.cpp_stream_type()} stream_,") + if V.graph.aot_mode: + prefix.writeline("kernels_type_& kernels_,") + prefix.writeline( + "const std::optional& cubin_dir_ = std::nullopt" ) + prefix.writeline("){") + with prefix.indent(): + self.generate_grid(prefix, inductor_meta, params) + self.generate_load_kernel(prefix, kernel_var_name, params) + self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params) + prefix.writeline("}") + # Ensure the cubin file is included in the package + V.graph.wrapper_code.additional_files.append( + params[get_cpp_wrapper_cubin_path_name()] + ) - return f"\n Grid {self.grid_var} = Grid({grid_args_str});\n" + def generate_grid( + self, + prefix: IndentedBuffer, + inductor_meta: dict[str, Any], + params: dict[str, Any], + ): + from ..npu_triton_heuristics import GridExprNpu - def _new_line(self, line): - return DeferredNpuGridLine( - self.kernel_name, self.grid_var, self.grid, self.autotune_configs + numels = [arg for arg in params["call_args"] if "_numel" in arg] + grid = GridExprNpu.from_meta_and_set_numel( + inductor_meta, params["config"], numels, "cpp" + ) + for line in grid.prefix: + prefix.writeline(line) + prefix.splice( + f"""\ + uint32_t grid_0 = {grid.x_grid}; + uint32_t grid_1 = {grid.y_grid}; + uint32_t grid_2 = {grid.z_grid}; + """ + ) + prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;") + + def generate_load_kernel(self, prefix, kernel_var_name, params): + prefix.writeline(f"if ({kernel_var_name} == nullptr) {{") + with prefix.indent(): + load_kernel_args = [ + cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]), + cpp_string_literal(params["mangled_name"]), + str(params["shared_mem"]), + "cubin_dir_", + ] + prefix.writeline( + f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); " + ) + prefix.writeline("}") + + def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): + triton_meta = params["triton_meta"] + arg_type_lookup = dict(zip(params["def_args"], self.arg_types)) + # difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants + call_args = [name for name in params["call_args"] if name not in triton_meta["constants"]] + arg_types = [arg_type_lookup[name] for name in call_args] + arg_signatures = [triton_meta["signature"][name] for name in call_args] + call, call_args_str = wrapper.generate_args_decl( + prefix, + call_args, + arg_types, + arg_signatures, + kernel_var_name, + self.kernel_id, ) + prefix.writeline(f"{call_args_str}") + + prefix.writeline(r"launchKernel({}, {});".format(call, f'"{kernel_var_name}"')) class CppWrapperNpu(CppWrapperCpu): @@ -189,24 +203,28 @@ class CppWrapperNpu(CppWrapperCpu): """ def __init__(self) -> None: - self.device = 'npu' + self.device = "npu" self.device_codegen = get_device_op_overrides(self.device) super().__init__() self.grid_id = count() self.visited_raii_handle = set() self.visited_handle_for_kernel_id = dict() + self._triton_call_wrappers: dict[str, DeferredNpuTritonCallWrapper] = {} @staticmethod def create( - is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, ): # comment at CppWrapperCpu `codegen_subgraph` function. return CppWrapperNpu() def super_write_header_rewrite(self): """Copied from CppWrapperCpu to: - (1) change __file__ path for cpython, so that we can use aoti_runtime in current path. - (2) rewrite include path of aoti header file. + (1) change __file__ path for cpython, so that we can use aoti_runtime in current path. + (2) rewrite include path of aoti header file. """ if V.graph.is_const_graph: # We do not write header for constant graph, it will be written by main module. @@ -343,11 +361,7 @@ class CppWrapperNpu(CppWrapperCpu): if V.graph.aot_mode and V.graph.inputs_to_check: for idx in V.graph.inputs_to_check: input_name = V.graph.graph_input_names[idx] - checkIfTrue(input_name in V.graph.graph_inputs, f"{input_name} not found in graph inputs") - value = V.graph.graph_inputs[input_name] - checkIfTrue(isinstance(value, TensorBox), - f"{input_name} is expected to be tensor but found as {type(value)}") self.prefix.splice( f""" @@ -360,11 +374,11 @@ class CppWrapperNpu(CppWrapperCpu): super().codegen_inputs() def define_kernel( - self, - kernel_name: str, - kernel_body: str, - metadata: Optional[str] = None, - gpu=True, + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu=True, ): if gpu: if config.triton.autotune_at_compile_time: @@ -382,10 +396,10 @@ class CppWrapperNpu(CppWrapperCpu): self.prefix.writeline("\n") if not V.graph.aot_mode: for kernel in chain( - sorted(self.src_to_kernel.values()), - sorted( - [entry[0] for entry in self.user_defined_kernel_cache.values()] - ), + sorted(self.src_to_kernel.values()), + sorted( + [entry[0] for entry in self.user_defined_kernel_cache.values()] + ), ): self.prefix.writeline( maybe_hipify_code_wrapper( @@ -396,17 +410,17 @@ class CppWrapperNpu(CppWrapperCpu): return super().generate(is_inference) def generate_user_defined_triton_kernel( - self, - kernel_name: str, - raw_args: List[Any], - grid: List[Any], - configs, - triton_meta, - constexprs, + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, ): if ( - config.triton.autotune_at_compile_time - and kernel_name not in self.kernel_autotune_names + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names ): # Call PythonWrapperCodegen to create the autotune code block PythonWrapperCodegen.generate_user_defined_triton_kernel( @@ -442,28 +456,8 @@ class CppWrapperNpu(CppWrapperCpu): autotune_configs=configs, ) - @functools.lru_cache(None) # noqa: B019 - def generate_load_kernel_once( - self, - kernel_name: str, - device_index, - graph: "GraphLowering", # for per-graph caching - ): - keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem") - kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name - self.writeline(f"if ({kernel_var_name} == nullptr) {{") - deferred_gpu_kernel_line = DeferredNpuKernelLine( - kernel_name, - " " + kernel_var_name + r' = loadKernel("%s", "%s", %s, this->cubin_dir_);', - keys, - self.additional_files, - ) - self.writeline(deferred_gpu_kernel_line) - self.writeline("}") - return kernel_var_name - def codegen_tensor_item_npu( - self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None ): dtype_str = str(dtype).split(".")[-1] writer = indented_buffer or self @@ -475,22 +469,22 @@ class CppWrapperNpu(CppWrapperCpu): f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" ) writer.writeline(f"float {scalar} = float({scalar_tmp});") - struct_data = f'float {scalar} __attribute__((aligned(4)));' - arg_data = f'static_cast({scalar})' + struct_data = f"float {scalar} __attribute__((aligned(4)));" + arg_data = f"static_cast({scalar})" else: writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") writer.writeline( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" ) - struct_data = f'{DTYPE_TO_CPP[dtype]} {scalar} __attribute__((aligned(sizeof({DTYPE_TO_CPP[dtype]} ))));' - arg_data = f'static_cast<{DTYPE_TO_CPP[dtype]}>({scalar})' + struct_data = f"{DTYPE_TO_CPP[dtype]} {scalar} __attribute__((aligned(sizeof({DTYPE_TO_CPP[dtype]} ))));" + arg_data = f"static_cast<{DTYPE_TO_CPP[dtype]}>({scalar})" return struct_data, arg_data def codegen_device(self, device): if device.type not in DEVICE_TO_ATEN: raise RuntimeError(device.type + "not found in DEVICE_TO_ATEN") - device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" + device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" if device_str == "privateuse1": device_str = "npu" self.used_cached_devices.add(device_str) @@ -525,21 +519,21 @@ class CppWrapperNpu(CppWrapperCpu): return "" if kernel_id not in self.visited_handle_for_kernel_id: self.visited_handle_for_kernel_id[kernel_id] = set() - + def get_tensor_from_handle(h, t): if h in self.visited_handle_for_kernel_id[kernel_id]: return "" self.visited_handle_for_kernel_id[kernel_id].add(h) return f" auto {t} = *tensor_handle_to_tensor_pointer({h});\n" - + # Only dump tensor args, e.g, ['buf2', '8L', '4L'] => ['buf2'] tensor_args = [arg for arg in args if not arg[0].isdigit()] tensor_args_h = [f"{arg}_h" for arg in tensor_args] tensor_args_t = [f"{arg}_t" for arg in tensor_args] - handle_tensor_str = "".join([ - get_tensor_from_handle(h, t) for h, t in zip(tensor_args_h, tensor_args_t) - ]) + handle_tensor_str = "".join( + [get_tensor_from_handle(h, t) for h, t in zip(tensor_args_h, tensor_args_t)] + ) dump_path = npu_config.aot_inductor.dump_path_cpp return f""" @@ -549,16 +543,30 @@ class CppWrapperNpu(CppWrapperCpu): torch::save(arg_{mark}, "{dump_path}/{kernel_id}_{kernel_name}_{mark}.pt"); """ - def generate_launch_call( + def generate_args_decl( self, + code, call_args, arg_types, arg_signatures, + kernel_name, kernel_id, - grid_var, - kernel_name + is_triton_kernel=True, ): - kernel_val_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + """ + Generates any declarations of args to pass into a kernel call, and then returns the arg names. + + In more detail: + * declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;` + * returns: a string with the list of args, e.g. "var_0, var_1" + + call_args: list of call arguments + arg_types: list of argument types + arg_signatures: list with signatures of all the args + is_triton_kernel: whether these are passed into a triton kernel or not. In particular, + calls to triton kernels will have an additional global scratch space + arg injected at the front of the arg list. + """ new_args: list[str] = [] # Add more cases for other types as needed @@ -577,115 +585,105 @@ class CppWrapperNpu(CppWrapperCpu): "fp64": "double", } - - struct_def_body = '' - struct_arg_body = '' + struct_def_body = "" + struct_arg_body = "" def process_args(arg, arg_type, arg_signature=None): var_name = f"var_{next(self.arg_var_id)}" # ignore nvTmaDesc, as host-side TMA descriptors need # to be passed to the compiled Triton kernel by value - if isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc": - if arg.endswith(".item()"): # scalar - # Need to declare a scalar in this case - arg = arg[:-7] - struct_data, arg_data = self.codegen_tensor_item_npu( - arg_type, - arg, - var_name, - ) - else: - # void* - device_ptr_type = self.device_codegen.cpp_device_ptr() - self.writeline( - maybe_hipify_code_wrapper( - f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" - ) + if isinstance(arg_type, UnwrapUnspecArg) and arg_signature != "nvTmaDesc": + self.codegen_tensor_item_npu( + arg_type.dtype, + arg, + var_name, + indented_buffer=code, + ) + elif isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc": + device_ptr_type = self.device_codegen.cpp_device_ptr() + code.writeline( + maybe_hipify_code_wrapper( + f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" ) - if npu_config.aot_inductor.debug_kernel: - if arg not in self.visited_raii_handle: - self.writeline( - f"AtenTensorHandle {arg}_h = {arg}.get();" - ) - self.visited_raii_handle.add(arg) - struct_data = f'void* {var_name} __attribute__((aligned(8)));' - arg_data = f'static_cast({var_name})' - + ) + if npu_config.aot_inductor.debug_kernel: + if arg not in self.visited_raii_handle: + self.writeline(f"AtenTensorHandle {arg}_h = {arg}.get();") + self.visited_raii_handle.add(arg) + struct_data = f"void* {var_name} __attribute__((aligned(8)));" + arg_data = f"static_cast({var_name})" elif arg_type in (sympy.Integer, int): - # int - self.writeline(f"int {var_name} = {cexpr(arg)};") - struct_data = f'int {var_name} __attribute__((aligned(4)));' - arg_data = f'static_cast({var_name})' - + code.writeline(f"int {var_name} = {cexpr(arg)};") + struct_data = f"int {var_name} __attribute__((aligned(4)));" + arg_data = f"static_cast({var_name})" elif arg_type in (sympy.Float, float): - # float - self.writeline(f"float {var_name} = {cexpr(arg)};") - struct_data = f'float {var_name} __attribute__((aligned(4)));' - arg_data = f'static_cast({var_name})' - + code.writeline(f"float {var_name} = {cexpr(arg)};") + struct_data = f"float {var_name} __attribute__((aligned(4)));" + arg_data = f"static_cast({var_name})" # For symbolic call arguments, examine the arg signatures from triton meta # to explicitly cast to the right type # Reason: `auto` can infer unexpected type against kernel input signature. elif ( - isinstance(arg_type, type(SymbolicCallArg)) - and arg_signature is not None - and arg_signature in signature2dtype.keys() + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype.keys() ): - # or scalar symbolic type,currently only support scalar symbolic type - self.writeline( + code.writeline( f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" ) - struct_data = f'{signature2dtype[arg_signature]} {var_name} __attribute__((aligned(sizeof({signature2dtype[arg_signature]}))));' - arg_data = f'static_cast<{signature2dtype[arg_signature]}>({var_name})' + struct_data = f"{signature2dtype[arg_signature]} {var_name} __attribute__((aligned(sizeof({signature2dtype[arg_signature]}))));" + arg_data = f"static_cast<{signature2dtype[arg_signature]}>({var_name})" else: raise TypeError("Infer arg_type to cpp failed!") - nonlocal struct_def_body nonlocal struct_arg_body - struct_def_body += struct_data + ' ' - struct_arg_body += arg_data + ', ' + struct_def_body += struct_data + " " + struct_arg_body += arg_data + ", " for arg, arg_type, arg_signature in zip_longest( - call_args, arg_types, arg_signatures + call_args, arg_types, arg_signatures ): process_args(arg, arg_type, arg_signature) - debug_str_before_kernel = self.generate_debug_str(call_args, kernel_name, kernel_id, "before") - debug_str_after_kernel = self.generate_debug_str(call_args, kernel_name, kernel_id, "after") - + debug_str_before_kernel = self.generate_debug_str( + call_args, kernel_name, kernel_id, "before" + ) + debug_str_after_kernel = self.generate_debug_str( + call_args, kernel_name, kernel_id, "after" + ) launch_str = f""" auto launch_call_{kernel_id} = [=]() {{ - int32_t grid_x = {grid_var}.grid_x; - int32_t grid_y = {grid_var}.grid_y; - int32_t grid_z = {grid_var}.grid_z; rtError_t ret; void* ffts_addr = NULL; uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len); if (ret != RT_ERROR_NONE) return ret; void* workspace_addr = NULL; + void* sync_block_lock = NULL; struct __attribute__((packed)) {{ void* ffts_addr __attribute__((aligned(8))); + void* sync_block_lock __attribute__((aligned(8))); void* workspace_addr __attribute__((aligned(8))); {struct_def_body} - int32_t grid_x __attribute__((aligned(4))); - int32_t grid_y __attribute__((aligned(4))); - int32_t grid_z __attribute__((aligned(4))); + int32_t grid_0 __attribute__((aligned(4))); + int32_t grid_1 __attribute__((aligned(4))); + int32_t grid_2 __attribute__((aligned(4))); }} kernel_args = {{ static_cast(ffts_addr), + static_cast(sync_block_lock), static_cast(workspace_addr), {struct_arg_body} - static_cast(grid_x), - static_cast(grid_y), - static_cast(grid_z) + static_cast(grid_0), + static_cast(grid_1), + static_cast(grid_2) }}; - uint32_t block_num = grid_x * grid_y * grid_z; + uint32_t block_num = grid_0 * grid_1 * grid_2; auto arg_ptr = static_cast(&kernel_args); auto arg_size = sizeof(kernel_args); {debug_str_before_kernel} - ret = rtKernelLaunch({kernel_val_name}, block_num, arg_ptr, arg_size, NULL, stream); + ret = rtKernelLaunch({kernel_name}, block_num, arg_ptr, arg_size, NULL, stream_); {debug_str_after_kernel} if (ret != RT_ERROR_NONE) return ret; return ret; @@ -693,43 +691,24 @@ class CppWrapperNpu(CppWrapperCpu): """ return f"launch_call_{kernel_id}", launch_str - def generate_default_grid( - self, - kernel_name: str, - grid_args: List[Any], - gpu: bool = True, - grid_callable: Optional[Callable[..., Any]] = None, - **grid_extra_kwargs, - ): - """ - Generate grid configs for launching a CUDA kernel using the grid - function from triton_heuristics. Because its computation needs - to read kernel config after autotune, it is done in a deferred way - using DeferredNpuDefaultGrid. - """ - checkIfTrue(gpu, "CppWrapperNpu.generate_default_grid does not support non-NPU") - return DeferredNpuDefaultGrid( - kernel_name, grid_args, grid_callable, **grid_extra_kwargs - ) - def generate_kernel_call_npu( - self, - kernel_name: str, - call_args, - grid=None, - device_index=None, - npu=True, - triton=True, - arg_types=None, - raw_args=None, - grid_fn: str = "grid", - triton_meta=None, - autotune_configs=None, - grid_extra_kwargs="", + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + npu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", ): if ( - config.triton.autotune_at_compile_time - and kernel_name not in self.kernel_autotune_names + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names ): # Call PythonWrapperCodegen to create the autotune code block PythonWrapperCodegen.generate_kernel_call( @@ -759,69 +738,25 @@ class CppWrapperNpu(CppWrapperCpu): ) if triton: - device_index, call_args = self.prepare_triton_kernel_call( - device_index, call_args + call_args, arg_types = self.prepare_triton_wrapper_args( + call_args, arg_types ) - _ = self.generate_load_kernel_once(kernel_name, device_index, V.graph) - - # args with value 1 are added into equal_to_1 and constants - # in triton_meta (in the Python codegen) which makes them - # inlined in the PTX and compiled CUBIN - arg_signatures = [] - if ( - triton_meta is not None - and triton_meta.get("configs") - and triton_meta.get("signature") - ): - equal_to_1 = triton_meta["configs"][0].equal_to_1 - call_args = [ - arg - for i, arg in enumerate(call_args) - if i not in equal_to_1 - ] - arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] - # extract the arg signatures from triton_meta - arg_signatures = triton_meta["signature"].values() - arg_signatures = [ - v - for i, v in enumerate(arg_signatures) - if i not in equal_to_1 - ] - + wrapper_name = f"call_{kernel_name}" current_kernel_id = next(self.kernel_callsite_id) - current_grid_id = next(self.grid_id) - - # gen grids - grid_var = f"{kernel_name}_grid_{current_grid_id}" - self.writeline( - DeferredNpuGridLine(kernel_name, grid_var, grid, autotune_configs) - ) - - call, call_args_str = self.generate_launch_call( - call_args, arg_types, arg_signatures, current_kernel_id, grid_var, kernel_name - ) - self.writeline(f"{call_args_str}") - - # add debug printer code for all triton kernel related calls + if wrapper_name not in self._triton_call_wrappers: + self._triton_call_wrappers[wrapper_name] = DeferredNpuTritonCallWrapper( + wrapper_name, kernel_name, arg_types, current_kernel_id + ) + call_args.append(stream) + if V.graph.aot_mode: + call_args.append("kernels") + call_args.append("this->cubin_dir_") debug_printer_manager = V.graph.wrapper_code.debug_printer debug_printer_manager.set_printer_args( - call_args, kernel_name, arg_types, None + call_args[: len(arg_types)], kernel_name, arg_types, None ) with debug_printer_manager: - self.writeline(f"if ({grid_var}.is_non_zero()) {{") - self.writeline( - DeferredNpuKernelLine( - kernel_name, - r" launchKernel({}, {});".format( \ - call, - f'"{kernel_name}"', - ), - (), - self.additional_files, - ), - ) - - self.writeline("}\n") + self.writeline(f"{wrapper_name}({', '.join(call_args)});") else: casted = [] for arg_type, arg in zip(arg_types, call_args): @@ -833,19 +768,19 @@ class CppWrapperNpu(CppWrapperCpu): self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") def generate_kernel_call( - self, - kernel_name: str, - call_args, - grid=None, - device_index=None, - gpu=True, - triton=True, - arg_types=None, - raw_args=None, - grid_fn: str = "grid", - triton_meta=None, - autotune_configs=None, - grid_extra_kwargs="", + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + gpu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", ): """ Override the default value of argument 'gpu' to True here. @@ -892,5 +827,37 @@ class CppWrapperNpu(CppWrapperCpu): grid_extra_kwargs, ) + def finalize_prefix(self): + """Define the triton kernels now that autotuning is finished""" + old_prefix = self.prefix # new content should go at start of prefix + self.prefix = IndentedBuffer() + super().finalize_prefix() + for kernel in self._triton_call_wrappers.values(): + self.prefix.writeline("\n") + kernel.generate(self) + self.prefix.writeline("\n") + self.prefix.splice(old_prefix) + + @staticmethod + def prepare_triton_wrapper_args( + call_args: list[Any], arg_types: list[Any] + ) -> tuple[list[Any], list[Any]]: + new_args = [] + new_args_types = [] + for arg, arg_type in zip(call_args, arg_types): + if isinstance(arg, str): + if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + arg_type = UnwrapUnspecArg(dtype=arg_type) + new_args.append(arg) + elif isinstance(arg, bool): + new_args.append(str(arg).lower()) + elif isinstance(arg, (int, float, SymbolicCallArg)): + new_args.append(str(arg)) + else: + new_args.append(cexpr(V.graph.sizevars.simplify(arg))) + new_args_types.append(arg_type) + return new_args, new_args_types + def make_zero_buffer(self, name): return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));" diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py index 2b47e091af8f49d3662f4c613a97b505f3e9266b..fa37ca8077407067730f3fef994ea0e4148bf3ec 100644 --- a/torch_npu/_inductor/lowering.py +++ b/torch_npu/_inductor/lowering.py @@ -143,7 +143,6 @@ def _register_npu_inductor_fallbacks(): make_reduction("argmin", override_return_dtype=torch.int64) ) - @register_lowering(aten.max, type_promotion_kind=None) def reduce_max(x, dim=None, keepdim=False): if dim is not None: diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py index db9c427e60d69e95e39e9a7b83396198831d6070..2ec168739995f5eb1457c58bfa03d52eaec99874 100644 --- a/torch_npu/_inductor/lowering_op_list.py +++ b/torch_npu/_inductor/lowering_op_list.py @@ -74,7 +74,7 @@ GENERATE_LIST = [ aten.bitwise_and, aten.squeeze, aten.copy, - aten.reciprocal + aten.reciprocal, ] GENERATE_LIST2 = [ diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py index 21275d77f0e69cdcf82b67a06e1a488fecdc6f25..877ceb7dada8bda15549566a0dc17d261d8ac8ca 100644 --- a/torch_npu/_inductor/npu_triton_heuristics.py +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -157,20 +157,29 @@ def do_bench_using_profiling_npu(fn, warmup=2, rep=10, grad_to_none=None, quanti @dataclasses.dataclass class GridNpu(GridExpr): numels: List[str] = None + mode: Literal["python", "cpp"] = "python" def generate(self, meta: dict[str, int]) -> None: numel_args = [] split_axis = meta.get("split_axis", None) split_blocks = meta.get("split_blocks", None) if split_axis is None or split_blocks is None: - raise RuntimeError(f"Could not get split_axis or split_blocks from meta {meta}.") + raise RuntimeError( + f"Could not get split_axis or split_blocks from meta {meta}." + ) def grid_fn(i): if i >= len(split_axis): return "1" axis = split_axis[i] block = split_blocks[i] - return f"({self.numels[axis]} + {block} - 1) // {block}" + if block is None or block == 1: + return self.numels[axis] + if self.mode == "python": + return f"({self.numels[axis]} + {block} - 1) // {block}" + else: + return f"(({self.numels[axis]} + ({block} - 1)) / ({block}))" + self.x_grid = grid_fn(0) self.y_grid = grid_fn(1) self.z_grid = grid_fn(2) @@ -186,8 +195,10 @@ class GridExprNpu(GridExpr): ) -> GridExpr: grid_cls = globals()[inductor_meta["grid_type"]] if not issubclass(grid_cls, GridNpu): - raise AssertionError(f"grid_type in inductor_meta must be subclass of GridNpu" - f"but got {inductor_meta['grid_type']}") + raise AssertionError( + f"grid_type in inductor_meta must be subclass of GridNpu" + f"but got {inductor_meta['grid_type']}" + ) grid = grid_cls(inductor_meta=inductor_meta, mode=mode, numels=numels) if isinstance(cfg, Config): cfg = config_to_dict(cfg) @@ -586,6 +597,11 @@ class NPUCachingAutotuner(CachingAutotuner): "stream": input_stream, # User defined triton kernels will have arbitrary kwarg names "meta": input_launcher.config.kwargs, + "config": config_to_dict(input_launcher.config), + "inductor_meta": self.inductor_meta, + "triton_meta": self.triton_meta, + "def_args": input_launcher.def_args, + "call_args": input_launcher.call_args, } from torch._inductor.codecache import CudaKernelParamCache diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py index 095f1f69cf2bff023c2ee492c726940d801c75c5..24aac049060294dc88ac7617db9f3030d319a02e 100644 --- a/torch_npu/_inductor/utils.py +++ b/torch_npu/_inductor/utils.py @@ -32,8 +32,14 @@ def patch_is_same_tensor(): def patch_is_gpu(): from torch._inductor.utils import GPU_TYPES + GPU_TYPES.append('npu') + def _return_false(device_interface): + return False + + torch._inductor.scheduler.device_need_guard = _return_false + def patch_has_triton(): from torch.utils._triton import has_triton_package diff --git a/torch_npu/csrc/inductor/CMakeLists.txt b/torch_npu/csrc/inductor/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..51913f46b5e35f8e3f16a23eaf221ac8e39da398 --- /dev/null +++ b/torch_npu/csrc/inductor/CMakeLists.txt @@ -0,0 +1,9 @@ +FILE(GLOB _INDUCTOR_SRCS + *.cpp + aoti_runner/*.cpp + aoti_torch/*.cpp) + +LIST(APPEND INDUCTOR_SRCS ${_INDUCTOR_SRCS}) + +# Pass to parent +set(INDUCTOR_SRCS ${INDUCTOR_SRCS} PARENT_SCOPE) \ No newline at end of file diff --git a/torch_npu/csrc/inductor/aoti_package/model_package_loader.h b/torch_npu/csrc/inductor/aoti_package/model_package_loader.h new file mode 100644 index 0000000000000000000000000000000000000000..9f1fbde09b3a0b8ac8570c746559c31ec5ce0eb6 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_package/model_package_loader.h @@ -0,0 +1,40 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +namespace torch::inductor { +class TORCH_API AOTIModelPackageLoader { + public: + AOTIModelPackageLoader(const std::string &model_package_path, + const std::string &model_name = "model", + const bool run_single_threaded = false); + ~AOTIModelPackageLoader(); + + AOTIModelContainerRunner *get_runner(); + std::unordered_map get_metadata(); + + std::vector run(const std::vector &inputs, + void *stream_handle = nullptr); + + // boxed_run will steal the ownership of the input tensors + std::vector boxed_run(std::vector &&inputs, + void *stream_handle = nullptr); + + std::vector get_call_spec(); + void load_constants( + std::unordered_map &constants_map, + bool use_inactive, bool check_full_update); + std::vector get_constant_fqns(); + + private: + std::string temp_dir_; + std::unique_ptr runner_; + std::unordered_map metadata_; + + void load_metadata(const std::string &cpp_filename); +}; + +} // namespace torch::inductor +#endif diff --git a/torch_npu/csrc/inductor/aoti_package/pybind.h b/torch_npu/csrc/inductor/aoti_package/pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..71e6390878ff9cc3c95a56ad62534e67755577a7 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_package/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject *module); + +} // namespace torch::inductor diff --git a/torch_npu/csrc/inductor/aoti_runner/model_container_runner.h b/torch_npu/csrc/inductor/aoti_runner/model_container_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..9a62a8158b1f6c847cae2fdc9756fc330ab262da --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runner/model_container_runner.h @@ -0,0 +1,113 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include +#include + +// Forward declare DynamicLibrary +namespace at { +struct DynamicLibrary; +} + +namespace torch::inductor { +using TensorConstantMap = std::unordered_map; + +class TORCH_API AOTIModelContainerRunner { + public: + AOTIModelContainerRunner() = delete; + AOTIModelContainerRunner(const AOTIModelContainerRunner &other) = delete; + AOTIModelContainerRunner(AOTIModelContainerRunner &&other) = delete; + AOTIModelContainerRunner &operator=(const AOTIModelContainerRunner &other) = + delete; + AOTIModelContainerRunner &operator=(AOTIModelContainerRunner &&other) = + delete; + virtual ~AOTIModelContainerRunner(); + + std::vector run(const std::vector &inputs, + void *stream_handle = nullptr); + + // boxed_run will steal the ownership of the input tensors + std::vector boxed_run(std::vector &&inputs, + void *stream_handle = nullptr); + + std::unordered_map getConstantNamesToOriginalFQNs() + const; + std::unordered_map getConstantNamesToDtypes() const; + + void update_inactive_constant_buffer(const TensorConstantMap &const_map); + void update_constant_buffer( + std::unordered_map &tensor_map, + bool use_inactive, bool validate_full_updates); + void update_constant_buffer(const TensorConstantMap &const_map, + bool use_inactive, bool validate_full_updates); + void run_const_fold(bool use_inactive, + AOTInductorStreamHandle npu_stream_handle = nullptr); + void swap_constant_buffer(); + + std::vector get_call_spec(); + + protected: + AOTIModelContainerRunner(const std::string &model_so_path, size_t num_models, + const std::string &device_str, + const std::string &cubin_dir, + const bool run_single_threaded); + + virtual std::vector run_impl( + std::vector &input_handles, void *stream_handle); + + std::unique_ptr model_so_; + decltype(&AOTInductorModelContainerCreateWithDevice) create_func_{nullptr}; + decltype(&AOTInductorModelContainerDelete) delete_func_{nullptr}; + decltype(&AOTInductorModelContainerGetNumOutputs) get_num_outputs_func_{ + nullptr}; + decltype(&AOTInductorModelContainerRun) run_func_{nullptr}; + decltype(&AOTInductorModelContainerGetNumConstants) get_num_constants_func_{ + nullptr}; + decltype(&AOTInductorModelContainerGetConstantName) get_constant_name_func_{ + nullptr}; + decltype(&AOTInductorModelContainerGetConstantOriginalFQN) + get_constant_original_fqn_func_{nullptr}; + decltype(&AOTInductorModelContainerGetConstantDtype) get_constant_dtype_func_{ + nullptr}; + decltype(&AOTInductorModelContainerUpdateConstantBuffer) + update_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerUpdateInactiveConstantBuffer) + update_inactive_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerRunConstantFolding) run_const_fold_func_{ + nullptr}; + decltype(&AOTInductorModelContainerSwapConstantBuffer) + swap_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerGetCallSpec) get_call_spec_func_{nullptr}; + + AOTInductorModelContainerHandle container_handle_ = nullptr; + + AOTIProxyExecutorHandle proxy_executor_handle_; + + private: + std::unique_ptr proxy_executor_; +}; + +using CreateAOTIModelRunnerFunc = std::unique_ptr (*)( + const std::string &model_so_path, size_t num_models, + const std::string &device_str, const std::string &bin_dir, + const bool run_single_threaded); + +// Return a global map "device name" -> "aoti model runner create function" for +// all registered in AOTI external backends +TORCH_API std::unordered_map & +getAOTIModelRunnerRegistry(); + +// To register a new external backend in AOTI one needs to create an instance of +// this struct. It is not thread-safe. Becase it is expected to be called during +// the initialization of the program. +struct TORCH_API RegisterAOTIModelRunner { + RegisterAOTIModelRunner( + const std::string &name, + CreateAOTIModelRunnerFunc create_aoti_model_runner_fn) { + getAOTIModelRunnerRegistry()[name] = create_aoti_model_runner_fn; + } +}; + +} // namespace torch::inductor +#endif diff --git a/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.cpp b/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a3776f75ca2d81052fdb6c3e8c106980257e6b61 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.cpp @@ -0,0 +1,87 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#include +#include + +#include + +#ifndef _WIN32 +#include +#else +#include +namespace fs = std::filesystem; +#endif + +namespace { +bool file_exists(std::string &path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc{}; + return lstat(path.c_str(), &rc) == 0; +#endif +} +} // namespace + +namespace torch::inductor { + +AOTIModelContainerRunnerNpu::AOTIModelContainerRunnerNpu( + const std::string &model_so_path, size_t num_models, + const std::string &device_str, const std::string &cubin_dir, + const bool run_single_threaded) + : AOTIModelContainerRunner(model_so_path, num_models, device_str, cubin_dir, + run_single_threaded) { + model_so_path_ = model_so_path; + init_flag_ = false; +} + +AOTIModelContainerRunnerNpu::~AOTIModelContainerRunnerNpu() = default; + +void AOTIModelContainerRunnerNpu::init_proxy_executor() { + if (init_flag_) return; + + init_flag_ = true; + size_t lastindex = model_so_path_.find_last_of('.'); + std::string json_filename = model_so_path_.substr(0, lastindex) + "_npu.json"; + + if (file_exists(json_filename)) { + proxy_executor_npu_ = + std::make_unique( + json_filename, false); + proxy_executor_handle_ = + reinterpret_cast(proxy_executor_npu_.get()); + } else { + proxy_executor_handle_ = nullptr; + } +} + +std::vector AOTIModelContainerRunnerNpu::run_impl( + std::vector &input_handles, void *stream_handle) { + init_proxy_executor(); + c10_npu::NPUStream npu_stream = c10_npu::getCurrentNPUStream(); + return AOTIModelContainerRunner::run_impl( + input_handles, reinterpret_cast(npu_stream.stream())); +} + +std::vector AOTIModelContainerRunnerNpu::run_with_npu_stream( + const std::vector &inputs, + const c10_npu::NPUStream &npu_stream) +{ + init_proxy_executor(); + c10_npu::NPUStream cur_npu_stream = c10_npu::getCurrentNPUStream(); + return run(inputs, reinterpret_cast(cur_npu_stream.stream())); +} + +namespace { +std::unique_ptr create_aoti_runner_npu( + const std::string &model_so_path, size_t num_models, + const std::string &device_str, const std::string &cubin_dir, + const bool run_single_threaded) { + return std::make_unique( + model_so_path, num_models, device_str, cubin_dir, run_single_threaded); +} +} // namespace + +RegisterAOTIModelRunner register_npu_runner("npu", &create_aoti_runner_npu); + +} // namespace torch::inductor +#endif diff --git a/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.h b/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.h new file mode 100644 index 0000000000000000000000000000000000000000..af16757fbbed0ab8d3aae5bfc34a386ddf7acec3 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.h @@ -0,0 +1,41 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +namespace torch::inductor { + +// NOTICE: Following APIs are subject to change due to active development +// We provide NO BC guarantee for these APIs +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class AOTIModelContainerRunnerNpu + : public AOTIModelContainerRunner { +public: + // @param device_str: npu device string, e.g. "npu", "npu:0" + AOTIModelContainerRunnerNpu(const std::string &model_so_path, + size_t num_models = 1, + const std::string &device_str = "npu", + const std::string &cubin_dir = "", + const bool run_single_threaded = false); + + ~AOTIModelContainerRunnerNpu() override; + + std::vector run_impl(std::vector &input_handles, + void *stream_handle) override; + + std::vector run_with_npu_stream( + const std::vector &inputs, + const c10_npu::NPUStream &npu_stream); + void init_proxy_executor(); + + void set_proxy_executor(AOTIProxyExecutorHandle handle); + +private: + std::string model_so_path_; + bool init_flag_; + std::unique_ptr proxy_executor_npu_; +}; + +} // namespace torch::inductor +#endif diff --git a/torch_npu/csrc/inductor/aoti_runner/pybind.h b/torch_npu/csrc/inductor/aoti_runner/pybind.h new file mode 100644 index 0000000000000000000000000000000000000000..93855eaab25030c6e1a5f838f19a45251896d56d --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runner/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIRunnerBindings(PyObject *module); + +} // namespace torch::inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/arrayref_tensor.h b/torch_npu/csrc/inductor/aoti_runtime/arrayref_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..5c451645a9c429ead457f91ed435a7f00b3dd54d --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/arrayref_tensor.h @@ -0,0 +1,327 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::aot_inductor { + +// Can't use c10::ArrayRef because it's not truly header-only and +// pulls in other c10 headers. This is (sadly) copy-pasted and +// adapted. +template +class MiniArrayRef final { + public: + using iterator = T *; + using const_iterator = const T *; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + private: + /// The start of the array, in an external buffer. + T *Data; + + /// The number of elements. + size_type Length; + + public: + /// @name Constructors + /// @{ + + /// Construct an empty MiniArrayRef. + constexpr MiniArrayRef() : Data(nullptr), Length(0) {} + + /// Construct an MiniArrayRef from a single element. + // TODO Make this explicit + constexpr MiniArrayRef(const T &OneElt) : Data(&OneElt), Length(1) {} + + /// Construct an MiniArrayRef from a pointer and length. + constexpr MiniArrayRef(T *data, size_t length) : Data(data), Length(length) {} + + /// Construct an MiniArrayRef from a range. + constexpr MiniArrayRef(T *begin, T *end) : Data(begin), Length(end - begin) {} + + template ().data())>, + T *>>> + MiniArrayRef(Container &container) + : Data(container.data()), Length(container.size()) {} + + /// Construct an MiniArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because MiniArrayRef can't work on a std::vector + // bitfield. + template + MiniArrayRef(const std::vector &Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert(!std::is_same_v, + "MiniArrayRef cannot be constructed from a " + "std::vector bitfield."); + } + + /// Construct an MiniArrayRef from a std::array + template + constexpr MiniArrayRef(std::array &Arr) : Data(Arr.data()), Length(N) {} + + /// Construct an MiniArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-array*) + constexpr MiniArrayRef(T (&Arr)[N]) : Data(Arr), Length(N) {} + + // /// Construct an MiniArrayRef from an empty C array. + constexpr MiniArrayRef(const volatile void *Arr) : Data(nullptr), Length(0) {} + + /// Construct an MiniArrayRef from a std::initializer_list. + constexpr MiniArrayRef(const std::initializer_list &Vec) + : Data(std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { return Data; } + constexpr iterator end() const { return Data + Length; } + + // These are actually the same as iterator, since MiniArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { return Data; } + constexpr const_iterator cend() const { return Data + Length; } + + constexpr reverse_iterator rbegin() const { return reverse_iterator(end()); } + constexpr reverse_iterator rend() const { return reverse_iterator(begin()); } + + /// empty - Check if the array is empty. + constexpr bool empty() const { return Length == 0; } + + constexpr T *data() const { return Data; } + + /// size - Get the array size. + constexpr size_t size() const { return Length; } + + /// equals - Check for element-wise equality. + constexpr bool equals(MiniArrayRef RHS) const { + return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T &operator[](size_t Index) const { return Data[Index]; } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, MiniArrayRef> &operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U &&Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, MiniArrayRef> &operator=( + std::initializer_list) = delete; +}; + +using MiniIntArrayRef = MiniArrayRef; + +static_assert(sizeof(MiniIntArrayRef) == sizeof(void *) + sizeof(size_t), + "changing the size of MiniArrayRef breaks ABI compatibility!"); + +inline bool is_contiguous_strides_for_shape(int64_t ndim, + const int64_t *strides_ptr, + const int64_t *sizes_ptr) { + int64_t z = 1; + for (int64_t d = ndim - 1; d >= 0; d--) { + const auto &size_d = sizes_ptr[d]; + if (size_d != 1) { + if (strides_ptr[d] == z) { + z *= size_d; + } else { + return false; + } + } + } + return true; +} + +// Shim for AOTI generated code to pretend a raw array works like an +// AtenTensorHandle. +template +class ArrayRefTensor { + public: + ArrayRefTensor() = default; + + explicit ArrayRefTensor(MiniArrayRef arr, + MiniArrayRef sizes, + MiniArrayRef strides, + int32_t device_type, int32_t device_idx) + : arrayRef_(arr), + sizes_(sizes), + strides_(strides), + device_type_(device_type), + device_idx_(device_idx) { + assert(sizes.size() == strides.size()); + assert(is_contiguous_strides_for_shape(sizes.size(), strides.data(), + sizes.data())); + } + + AtenTensorHandle expensiveCopyToTensor() const { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_empty_strided(sizes_.size(), sizes_.data(), strides_.data(), + aoti_torch_dtype>(), + device_type_, device_idx_, &result)); + void *dataPtr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(result, &dataPtr)); + std::memcpy(dataPtr, data(), numel() * sizeof(T)); + return result; + } + + // We need to look the same as RAIIAtenTensorHandle, which returns + // an owning AtenTensorHandle from release(). So, we allocate one! + AtenTensorHandle release() { return expensiveCopyToTensor(); } + + AtenTensorHandle borrowAsTensor() const { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( + data(), sizes_.size(), sizes_.data(), strides_.data(), 0, + aoti_torch_dtype>(), device_type_, device_idx_, + &result, aoti_torch_layout_strided(), nullptr, 0)); + return result; + } + + // We don't need to free any memory. + void reset() {} + + auto sizes() const { return sizes_; } + + auto strides() const { return strides_; } + + auto device_type() const { return device_type_; } + + auto device_idx() const { return device_idx_; } + + T *data() const { return arrayRef_.data(); } + + auto numel() const { return arrayRef_.size(); } + + void set_arrayref(MiniArrayRef new_arrayref) { arrayRef_ = new_arrayref; } + + private: + MiniArrayRef arrayRef_; + // We expect generated code to have statically available sizes & + // strides for us. + MiniArrayRef sizes_; + MiniArrayRef strides_; + int32_t device_type_ = 0; + int32_t device_idx_ = 0; + // We continue to zero-initialize this field in case we repurpose + // the space later; having predictable contents can only help. + int32_t unusedDoNotRemoveForABICompatibility_ = 0; +}; + +static_assert(sizeof(ArrayRefTensor) == + 3 * sizeof(MiniIntArrayRef) + 3 * sizeof(int32_t) + + (alignof(ArrayRefTensor) > 4 ? sizeof(int32_t) : 0), + "changing the size of ArrayRefTensor breaks ABI compatibility!"); + +template +inline ArrayRefTensor reinterpret_tensor_wrapper( + const ArrayRefTensor &self, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset) { + // REVIEW: we should add a way to build the DSO in debug mode during + // tests so we can have checks like this! + assert(is_contiguous_strides_for_shape(ndim, strides_ptr, sizes_ptr)); + return ArrayRefTensor(MiniArrayRef(self.data() + storage_offset, + self.numel() - storage_offset), + MiniArrayRef(sizes_ptr, ndim), + MiniArrayRef(strides_ptr, ndim), + self.device_type(), self.device_idx()); +} + +template +inline T *get_data_ptr_wrapper(ArrayRefTensor &tensor) { + return tensor.data(); +} + +template +inline T *get_data_ptr_wrapper(const MiniArrayRef &arr) { + return arr.data(); +} + +template +inline const ArrayRefTensor &unwrap_raii_handle_if_needed( + const ArrayRefTensor &tensor) { + return tensor; +} + +template +inline ArrayRefTensor &unwrap_raii_handle_if_needed( + ArrayRefTensor &tensor) { + return tensor; +} + +template +inline const ArrayRefTensor &wrap_with_raii_handle_if_needed( + const ArrayRefTensor &tensor) { + return tensor; +} + +template +inline ArrayRefTensor &wrap_with_raii_handle_if_needed( + ArrayRefTensor &tensor) { + return tensor; +} + +template +inline ArrayRefTensor wrap_with_raii_handle_if_needed( + ArrayRefTensor &&tensor) { + return std::move(tensor); +} + +template +inline RAIIAtenTensorHandle expensive_copy_to_tensor_if_needed( + const ArrayRefTensor &tensor) { + return tensor.expensiveCopyToTensor(); +} + +inline AtenTensorHandle expensive_copy_to_tensor_if_needed( + AtenTensorHandle handle) { + return handle; +} + +template +const T ©_arrayref_tensor_to_tensor(const T &t) { + return t; +} + +template +RAIIAtenTensorHandle copy_arrayref_tensor_to_tensor( + const ArrayRefTensor &art) { + return art.expensiveCopyToTensor(); +} + +template +const T &borrow_arrayref_tensor_as_tensor(const T &t) { + return t; +} + +template +RAIIAtenTensorHandle borrow_arrayref_tensor_as_tensor( + const ArrayRefTensor &art) { + return art.borrowAsTensor(); +} + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/device_utils.h b/torch_npu/csrc/inductor/aoti_runtime/device_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..ed0ae6b10359f96b41b599117f540ee0fdc64226 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/device_utils.h @@ -0,0 +1,43 @@ +#pragma once + +#if defined(USE_NPU) + +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" + +typedef void *NPUdeviceptr; + +typedef void *NPUfunction; + +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ + do { \ + const aclError code = EXPR; \ + if (code != ACL_SUCCESS) { \ + throw std::runtime_error(std::string("NPU error core: ") + \ + std::to_string(code) + std::string(" ") + \ + std::string(__FILE__) + std::string(":") + \ + std::to_string(__LINE__)); \ + } \ + } while (0) + +namespace torch::aot_inductor { + +using DeviceStreamType = aclrtStream; + +} // namespace torch::aot_inductor + +#else + +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ + bool ok = EXPR; \ + if (!ok) { \ + throw std::runtime_error("CPU runtime error"); \ + } + +namespace torch::aot_inductor { + +using DeviceStreamType = void *; + +} // namespace torch::aot_inductor + +#endif // USE_NPU diff --git a/torch_npu/csrc/inductor/aoti_runtime/interface.h b/torch_npu/csrc/inductor/aoti_runtime/interface.h new file mode 100644 index 0000000000000000000000000000000000000000..4ab14859459930693cbd5e10fd1b4882c10d0b50 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/interface.h @@ -0,0 +1,183 @@ +#pragma once + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +extern "C" { +struct AOTInductorModelOpaque; +using AOTInductorModelHandle = AOTInductorModelOpaque *; + +struct AOTInductorModelContainerOpaque; +using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque *; + +struct AOTInductorStreamOpaque; +using AOTInductorStreamHandle = AOTInductorStreamOpaque *; + +struct AOTInductorConstantMap; +using AOTInductorConstantMapHandle = AOTInductorConstantMap *; + +// TODO: Deprecate this API. This was kept for BC compatibility. +// Please use AOTInductorModelContainerCreateWithDevice instead. +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle *container_handle, size_t num_models, + bool is_cpu, const char *cubin_dir); + +// Creates an AOTInductor model container. The parameter num_models +// specifies the number of model instances that may be run concurrently for +// the same input model. +// `device_str` MUST NOT be nullptr. It must be a valid device string, e.g. +// "cpu", "npu", "npu:0", etc. +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle *container_handle, size_t num_models, + const char *device_str, const char *cubin_dir); + +// Deletes the AOTInductor model container. +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle); + +// Runs the inference. +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Single-threaded variant of previous. +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Retrieves the number of constants for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, size_t *num_constants); + +// Retrieves a constant's name. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, size_t idx, + const char **name); + +// Retrieves a constant's original FQN. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, size_t idx, + const char **original_fqn); + +// Retrieves whether a constant is from folded. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, size_t idx, + bool *from_folded); + +// Retrieves the inductor constant type. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantType( + AOTInductorModelContainerHandle container_handle, size_t idx, + int32_t *type); + +// Retrieves a constant's dtype. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, size_t idx, + int32_t *dtype); + +// Setup the constant buffer in model container with provided ConstantMap +// use_inactive should be set as true if the inactive buffer is to be updated. +// validate_full_update checks if all constants are included in the ConstantMap +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, bool use_inactive, + bool validate_full_update); + +// Setup the inactive constant buffer in model container with provided +// ConstantMap +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Run constant folding on constant buffer. +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Swap the constant buffer being used to the inactive one. +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle); + +// Retrieves the number of inputs for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, size_t *ret_num_inputs); + +// Retrieves the input name at the given index. +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, size_t input_idx, + const char **ret_input_names); + +// Retrieves the number of outputs for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, size_t *ret_num_outputs); + +// Retrieves the output name at the given index. +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, size_t output_idx, + const char **ret_output_names); + +// Creates an AOTInductorModel instance. This is a thin and light wrapper +// around the compiled model; it doesn't handle concurrency, queueing, device +// management, etc. Use this if bare-metal performance is needed and you are +// willing to handle other "management" aspects yourself. +// +// constant_map_handle is an opaque type to satisfy the C ABI. It should be a +// std::unordered_map*. +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle *model_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Run an AOTInductorModel (see AOTInductorModelCreate for when one should use +// this function versus AOTInductorModelContainerRun). +AOTIRuntimeError AOTInductorModelRun(AOTInductorModelHandle model_handle, + AtenTensorHandle *input_handles, + AtenTensorHandle *output_handles); + +// Replace AOTInductorModel's constant map. Note it doesn't handle concurrency +// so be sure to handle ordering if AOTInductorModelRun is ran concurrently. +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Delete an AOTInductorModel created by AOTInductorModelCreate. +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle); + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, size_t *ret_num_outputs); + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, const char **in_spec, + const char **out_spec); + +} // extern "C" diff --git a/torch_npu/csrc/inductor/aoti_runtime/model.h b/torch_npu/csrc/inductor/aoti_runtime/model.h new file mode 100644 index 0000000000000000000000000000000000000000..4b0e58015130022db37a557cd4e3c4f710eebd77 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/model.h @@ -0,0 +1,592 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include +#include + +#define AOTI_RUNTIME_CHECK(EXPR, MSG) \ + do { \ + bool ok = EXPR; \ + if (!ok) { \ + throw std::runtime_error(MSG); \ + } \ + } while (0) + +// At codegen time, we write out a binary file called constants.bin. +// We then turn the raw binary to an object file that exposes this +// symbol and link it into the final .so. +// The constants are NOT readonly because they may be mutated. +// NOLINTNEXTLINE(*array*) +extern uint8_t _binary_constants_bin_start[]; +// NOLINTNEXTLINE(*array*) +extern uint8_t _binary_constants_bin_end[]; + +#define AOTI_CONST_ALIGNMENT 64 + +namespace { + +using RAIIDataPtr = std::unique_ptr>; + +#ifdef USE_NPU + +using RAIIDataPtr = std::unique_ptr>; + +RAIIDataPtr RAII_npuMalloc(size_t num_bytes) { + void *data_ptr; + // aclrtMalloc doesn't support allocate 0-bytes. In this case, + // e.g, model has no weight, we should do padding. + size_t padding_bytes = 32; + if (num_bytes == 0) num_bytes = padding_bytes; + AOTI_RUNTIME_DEVICE_CHECK( + aclrtMalloc((void **)&data_ptr, num_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + auto deleter = [](void *ptr) { AOTI_RUNTIME_DEVICE_CHECK(aclrtFree(ptr)); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#endif // USE_NPU + +RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { + void *data_ptr = std::malloc(num_bytes); + if (!data_ptr) { + throw std::bad_alloc(); + } + auto deleter = [](void *ptr) { std::free(ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + +} // anonymous namespace + +namespace torch::aot_inductor { +enum ConstantType : uint8_t { + Unknown = 0, + Parameter = 1, + Buffer = 2, + TensorConstant = 3, + FoldedConstant = 4, +}; + +using ConstantMap = std::unordered_map; + +// valid device strs are: cpu, npu, npu:0, npu:1, ... +// Update the list here if more devices are supported in the future +inline void parse_device_str(const std::string &device_str, + int32_t &device_type, int32_t &device_idx) { + std::regex re("(cpu|npu)(:([0-9]+))?"); + std::smatch sm; + bool matched = std::regex_match(device_str, sm, re); + AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str); + + if (sm[1].str() == "cpu") { + device_type = aoti_torch_device_type_cpu(); +#ifdef USE_NPU + } else if (sm[1].str() == "npu") { + device_type = aoti_torch_device_type_npu(); +#endif + } else { + AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str); + } + const size_t default_sm = 3; + if (sm[default_sm].matched) { + device_idx = stoi(sm[default_sm].str()); + } else { + device_idx = -1; + } +} + +// Defines the base class for AOTInductorModel, which is generated by the +// AOTInductor cpp codegen. Since we do not need dynamic dispatch, we rely +// on curiously recurring template pattern (CRTP) to save some runtime +// v-table overhead. The generated AOTInductorModel is specialized with +// methods such as run_impl. +template +class AOTInductorModelBase { + public: + AOTInductorModelBase(size_t num_inputs, size_t num_outputs, + size_t num_constants, const std::string &device_str, + std::optional cubin_dir, + bool include_weights = true) + : inputs_info_(num_inputs), + outputs_info_(num_outputs), + constants_info_(num_constants), + cubin_dir_(std::move(cubin_dir)), + include_weights(include_weights) { + parse_device_str(device_str, device_type_, device_idx_); + +#ifdef USE_NPU + if (device_idx_ == -1) { + AOTI_RUNTIME_DEVICE_CHECK(aclrtSetDevice(0)); + AOTI_RUNTIME_DEVICE_CHECK(aclrtGetDevice(&device_idx_)); + } else { + AOTI_RUNTIME_DEVICE_CHECK(aclrtSetDevice(device_idx_)); + } +#endif // USE_NPU + } + + // NOLINTNEXTLINE(modernize-use-equals-default) + ~AOTInductorModelBase() { +#ifdef USE_NPU + if (run_finished_) { + auto code = aclrtDestroyEvent(*run_finished_); + if (code != ACL_SUCCESS) { + std::cerr + << "Failed to destroy NPU event in AOTInductor model erorr code: " + << code << std::endl; + } + } +#endif // USE_NPU + } + + AOTInductorModelBase(AOTInductorModelBase &&) = delete; + AOTInductorModelBase &operator=(AOTInductorModelBase &&) = delete; + AOTInductorModelBase(const AOTInductorModelBase &) = delete; + AOTInductorModelBase &operator=(const AOTInductorModelBase &) = delete; + + void run(AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the + // array itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor) { +#if defined(USE_NPU) + if (!run_finished_) { + aclrtEvent run_finished; + AOTI_RUNTIME_DEVICE_CHECK(aclrtCreateEvent(&run_finished)); + run_finished_.emplace(run_finished); + } +#else + run_finished_ = false; +#endif + + auto *model = static_cast(this); + model->run_impl(input_handles, output_handles, stream, proxy_executor); + +#if defined(USE_NPU) + AOTI_RUNTIME_DEVICE_CHECK(aclrtRecordEvent(*run_finished_, stream)); +#else + run_finished_ = true; +#endif + } + + // Non-thread-aware variant of run(). Obviously unsafe to use in a threaded + // environment :) + void run_single_threaded( + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the array + // itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor) { + // don't bother with any of the run_finished stuff; this is unsafe to call + // in a threaded context + auto *model = static_cast(this); + model->run_impl(input_handles, output_handles, stream, proxy_executor); + } + + std::unordered_map run_const_fold( + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor, + bool initialization = false) { +#if defined(USE_NPU) + if (!run_finished_) { + aclrtEvent run_finished; + AOTI_RUNTIME_DEVICE_CHECK(aclrtCreateEvent(&run_finished)); + run_finished_.emplace(run_finished); + } +#else + run_finished_ = false; +#endif + + auto *model = static_cast(this); + auto folded_constants = + model->const_run_impl(stream, proxy_executor, initialization); + +#if defined(USE_NPU) + AOTI_RUNTIME_DEVICE_CHECK(aclrtRecordEvent(*run_finished_, stream)); +#else + run_finished_ = true; +#endif + return folded_constants; + } + + void load_constants() { + size_t num_constants = this->num_constants(); + constants_map_->reserve(num_constants); + + std::vector constants_internal_offset(num_constants); + size_t blob_size = 0; + compute_constant_blob(blob_size, constants_internal_offset); +#if defined(USE_NPU) + constant_blob_ = RAII_npuMalloc(blob_size); +#else + constant_blob_ = RAII_cpuMalloc(blob_size); +#endif + if (!include_weights) { + return; + } + + size_t bytes_read = 0; + for (size_t i = 0; i < num_constants; i++) { + bool from_folded = this->constant_from_folded(i); + if (from_folded) { + continue; + } + std::string name = this->constant_name(i); + size_t data_size = this->constant_data_size(i); + uint8_t *internal_ptr = + (data_size != 0) ? constant_ptr(constants_internal_offset[i], + bytes_read, data_size, from_folded) + : nullptr; + bytes_read += data_size; + + // Create at::Tensor from copied memory. + auto dtype = this->constant_dtype(i); + auto ndim = this->constant_ndim(i); + auto size = this->constant_shape(i); + auto stride = this->constant_stride(i); + auto offset = this->constant_offset(i); + auto layout = this->constant_layout(i); + auto opaque_metadata_ptr = this->opaque_metadata(i); + auto opaque_metadata_size = this->opaque_metadata_size(i); + + AtenTensorHandle tensor_handle = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu_v2( + internal_ptr, ndim, size, stride, offset, dtype, device_type_, + device_idx_, &tensor_handle, layout, opaque_metadata_ptr, + opaque_metadata_size)); + constants_map_->emplace(std::move(name), tensor_handle); + } + if (constants_map_) { + this->update_constants_array_from_map(); + } + } + + RAIIDataPtr &&release_constant_blob() { return std::move(constant_blob_); } + + std::shared_ptr> get_constants_array() { + return constants_; + } + + int32_t get_device_type() const { return device_type_; } + + int32_t get_device_idx() const { return device_idx_; } + + uint8_t *constant_ptr(size_t constant_offset, size_t bytes_read, + size_t data_size, bool skip_copy) { + auto *constants_ptr = static_cast(constant_blob_.get()); + uint8_t *internal_ptr = constants_ptr + constant_offset; + // TODO: Handle shared storage case. + if (!skip_copy) { +#if defined(USE_NPU) + AOTI_RUNTIME_DEVICE_CHECK(aclrtMemcpy( + internal_ptr, data_size, _get_constants_start() + bytes_read, + data_size, ACL_MEMCPY_HOST_TO_DEVICE)); +#else + memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size); +#endif + } + return internal_ptr; + } + + void compute_constant_blob(size_t &blob_size, + std::vector &constants_internal_offset) { + size_t num_constants = this->num_constants(); + blob_size = 0; + for (size_t i = 0; i < num_constants; i++) { + size_t data_size = this->constant_data_size(i); + if (data_size % AOTI_CONST_ALIGNMENT) { + data_size = AOTI_CONST_ALIGNMENT + + (data_size / AOTI_CONST_ALIGNMENT) * AOTI_CONST_ALIGNMENT; + } + constants_internal_offset[i] = blob_size; + blob_size += data_size; + } + } + + size_t num_inputs() const { return inputs_info_.size(); } + + size_t num_outputs() const { return outputs_info_.size(); } + + size_t num_constants() const { return constants_info_.size(); } + + const char *input_name(int64_t idx) const { + return inputs_info_.at(idx).name; + } + + const char *output_name(int64_t idx) const { + return outputs_info_.at(idx).name; + } + + const char *constant_name(int64_t idx) const { + return constants_info_.at(idx).name; + } + + size_t constant_ndim(int64_t idx) { + return constants_info_.at(idx).shape.size(); + } + + const int64_t *constant_shape(int64_t idx) const { + return constants_info_.at(idx).shape.data(); + } + + const int64_t *constant_stride(int64_t idx) const { + return constants_info_.at(idx).stride.data(); + } + + int32_t constant_dtype(int64_t idx) const { + return constants_info_.at(idx).dtype; + } + + int32_t constant_layout(int64_t idx) const { + return constants_info_.at(idx).layout; + } + + size_t constant_offset(int64_t idx) const { + return constants_info_.at(idx).offset; + } + + size_t constant_data_size(int64_t idx) const { + return constants_info_.at(idx).data_size; + } + + const char *constant_original_fqn(int64_t idx) const { + return constants_info_.at(idx).original_fqn; + } + + const uint8_t *opaque_metadata(int64_t idx) const { + return constants_info_.at(idx).opaque_metadata.data(); + } + + size_t opaque_metadata_size(int64_t idx) { + return constants_info_.at(idx).opaque_metadata.size(); + } + + bool constant_from_folded(int64_t idx) const { + return constants_info_.at(idx).from_folded; + } + + int32_t constant_type(int64_t idx) const { + return constants_info_.at(idx).type; + } + + const char *get_in_spec() const { return in_spec_.c_str(); } + + const char *get_out_spec() const { return out_spec_.c_str(); } + + void update_constants_array_from_map() { + if (!constants_map_) { + throw std::runtime_error{ + "constants_map_ was not ready when constants_ is trying to be " + "constructed from it!"}; + } + if (!constants_) { + constants_ = + std::make_shared>(constants_info_.size()); + } else { + constants_->resize(constants_info_.size()); + } + int idx = 0; + for (const auto &info : constants_info_) { + const auto it = constants_map_->find(info.name); + if (it != constants_map_->end()) { + constants_->at(idx) = ConstantHandle(it->second); + } + idx++; + } + } + + void update_constants_map(std::shared_ptr constants_map, + bool remap_constants_array = true) { + constants_map_ = std::move(constants_map); + if (remap_constants_array) { + update_constants_array_from_map(); + } + } + + // This function allows us to update the constants_ that is used to look up + // the corresponding constant tensor during runtime. + void update_constants_array( + std::shared_ptr> constants_array) { + constants_ = std::move(constants_array); + } + + /// Returns true if the model is complete. + bool is_finished() { +#if defined(USE_NPU) + if (!run_finished_) { + throw std::runtime_error{"Model NPU event was not initialized"}; + } + aclrtEventRecordedStatus recordStatus = ACL_EVENT_RECORDED_STATUS_NOT_READY; + AOTI_RUNTIME_DEVICE_CHECK( + aclrtQueryEventStatus(*run_finished_, &recordStatus)); + + if (recordStatus == ACL_EVENT_RECORDED_STATUS_COMPLETE) { + return true; + } else { + return false; + } +#else + return run_finished_; +#endif + } + + /// Synchronizes completion event. + void wait_for_completion() {} + + protected: + uint8_t *_get_constants_start() { +#ifndef USE_MMAP_SELF + // NOLINTNEXTLINE(*const-cast*) + return const_cast(_binary_constants_bin_start); +#else + if (self_mmap) { + return self_mmap; + } + Dl_info dl_info; + // get pointer to constant which are appended to the binary + AOTI_RUNTIME_CHECK(dladdr(__func__, &dl_info), + "Can't find shared library name"); + int fd = open(dl_info.dli_fname, O_RDONLY); + AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened"); + auto fsize = lseek(fd, 0, SEEK_END); + auto weights_size = + reinterpret_cast(_binary_constants_bin_start)[0]; + auto magic_number = + reinterpret_cast(_binary_constants_bin_start)[1]; + auto weights_offset = fsize - weights_size; + AOTI_RUNTIME_CHECK((weights_offset & 0x3fff) == 0, + "weights_offset must be aligned to 16K boundary"); + auto ptr = mmap(NULL, weights_size, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, + weights_offset); + close(fd); + AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed"); + self_mmap = static_cast(ptr); + AOTI_RUNTIME_CHECK( + reinterpret_cast(self_mmap + weights_size - + sizeof(uint64_t))[0] == magic_number, + "Weigths data seems corrupt"); + return self_mmap; +#endif + } + struct ParamInfo { + const char *name = nullptr; + }; + + struct ConstInfo { + const char *name = nullptr; + std::vector shape; + std::vector stride; + int32_t dtype{}; + int64_t offset{}; + size_t data_size{}; + int32_t layout{}; + std::vector opaque_metadata; + int64_t opaque_metadata_size{}; + const char *original_fqn = nullptr; + bool from_folded{}; + int32_t type{}; + }; + + std::vector inputs_info_; + std::vector outputs_info_; + std::vector constants_info_; + std::string in_spec_; + std::string out_spec_; + + std::shared_ptr constants_map_; + std::shared_ptr> constants_; + + // Holds the blob storage for constants' at::Tensor. + RAIIDataPtr constant_blob_; + +#ifdef USE_MMAP_SELF + uint8_t *self_mmap = NULL; +#endif + + // A directory with npu binary files, e.g. compiled kernels, etc. + const std::optional cubin_dir_; + + // This is the flag that implies whether the weight is included in the model. + // If True, we would prepare the weight when loading the model, otherwise the + // model will be loaded without weights, and need to be provided by the user. + bool include_weights; + + // Record if the model finishes an inference run so that its owning + // AOTModelContainer can re-use this instance. +#if defined(USE_NPU) + std::optional run_finished_; +#else + bool run_finished_{}; +#endif + + int32_t device_type_{}; + int32_t device_idx_{}; +}; + +// Codegen-ed classes can derive from this to keep pointers to loaded kernels. +class AOTInductorModelKernelsBase { + public: + virtual ~AOTInductorModelKernelsBase() = default; +}; + +class AOTInductorModel : public AOTInductorModelBase { + public: + AOTInductorModel(std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string &device_str, + std::optional cubin_dir, + bool include_weights = true); + + std::unordered_map const_run_impl( + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor, + bool initialization = false); + + void _const_run_impl(std::vector &output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor); + + void run_impl( + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the array + // itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor); + + template + Outputs run_impl_minimal_arrayref_interface( + const Inputs &inputs, DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor); + + static std::unique_ptr Create( + std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string &device_str, std::optional cubin_dir) { + return std::make_unique(std::move(constants_map), + std::move(constants_array), + device_str, std::move(cubin_dir)); + } + + private: + std::unique_ptr kernels_; +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/model_container.h b/torch_npu/csrc/inductor/aoti_runtime/model_container.h new file mode 100644 index 0000000000000000000000000000000000000000..51a5172aff1e6832d6b2d1c357c12d4a876a0160 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/model_container.h @@ -0,0 +1,589 @@ +#pragma once + +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +namespace torch::aot_inductor { + +class AOTInductorModelContainer { + public: + AOTInductorModelContainer( + size_t num_models, const std::string &device_str, + const std::optional &cubin_dir = std::nullopt) { + constants_map_ = std::make_shared(); + constants_array_ = std::make_shared>(); + + models_.reserve(num_models); + available_models_.reserve(num_models); + for (size_t i = 0; i < num_models; ++i) { + models_.push_back(AOTInductorModel::Create( + constants_map_, constants_array_, device_str, cubin_dir)); + available_models_.push_back(models_.back().get()); + } + + // Note that the all following fields (input_names_, output_names, + // etc) can be filled in by the AOT + // codegen. However, we choose to query such information from + // the owned AOTInductorModel for a couple of reasons: + // * simplify the codegen templates + // * reduce information fragmentation and duplication + // * the initialization process below is done only once when the container + // is constructed, so it would have little performance impact + auto *model = available_models_[0]; + size_t num_inputs = model->num_inputs(); + input_names_.reserve(num_inputs); + for (size_t i = 0; i < num_inputs; i++) { + input_names_.emplace_back(model->input_name(static_cast(i))); + } + + size_t num_outputs = model->num_outputs(); + output_names_.reserve(num_outputs); + for (size_t i = 0; i < num_outputs; i++) { + output_names_.emplace_back(model->output_name(static_cast(i))); + } + model->load_constants(); + constant_blob_ = model->release_constant_blob(); + constants_internal_offset_.resize(model->num_constants()); + model->compute_constant_blob(blob_size_, constants_internal_offset_); + + for (auto &model : models_) { + model->update_constants_map(constants_map_); + } + + in_spec_ = model->get_in_spec(); + out_spec_ = model->get_out_spec(); + } + + void run(AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the + // array itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor) { + std::shared_lock model_lk(model_exec_mutex_); + auto *model = get_available_model(); + + if (!constant_folded_) { + // At this point, constant is not ready yet. We need to call constant + // folding before we execute the model. We obtain a unique lock at this + // point to make sure constant is ready for all. + model_lk.unlock(); + std::unique_lock constants_folding_lk(model_exec_mutex_); + // Double locking to make sure constant folding is only ran once. + if (!constant_folded_) { + auto folded_const_map = model->run_const_fold( + stream, proxy_executor, /* initialization = */ true); + update_constant_buffer(std::move(folded_const_map), + /* use_inactive = */ false, + /* validate_full_update = */ false); + constant_folded_ = true; + } + constants_folding_lk.unlock(); + model_lk.lock(); + } + + try { + model->run(input_handles, output_handles, stream, proxy_executor); + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + + { + std::lock_guard lk(models_mutex_); + pending_models_.push_back(model); + } + pending_models_available_.notify_one(); + } + + // Non-thread-aware variant of run(). Obviously unsafe to use in a threaded + // environment :) + void run_single_threaded( + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the array + // itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor) { + auto *model = available_models_[0]; + + if (!constant_folded_) { + auto folded_const_map = + model->run_const_fold(stream, proxy_executor, true); + update_constant_buffer(std::move(folded_const_map), false, false); + constant_folded_ = true; + } + + model->run_single_threaded(input_handles, output_handles, stream, + proxy_executor); + } + + size_t num_constants() const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->num_constants(); + } + + // retrieve the constant name of constants_info_[idx] + const char *constant_name(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_name(static_cast(idx)); + } + + // retrieve original FQN of constants_info_[idx] + const char *constant_original_fqn(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_original_fqn(static_cast(idx)); + } + + // retrieve whether constant is from folded of constants_info_[idx] + bool constant_from_folded(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_from_folded(static_cast(idx)); + } + + // retrieve type of constants_info_[idx] + int32_t constant_type(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_type(static_cast(idx)); + } + + // retrieve dtype of constants_info_[idx] + int32_t constant_dtype(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_dtype(static_cast(idx)); + } + + void run_const_fold(bool inactive_buffer, DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { + std::shared_lock model_lk(model_exec_mutex_); + auto *model = get_available_model(); + + if (!inactive_buffer) { + // We would need to acquire a unique lock if we want to run constant + // folding on the active buffer. + model_lk.unlock(); + std::unique_lock constants_folding_lk(model_exec_mutex_); + try { + auto folded_const_map = model->run_const_fold(stream, proxy_executor); + update_constant_buffer(std::move(folded_const_map), + /* use_inactive = */ false, + /* validate_full_update = */ false); + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + constants_folding_lk.unlock(); + model_lk.lock(); + } else { + // We swap the constant mapping to the inactive buffer in the model to run + // const run. + auto constants_map = get_constants_map(/* get_inactive= */ true); + auto constants_array = get_constants_array(/* get_inactive= */ true); + + try { + model->update_constants_map(constants_map, + /* remap_constants_array= */ false); + model->update_constants_array(constants_array); + + auto folded_const_map = model->run_const_fold(stream, proxy_executor); + update_constant_buffer(std::move(folded_const_map), + /* use_inactive = */ true, + /* validate_full_update = */ false); + + // Swap back the model's constants mapping + constants_map = get_constants_map(/* get_inactive= */ false); + constants_array = get_constants_array(/* get_inactive= */ false); + model->update_constants_map(constants_map, + /* remap_constants_array= */ false); + model->update_constants_array(constants_array); + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + } + + { + std::lock_guard lk(models_mutex_); + pending_models_.push_back(model); + } + pending_models_available_.notify_one(); + } + + bool _should_skip_update(const size_t idx) const { + auto constant_type = models_[0]->constant_type(static_cast(idx)); + // We should skip constants + return constant_type == ConstantType::TensorConstant; + } + + bool _could_skip_update(const size_t idx) const { + auto constant_type = models_[0]->constant_type(static_cast(idx)); + // Buffer can be optionally skipped, so if it not provided by upstream + // services, it is OK to relax the check. + return constant_type == ConstantType::Buffer; + } + + void assert_all_constants( + const std::unordered_map &constants_map) { + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + if (models_[0]->constant_from_folded(static_cast(idx))) { + continue; + } + + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end()) { + if (_should_skip_update(idx) || _could_skip_update(idx)) { + // tracing sometimes creates tensors that are non-existent in + // original graph. We could skip those and do a direct copy. + std::cerr << "[WARNING] Found constant or module state buffer " + << constant_name + << " in model, but not provided by user!\n"; + continue; + } + throw std::runtime_error(std::string("Cannot find constants ") + + constant_name + + std::string(" in constants_map!")); + } + } + } + + // We directly take ownership from AtenTensorHandle if constants are moved. + void update_constant_buffer( + std::unordered_map &&constants_map, + bool use_inactive, bool validate_full_update) { + if (this->num_models() == 0) { + throw std::runtime_error("No model available in container!"); + } + if (validate_full_update) { + assert_all_constants(constants_map); + } + + auto original_constants_map = get_constants_map(!use_inactive); + auto constants_map_to_update = get_constants_map(use_inactive); + + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end() && !use_inactive) { + continue; + } + + AtenTensorHandle tensor; + if (it == constants_map.end() && use_inactive) { + aoti_torch_clone( + original_constants_map->find(constant_name)->second.get(), &tensor); + } else { + tensor = it->second; + } + + constants_map_to_update->insert_or_assign(constant_name, tensor); + } + // Update the inactive constant array. + update_array_from_map(get_constants_array(use_inactive), + constants_map_to_update); + } + + // This function updates the buffer for storing constants. + // It will update the buffer, the mapping and the array mapping. + void update_constant_buffer( + const std::unordered_map &constants_map, + bool use_inactive, bool validate_full_update) { + if (this->num_models() == 0) { + throw std::runtime_error("No model available in container!"); + } + if (validate_full_update) { + assert_all_constants(constants_map); + } + + auto original_constants_map = get_constants_map(!use_inactive); + auto constants_map_to_update = get_constants_map(use_inactive); + + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end() && !use_inactive) { + continue; + } + + AtenTensorHandle tensor; + if (it == constants_map.end() && use_inactive) { + tensor = original_constants_map->find(constant_name)->second.get(); + } else { + tensor = it->second; + } + auto *constants_blob_ptr = + static_cast(get_constant_blob_ptr(use_inactive)); + + // Move the data to container handled blob. + uint8_t *internal_constants_ptr = + constants_blob_ptr + constants_internal_offset_[idx]; + void *user_constant_ptr; + int64_t constant_size; + aoti_torch_get_data_ptr(tensor, &user_constant_ptr); + aoti_torch_get_storage_size(tensor, &constant_size); +#if defined(USE_NPU) + AOTI_RUNTIME_DEVICE_CHECK( + aclrtMemcpy(internal_constants_ptr, constant_size, user_constant_ptr, + constant_size, ACL_MEMCPY_HOST_TO_DEVICE)); +#else + memcpy(internal_constants_ptr, user_constant_ptr, constant_size); +#endif + // Generate Tensor from container handled blob. + // We extract stride and offset from provided Tensor since we do not + // guarantee that the tensor is contiguous. + AtenTensorHandle tensor_handle; + int64_t *stride; + int64_t offset; + int device_type = models_[0]->get_device_type(); + int device_idx = models_[0]->get_device_idx(); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(tensor, &offset)); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu( + internal_constants_ptr, models_[0]->constant_ndim(idx), + models_[0]->constant_shape(idx), stride, offset, + models_[0]->constant_dtype(idx), device_type, device_idx, + &tensor_handle)); + + // Now place the tensor to constants_map. Note at this point the ownership + // of the tensor_handle will be taken over. + constants_map_to_update->insert_or_assign(constant_name, tensor_handle); + } + // Update the inactive constant array. + update_array_from_map(get_constants_array(use_inactive), + constants_map_to_update); + } + + void update_array_from_map( + const std::shared_ptr> &constants_array, + const std::shared_ptr &constants_map) { + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + if (constants_map->find(models_[0]->constant_name( + static_cast(idx))) != constants_map->end()) { + constants_array->at(idx) = ConstantHandle( + constants_map + ->find(models_[0]->constant_name(static_cast(idx))) + ->second); + } + } + } + + void swap_constant_buffer() { + std::lock_guard unique_lk(model_exec_mutex_); + + auto constants_map = get_constants_map(/* get_inactive= */ true); + auto constants_array = get_constants_array(/* get_inactive= */ true); + + for (auto &model : models_) { + model->update_constants_map(constants_map, + /* remap_constants_array = */ false); + model->update_constants_array(constants_array); + } + + use_secondary_ = !use_secondary_; + } + + size_t num_inputs() const { return input_names_.size(); } + + size_t num_outputs() const { return output_names_.size(); } + + const char *input_name(size_t idx) const { + return input_names_.at(idx).c_str(); + } + + const char *output_name(size_t idx) const { + return output_names_.at(idx).c_str(); + } + + size_t num_models() const { return models_.size(); } + + const char *get_in_spec() const { return in_spec_; } + + const char *get_out_spec() const { return out_spec_; } + + private: + std::vector input_names_; + std::vector output_names_; + const char *in_spec_; + const char *out_spec_; + + // Holds the blob storage for constants' at::Tensor within the container. + // This blob of memory will be managed by the container. + RAIIDataPtr constant_blob_; + RAIIDataPtr constant_blob_secondary_; + + size_t blob_size_; + std::vector constants_internal_offset_; + + // Determine which constants is being used for the model. + // If true, + // constants_map_secondary/constant_blob_secondary/constants_array_secondary + // is being used. + bool use_secondary_{false}; + + // Determine whether we have ran constant folding + bool constant_folded_{false}; + + // Holds the mapping of constants to at::Tensor. + std::shared_ptr constants_map_; + std::shared_ptr constants_map_secondary_; + + // Holds the indexed array of constant for faster lookup during runtime. + std::shared_ptr> constants_array_; + std::shared_ptr> constants_array_secondary_; + + // Holds all the AOTInductorModel instances owned by this container. + std::vector> models_; + + // Holds the AOTInductorModel instances available for inference. + std::vector available_models_; + + // Holds the AOTInductorModel instances that have started running + // inference and can be placed onto available_models_ upon their + // completion. + std::deque pending_models_; + + // Protects available_models_ and pending_models_. + std::mutex models_mutex_; + + // Notified whenever a model is placed onto pending_models_. + std::condition_variable pending_models_available_; + + AOTInductorModel *get_available_model() { + std::unique_lock lk(models_mutex_); + if (available_models_.empty()) { + reclaim_finished_models(lk); + } + auto *result = available_models_.back(); + available_models_.pop_back(); + return result; + } + + // This mutex is used to protect execution of model. + // We acquire the mutex in shared mode if we allow concurrent execution. + // We acquire the mutex in unique mode when we want exclusive access of the + // model. One such case is when we want to do a weight swapping. We want to + // make sure no one is executing the model. + std::shared_mutex model_exec_mutex_; + + void *get_constant_blob_ptr(bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + return constant_blob_.get(); + } else { + if (!constant_blob_secondary_) { +#if defined(USE_NPU) + constant_blob_secondary_ = RAII_npuMalloc(blob_size_); +#else + constant_blob_secondary_ = RAII_cpuMalloc(blob_size_); +#endif // USE_NPU + } + return constant_blob_secondary_.get(); + } + } + + std::shared_ptr get_constants_map(bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + return constants_map_; + } else { + if (!constants_map_secondary_) { + constants_map_secondary_ = std::make_shared(); + } + return constants_map_secondary_; + } + } + + std::shared_ptr> get_constants_array( + bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + return constants_array_; + } else { + if (!constants_array_secondary_) { + constants_array_secondary_ = + std::make_shared>( + models_[0]->num_constants()); + } + return constants_array_secondary_; + } + } + + void reclaim_finished_models(std::unique_lock &lk) { +#ifdef __aarch64__ + // push finished model instances to the end of pending_models_ + auto it = + std::partition(pending_models_.begin(), pending_models_.end(), + [](AOTInductorModel *m) { return !m->is_finished(); }); +#else + // push finished model instances to the end of pending_models_ + auto it = std::stable_partition( + pending_models_.begin(), pending_models_.end(), + [](AOTInductorModel *m) { return !m->is_finished(); }); +#endif + + if (it != pending_models_.end()) { + // We have finished model instances that can be pushed into + // available_models_ so that we don't have to be blocked on waiting + // the pending_models_available_ condition. + available_models_.insert(available_models_.end(), it, + pending_models_.end()); + pending_models_.erase(it, pending_models_.end()); + return; + } + + pending_models_available_.wait( + lk, [this]() { return !pending_models_.empty(); }); + // Let's make the schedule simple first. We always wait on the first + // pending_models_ to be complete. + auto *model = pending_models_.front(); + pending_models_.pop_front(); + lk.unlock(); + try { + model->wait_for_completion(); + } catch (...) { + lk.lock(); + available_models_.push_back(model); + throw; + } + lk.lock(); + available_models_.push_back(model); + } +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/scalar_to_tensor.h b/torch_npu/csrc/inductor/aoti_runtime/scalar_to_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..3f59d40711687e04ad50a57b7832a821d9154df1 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/scalar_to_tensor.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace torch::aot_inductor { + +template +inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) { + throw std::runtime_error("Unsupported scalar_to_tensor_handle"); +} + +// Specialize for supported C++ primitive types +#define AOTI_RUNTIME_SCALAR_TO_TENSOR(dtype, ctype) \ + template <> \ + inline RAIIAtenTensorHandle scalar_to_tensor_handle(ctype value) { \ + AtenTensorHandle tensor_handle; \ + AOTI_TORCH_ERROR_CODE_CHECK( \ + aoti_torch_scalar_to_tensor_##dtype(value, &tensor_handle)); \ + return RAIIAtenTensorHandle(tensor_handle); \ + } + +AOTI_RUNTIME_SCALAR_TO_TENSOR(float32, float) +AOTI_RUNTIME_SCALAR_TO_TENSOR(float64, double) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint8, uint8_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint16, uint16_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint32, uint32_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint64, uint64_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int8, int8_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int16, int16_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int32, int32_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int64, int64_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(bool, bool) +AOTI_RUNTIME_SCALAR_TO_TENSOR(complex64, c10::complex) +AOTI_RUNTIME_SCALAR_TO_TENSOR(complex128, c10::complex) +#undef AOTI_RUNTIME_SCALAR_TO_TENSOR + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/thread_local.h b/torch_npu/csrc/inductor/aoti_runtime/thread_local.h new file mode 100644 index 0000000000000000000000000000000000000000..303f34c3c789c8378c1319867c6c1cdd07603b60 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/thread_local.h @@ -0,0 +1,144 @@ +#pragma once + +#include + +namespace torch::aot_inductor { + +template +struct ThreadLocalCachedOutputTensor; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle &) {} + void copy_data_from(const RAIIAtenTensorHandle &handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } +}; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle &) {} + void copy_data_from(const AtenTensorHandle &handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } +}; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const ConstantHandle &) {} + void copy_data_from(const ConstantHandle &handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } +}; + +template +struct ThreadLocalCachedOutputTensor> { + explicit ThreadLocalCachedOutputTensor(const ArrayRefTensor &t) { + realloc(t); + } + + void copy_data_from(const ArrayRefTensor &t) { + if (t.numel() > capacity_) { + realloc(t); + } + std::copy(t.data(), t.data() + t.numel(), storage_.get()); + } + + AtenTensorHandle tensor() const { return tensor_.get(); } + + private: + void realloc(const ArrayRefTensor &t) { + capacity_ = t.numel(); + // NOLINTNEXTLINE(*arrays*) + storage_ = std::make_unique(t.numel()); + AtenTensorHandle handle = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu( + storage_.get(), t.sizes().size(), t.sizes().data(), t.strides().data(), + 0, aoti_torch_dtype>(), t.device_type(), + t.device_idx(), &handle)); + tensor_ = handle; + } + + // NOLINTNEXTLINE(*arrays*) + std::unique_ptr storage_; + int64_t capacity_ = 0; + RAIIAtenTensorHandle tensor_; +}; + +template +struct ThreadLocalCachedOutputArray; + +// Just needs to compile, doesn't need to do anything. +template <> +struct ThreadLocalCachedOutputArray { + explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle &) { + throw std::runtime_error("can't happen"); + } + + // Not supported yet! We would need to put contiguous() or + // expect_contiguous() into the ABI. + void copy_data_from(const RAIIAtenTensorHandle &) { + throw std::runtime_error("can't happen"); + } + + template + ArrayRefTensor arrayref_tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +// Just needs to compile, doesn't need to do anything. +template <> +struct ThreadLocalCachedOutputArray { + explicit ThreadLocalCachedOutputArray(const ConstantHandle &) { + throw std::runtime_error("can't happen"); + } + + // Not supported yet! We would need to put contiguous() or + // expect_contiguous() into the ABI. + void copy_data_from(const ConstantHandle &) { + throw std::runtime_error("can't happen"); + } + + template + ArrayRefTensor arrayref_tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +template +struct ThreadLocalCachedOutputArray> { + explicit ThreadLocalCachedOutputArray(const ArrayRefTensor &t) {} + + template , + std::remove_const_t>, + bool> = true> + ArrayRefTensor arrayref_tensor() const { + return tensor_; + } + + void copy_data_from(const ArrayRefTensor &t) { + if (t.numel() > capacity_) { + capacity_ = t.numel(); + // NOLINTNEXTLINE(*arrays*) + storage_ = std::make_unique(capacity_); + } + std::copy(t.data(), t.data() + t.numel(), storage_.get()); + tensor_ = t; + tensor_.set_arrayref(MiniArrayRef(storage_.get(), t.numel())); + } + + private: + // NOLINTNEXTLINE(*arrays*) + std::unique_ptr storage_; + uint32_t capacity_ = 0; + ArrayRefTensor tensor_; +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/utils.h b/torch_npu/csrc/inductor/aoti_runtime/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..72c78e7ac176b5ed6eb7b30b578fd4e0269f76b0 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/utils.h @@ -0,0 +1,234 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +#if defined(__GNUC__) || defined(__clang__) +#define AOTI_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define AOTI_NOINLINE __declspec(noinline) +#else +#define AOTI_NOINLINE +#endif + +AOTI_NOINLINE static void throw_exception(const char *call, const char *file, + int64_t line) { + std::stringstream ss; + ss << call << " API call failed at " << file << ", line " << line; + throw std::runtime_error(ss.str()); +} + +#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_TORCH_SUCCESS) { \ + throw_exception(#call, __FILE__, __LINE__); \ + } + +using AOTIRuntimeError = int32_t; +#define AOTI_RUNTIME_SUCCESS 0 +#define AOTI_RUNTIME_FAILURE 1 + +#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_RUNTIME_SUCCESS) { \ + throw_exception(#call, __FILE__, __LINE__); \ + } + +namespace torch::aot_inductor { + +using DeleterFnPtr = void (*)(void *); + +inline void noop_deleter(void *) {} + +inline void delete_tensor_object(void *ptr) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_delete_tensor_object(reinterpret_cast(ptr))); +} + +// RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI +class RAIIAtenTensorHandle { + public: + RAIIAtenTensorHandle() : handle_(nullptr, noop_deleter) {} + RAIIAtenTensorHandle(const RAIIAtenTensorHandle &other) = delete; + RAIIAtenTensorHandle &operator=(const RAIIAtenTensorHandle &other) = delete; + + // Steal the ownership from another RAIIAtenTensorHandle using std::move + RAIIAtenTensorHandle(RAIIAtenTensorHandle &&other) = default; + RAIIAtenTensorHandle &operator=(RAIIAtenTensorHandle &&other) = default; + + // Steal the ownership from raw AtenTensorHandle + RAIIAtenTensorHandle(AtenTensorHandle handle) + : handle_(handle, delete_tensor_object) {} + + ~RAIIAtenTensorHandle() { handle_.reset(); } + + // Return a raw AtenTensorHandle to be used by aoti_torch functions + // Note: this function does NOT transfer the ownership of the handle + operator AtenTensorHandle() const { return handle_.get(); } + + AtenTensorHandle release() { return handle_.release(); } + + AtenTensorHandle get() const { return handle_.get(); } + + void reset() { handle_.reset(); } + + int64_t size(int64_t d) { + int64_t size = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_.get(), d, &size)); + return size; + } + + int64_t stride(int64_t d) { + int64_t stride = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_stride(handle_.get(), d, &stride)); + return stride; + } + + int64_t storage_offset() { + int64_t storage_offset = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(handle_.get(), &storage_offset)); + return storage_offset; + } + + void *data_ptr() const { + void *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_data_ptr(handle_.get(), &result)); + return result; + } + + int64_t *sizes() const { + int64_t *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_.get(), &result)); + return result; + } + + int64_t *strides() const { + int64_t *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_.get(), &result)); + return result; + } + + private: + std::unique_ptr handle_; +}; + +// Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle +inline std::vector steal_from_raw_handles_to_raii_handles( + AtenTensorHandle *handles, size_t size) { + std::vector result; + result.reserve(size); + for (size_t i = 0; i < size; i++) { + result.emplace_back(handles[i]); + handles[i] = nullptr; + } + return result; +} + +inline AtenTensorHandle reinterpret_tensor_wrapper(AtenTensorHandle self, + int64_t ndim, + const int64_t *sizes_ptr, + const int64_t *strides_ptr, + int64_t storage_offset) { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor( + self, ndim, sizes_ptr, strides_ptr, storage_offset, &result)); + return result; +} + +inline void *get_data_ptr_wrapper(AtenTensorHandle tensor) { + void *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(tensor, &result)); + return result; +} + +inline AtenTensorHandle unwrap_raii_handle_if_needed( + const RAIIAtenTensorHandle &handle) { + return handle.get(); +} + +inline RAIIAtenTensorHandle wrap_with_raii_handle_if_needed( + AtenTensorHandle handle) { + return RAIIAtenTensorHandle(handle); +} + +class ConstantHandle { + public: + ConstantHandle() = default; + + explicit ConstantHandle(AtenTensorHandle handle) : handle_(handle) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &data_)); + } + + operator AtenTensorHandle() const { return handle_; } + + AtenTensorHandle tensor() const { return handle_; } + + AtenTensorHandle get() const { return handle_; } + + void *data_ptr() const { return data_; } + + int64_t *sizes() const { + int64_t *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_, &result)); + return result; + } + + int64_t *strides() const { + int64_t *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_, &result)); + return result; + } + + private: + AtenTensorHandle handle_{}; + void *data_ = nullptr; +}; + +inline void *get_data_ptr_wrapper(const ConstantHandle &constant) { + return constant.data_ptr(); +} + +inline const ConstantHandle &unwrap_raii_handle_if_needed( + const ConstantHandle &handle) { + return handle; +} + +// Shouldn't be called. +inline AtenTensorHandle wrap_with_raii_handle_if_needed( + const ConstantHandle &handle) = delete; + +// DANGEROUS. Do not call unless you explicitly intend to get a reference to a +// temporary value, which will expire at the end of the current expression. +// This should only be called in cases where the C-shim API expects an optional +// input argument (passed by pointer), and a temporary needs to be passed to it. +template +T &temporary_reference(T &&t) { + return t; +} + +#define CACHE_TORCH_DTYPE(typename) \ + static auto cached_torch_dtype_##typename = aoti_torch_dtype_##typename() + +#define CACHE_TORCH_DEVICE(device) \ + static auto cached_torch_device_type_##device = \ + aoti_torch_device_type_##device() + +#define CACHE_TORCH_LAYOUT(layout) \ + static auto cached_torch_layout_##layout = aoti_torch_layout_##layout() + +#define CACHE_TORCH_MEMORY_FORMAT(format) \ + static auto cached_torch_memory_format_##format = \ + aoti_torch_memory_format_##format() + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/c/shim.h b/torch_npu/csrc/inductor/aoti_torch/c/shim.h new file mode 100644 index 0000000000000000000000000000000000000000..e86fc28330c8bcbd9ef86df6a98b6b3c48ddfc37 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/c/shim.h @@ -0,0 +1,653 @@ +#ifndef AOTI_TORCH_SHIM +#define AOTI_TORCH_SHIM + +#include +#include + +// This header defines a stable C API for certain ATen functionality in +// libtorch. The AOTInductor compiled model.so will only refer to this header +// instead of other headers from aten/c10, which means it will NOT be able to +// directly use any data structures or call functions from libtorch. +// +// What problems are we trying to solve here? Direct use of aten/c10 APIs +// means use of C++ APIs on a library that doesn't have any ABI compatibility +// guarantees. However, we want model.so to remain usable across updates +// to the PyTorch C++ libraries, which requires a stable ABI. By introducing +// a C shim layer, we can minimize the surface that will cause breakage. The +// corresponding software stack can be illustrated as follows: +// +// |--------------------------------| +// | inference service code | +// |--------------------------------| +// | model.so | +// |--------------|-----------------| +// | | +// | libtorch.so | +// |--------------------------------| +// +// The general guidelines for the C API: +// +// - No exceptions, return an explicit error code to be checked at call site +// - Only pointers (AtenTensorHandle counts), integers and floats in headers +// +// If you want to make changes to this header, you MUST MAINTAIN ABI +// compatibility. Typically, this means you will have to add a _v2 version +// of a function that you, e.g., want to add a new function parameter to, and +// maintain the old and new versions of the APIs until all old model.so +// go out of use. + +#ifdef __GNUC__ +#define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) +#else // !__GNUC__ +#ifdef _WIN32 +// PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead +// to symbol clashes at link time if libtorch is included in a DLL and binary +// that depends on the DLL. As a short term fix, we don't export the symbols. +// In the long term, this will need to be addressed when Windows is supported. +#ifdef OVRSOURCE +// Do not export AOTI on Windows for internal builds +#define AOTI_TORCH_EXPORT +#else /* OVRSOURCE */ +#ifdef EXPORT_AOTI_FUNCTIONS +#define AOTI_TORCH_EXPORT __declspec(dllexport) +#else +#define AOTI_TORCH_EXPORT __declspec(dllimport) +#endif +#endif /* OVRSOURCE */ +#else // !_WIN32 +#define AOTI_TORCH_EXPORT +#endif // _WIN32 +#endif // __GNUC__ + +// The following files are implemented in a header-only way and are guarded by +// test/cpp/aoti_abi_check +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// AtenTensorHandle represents an abstract notion of Tensor that can be passed +// between model.so and libtorch.so. The contents of the structure itself +// are private; model.so is not allowed to access any fields directly, it must +// go through functions defined in this ABI. Under the hood, this is +// represented as at::Tensor*, but we reserve the right to change this (and in +// fact, we probably should change it to at::TensorImpl* at least). +// +// An AtenTensorHandle can be owning (please check the API reference for exact +// ownership/borrow semantics). If you have an owning AtenTensorHandle +// in model.so, you are obligated to aoti_torch_delete_tensor_object when you +// are done. You can use the helper C++ class RAIIAtenTensorHandle +// (see aot_runtime/model.h) to ensure the deallocator is called in RAII style +// (note that RAIIAtenTensorHandle is private to model.so, and never crosses +// the ABI boundary.) +struct AtenTensorOpaque; +using AtenTensorHandle = AtenTensorOpaque *; + +struct AtenGeneratorOpaque; +using AtenGeneratorHandle = AtenGeneratorOpaque *; + +struct AOTIProxyExecutorOpaque; +using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque *; + +using AOTITorchError = int32_t; +#define AOTI_TORCH_SUCCESS 0 +#define AOTI_TORCH_FAILURE 1 + +// Getter functions for retrieving various constants from the runtime, that +// can subsequently be passed to other aoti_* functions. By hiding these +// behind functions, the precise value of device/dtype is NOT part of the +// ABI contract. (In practice, aten/c10 is pretty good about not renumbering +// these, so we probably could later switch to having these in the ABI, if +// desired for perf reasons.) +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_meta(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_npu(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1(); + +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2fnuz(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fnuz(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint8(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int8(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128(); + +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_coo(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_csr(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_csc(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_bsr(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_bsc(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_jagged(); + +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_contiguous_format(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_channels_last(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_channels_last_3d(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_preserve_format(); + +// Get TORCH_ABI_VERSION of the built libtorch.so +AOTI_TORCH_EXPORT uint64_t aoti_torch_abi_version(); + +// Functions for converting a single-element tensor to a scalar value +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float16(AtenTensorHandle tensor, c10::Half *ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float32(AtenTensorHandle tensor, float *ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float64(AtenTensorHandle tensor, double *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_uint8(AtenTensorHandle tensor, + uint8_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_uint16(AtenTensorHandle tensor, + uint16_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_uint32(AtenTensorHandle tensor, + uint32_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_uint64(AtenTensorHandle tensor, + uint64_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_int8(AtenTensorHandle tensor, + int8_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_int16(AtenTensorHandle tensor, + int16_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_int32(AtenTensorHandle tensor, + int32_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_int64(AtenTensorHandle tensor, + int64_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_bool(AtenTensorHandle tensor, + bool *ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16 *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64( + AtenTensorHandle tensor, c10::complex *ret_value); + +// Functions for wrapping a scalar value to a single-element tensor +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32( + float value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float64( + double value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint8( + uint8_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint16( + uint16_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint32( + uint32_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint64( + uint64_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int8( + int8_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int16( + int16_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int32( + int32_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int64( + int64_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_scalar_to_tensor_bool(bool value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex64( + c10::complex value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex128( + c10::complex value, AtenTensorHandle *ret_new_tensor); + +AOTI_TORCH_EXPORT bool aoti_torch_grad_mode_is_enabled(); +AOTI_TORCH_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled); + +// Free the tensor object +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_tensor_object(AtenTensorHandle tensor); + +// Get a pointer to the underlying storage data +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_data_ptr( + AtenTensorHandle tensor, + void **ret_data_ptr // returns borrowed reference +); + +// Get the nbytes of the underlying storage +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_storage_size(AtenTensorHandle tensor, int64_t *ret_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_dim(AtenTensorHandle tensor, + int64_t *ret_dim); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_numel(AtenTensorHandle tensor, + int64_t *ret_numel); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_storage_numel(AtenTensorHandle tensor, int64_t *ret_numel); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_sizes( + AtenTensorHandle tensor, + int64_t **ret_sizes // returns borrowed reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_size(AtenTensorHandle tensor, + int64_t d, + int64_t *ret_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_strides( + AtenTensorHandle tensor, + int64_t **ret_strides // returns borrowed reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_stride(AtenTensorHandle tensor, + int64_t d, + int64_t *ret_stride); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_dtype(AtenTensorHandle tensor, + int32_t *ret_dtype); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t *ret_device_type); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t *ret_device_index); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( + AtenTensorHandle tensor, int64_t *ret_storage_offset); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle( + AtenTensorHandle orig_handle, AtenTensorHandle *new_handle); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__alloc_from_pool( + AtenTensorHandle self, int64_t offset_bytes, int32_t dtype, int64_t ndim, + const int64_t *sizes_ptr, const int64_t *strides_ptr, + AtenTensorHandle *ret_new_tensor); + +// This function will create a new tensor object and its pointer is returned +// through *out. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( + AtenTensorHandle self, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, + AtenTensorHandle *ret_new_tensor // returns new reference +); + +// This function will create a new tensor object and its pointer is returned +// through *out. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided( + int64_t ndim, const int64_t *sizes_ptr, const int64_t *strides_ptr, + int32_t dtype, int32_t device_type, int32_t device_index, + AtenTensorHandle *ret_new_tensor // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_as_strided(AtenTensorHandle self, const int64_t *sizes_ptr, + const int64_t *strides_ptr, AtenTensorHandle *ret); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret, // returns new reference + int32_t layout, const uint8_t *opaque_metadata, + int64_t opaque_metadata_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_npu( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_npu_v2( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret, // returns new reference + int32_t layout, const uint8_t *opaque_metadata, + int64_t opaque_metadata_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( + AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, + int32_t scale_grad_by_freq, int32_t mode, int32_t sparse, + AtenTensorHandle per_sample_weights, // optional argument + int32_t include_last_offset, int32_t padding_idx, + AtenTensorHandle *ret0, // returns new reference + AtenTensorHandle *ret1, // returns new reference + AtenTensorHandle *ret2, // returns new reference + AtenTensorHandle *ret3 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( + AtenTensorHandle self, const int64_t *dim_ptr, int64_t dim_size, + int64_t normalization, int32_t forward, + AtenTensorHandle *ret // returns new reference +); + +// This version is deprecated. We will remove it later +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, + double dropout_p, bool is_causal, bool return_debug_mask, double scale, + AtenTensorHandle *ret0, // returns new reference + AtenTensorHandle *ret1, // returns new reference + AtenTensorHandle *ret2, // returns new reference + AtenTensorHandle *ret3, // returns new reference + int64_t *ret4, int64_t *ret5, + AtenTensorHandle *ret6, // returns new reference + AtenTensorHandle *ret7, // returns new reference + AtenTensorHandle *ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_flash_attention_v2( + AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, + double dropout_p, int is_causal, int return_debug_mask, + double *scale, // optional argument + AtenTensorHandle *ret0, // returns new reference + AtenTensorHandle *ret1, // returns new reference + AtenTensorHandle *ret2, // returns new reference + AtenTensorHandle *ret3, // returns new reference + int64_t *ret4, int64_t *ret5, + AtenTensorHandle *ret6, // returns new reference + AtenTensorHandle *ret7, // returns new reference + AtenTensorHandle *ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_efficient_attention( + AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, + AtenTensorHandle attn_bias, // optional argument + int compute_log_sumexp, double dropout_p, int is_causal, + double *scale, // optional argument + AtenTensorHandle *ret0, // returns new reference + AtenTensorHandle *ret1, // returns new reference + AtenTensorHandle *ret2, // returns new reference + AtenTensorHandle *ret3 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( + AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle bias, + int32_t *out_dtype, AtenTensorHandle scale_a, AtenTensorHandle scale_b, + AtenTensorHandle scale_result, int8_t use_fast_accum, + AtenTensorHandle *ret0, AtenTensorHandle *ret1); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( + AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, + AtenTensorHandle scale_b, AtenTensorHandle bias, + AtenTensorHandle scale_result, int32_t *out_dtype, int8_t use_fast_accum, + AtenTensorHandle *ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( + AtenTensorHandle input, AtenTensorHandle weight, + AtenTensorHandle bias, // optional argument + const int64_t *stride_ptr, int64_t stride_size, const int64_t *padding_ptr, + int64_t padding_size, const int64_t *dilation_ptr, int64_t dilation_size, + int transposed, const int64_t *output_padding_ptr, + int64_t output_padding_size, int64_t groups, + AtenTensorHandle *ret // returns new reference +); + +// This function will create a new uninitialized tensor object +// and its pointer is returned through *ret. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_new_uninitialized_tensor(AtenTensorHandle *ret); + +// WARNING: This will be deprecated. Use aoti_torch_copy_ instead. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_tensor_copy_(AtenTensorHandle src, + AtenTensorHandle dst); + +// Make the tensor referred to by dst an alias for the tensor referred +// to by src. The two tensors must still be deleted with +// aoti_torch_delete_tensor separately (or not) as before the call. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_assign_tensors(AtenTensorHandle src, AtenTensorHandle dst); + +// Make a shallow copy of the tensor referred to by src and assign +// it to the handle in the ret_dst. This is similar to the above +// aoti_torch_assign_tensors function, but creates and sets the +// ret_dst from within. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_assign_tensors_out(AtenTensorHandle src, AtenTensorHandle *ret_dst); + +// This function will create a new tensor object and its pointer is returned +// through *ret. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_clone(AtenTensorHandle self, + AtenTensorHandle *ret); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_clone_preserve_strides(AtenTensorHandle self, AtenTensorHandle *ret); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out(AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat1, + AtenTensorHandle mat2, + float beta, float alpha); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out(AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_(AtenTensorHandle self, + AtenTensorHandle src, + int32_t non_blocking); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out(AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out( + AtenTensorHandle out, AtenTensorHandle a, AtenTensorHandle b, + AtenTensorHandle c, AtenTensorHandle d); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16(AtenTensorHandle weight, + AtenTensorHandle *out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( + AtenTensorHandle weight, AtenTensorHandle weight_scale, + AtenTensorHandle weight_zero_point, AtenTensorHandle bias, + AtenTensorHandle *out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(AtenTensorHandle input, + AtenTensorHandle weight, + AtenTensorHandle bias, + int64_t out_channel, + AtenTensorHandle *out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu__wrapped_quantized_linear_prepacked( + AtenTensorHandle input, AtenTensorHandle input_scale, + AtenTensorHandle input_zero_point, AtenTensorHandle weight, + AtenTensorHandle out_scale, AtenTensorHandle out_zeropoint, + int64_t out_channel, AtenTensorHandle *out); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_nonzero(AtenTensorHandle self, + AtenTensorHandle *out); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( + AtenTensorHandle repeats, int64_t *output_size, AtenTensorHandle *out); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_check_inf_and_nan(const char *tensor_name, AtenTensorHandle tensor); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out(AtenTensorHandle out, + AtenTensorHandle self, + int64_t dim, + AtenTensorHandle index, + AtenTensorHandle src); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_reduce_out( + AtenTensorHandle out, AtenTensorHandle self, int64_t dim, + AtenTensorHandle index, AtenTensorHandle src, const char *reduce, + int32_t include_self); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_index_put_out( + AtenTensorHandle out, AtenTensorHandle self, + const AtenTensorHandle *indices, const uint32_t num_indices, + const AtenTensorHandle values, bool accumulate); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( + AtenTensorHandle self, + AtenTensorHandle *ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( + AtenTensorHandle self, int32_t dtype, + AtenTensorHandle *ret // returns new reference +); + +AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle(AtenTensorHandle self, + const char *msg); + +// When AOTI debug printer option is enabled, this function will be invoked to +// torch pickle save the intermediate tensor for debugging purpose. +AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(AtenTensorHandle self, + const char *tensor_name, + const char *launch_prefix, + const char *kernel_name); + +// helpers for converting between StableIValue and actual IValues +using StableIValue = uint64_t; + +class TorchLibraryOpaque; +using TorchLibraryHandle = TorchLibraryOpaque *; + +// stable corollary to torch::Library constructor with Kind::IMPL +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_impl( + const char *ns, const char *k, const char *file, uint32_t line, + TorchLibraryHandle *ret_new_torch_lib); + +// stable corollary to torch::Library constructor with Kind::DEF +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_library_init_def(const char *ns, const char *file, uint32_t line, + TorchLibraryHandle *ret_new_torch_lib); + +// stable corollary to torch::Library constructor with Kind::FRAGMENT +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_fragment( + const char *ns, const char *file, uint32_t line, + TorchLibraryHandle *ret_new_torch_lib); + +// stable corollary to torch::Library method m.impl(), should be +// called from StableLibrary +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_library_impl(TorchLibraryHandle self, const char *name, + void (*fn)(StableIValue *, uint64_t, uint64_t)); + +// stable corollary to torch::Library method m.def(), should be +// called from StableLibrary +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_def(TorchLibraryHandle self, + const char *schema); + +// the above stable constructors for torch::Library add Library objects +// to the heap. if you are calling those functions directly, please use +// this function to free the Library's memory. The more user friendly +// alternative is to use StableLibrary, which will free its handle upon +// destruction +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_library_object(TorchLibraryHandle tlh); + +// calls the op overload defined by a given opName, overloadName, and a +// stack of StableIValues. This call will populate any return values of the +// op into the stack in their StableIValue form, with ret0 at index 0, ret1 +// at index 1, and so on. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher( + const char *opName, const char *overloadName, StableIValue *stack); + + +// See `ProxyExecutor Design Note` in ir.py for more details +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function( + AOTIProxyExecutorHandle proxy_executor, int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args); + +AOTI_TORCH_EXPORT void aoti_torch_check(bool cond, const char *func, + const char *file, uint32_t line, + const char *msg); + +#ifdef STRIP_ERROR_MESSAGES +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check(false, __func__, __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ + } +#else +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check(false, __func__, __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ + } +#endif + +AOTI_TORCH_EXPORT void aoti_torch_warn(const char *func, const char *file, + uint32_t line, const char *msg); + +#ifdef DISABLE_WARN +#define AOTI_TORCH_WARN(...) ((void)0); +#else +#define AOTI_TORCH_WARN(...) \ + aoti_torch_warn(__func__, __FILE__, static_cast(__LINE__), \ + #__VA_ARGS__); +#endif + +#ifdef __cplusplus +} // extern "C" + +template +int32_t aoti_torch_dtype() = delete; + +#define DEFINE_DTYPE_SPECIALIZATION(ctype, typename) \ + template <> \ + inline int32_t aoti_torch_dtype() { \ + return aoti_torch_dtype_##typename(); \ + } + +namespace c10 { +struct BFloat16; +struct Half; +} // namespace c10 + +DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16) +DEFINE_DTYPE_SPECIALIZATION(c10::Half, float16) +DEFINE_DTYPE_SPECIALIZATION(c10::complex, complex64) +DEFINE_DTYPE_SPECIALIZATION(float, float32) +DEFINE_DTYPE_SPECIALIZATION(double, float64) +DEFINE_DTYPE_SPECIALIZATION(uint8_t, uint8) +DEFINE_DTYPE_SPECIALIZATION(int8_t, int8) +DEFINE_DTYPE_SPECIALIZATION(int16_t, int16) +DEFINE_DTYPE_SPECIALIZATION(int32_t, int32) +DEFINE_DTYPE_SPECIALIZATION(int64_t, int64) +DEFINE_DTYPE_SPECIALIZATION(bool, bool) + +#endif +#endif // AOTI_TORCH_SHIM diff --git a/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.cpp b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7495569f8e3ad6dce495ef23610d63bfd456ed36 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.cpp @@ -0,0 +1,545 @@ +#include +#include + +#include +#include +#include +#include + +namespace { +at::Tensor *tensor_handle_to_tensor_pointer(AtenTensorHandle handle) { + return reinterpret_cast(handle); +} +} // namespace + +namespace torch::aot_inductor { + +void OSSProxyExecutor::prefill_stack_with_static_arguments( + size_t index, const at::TypePtr &schema_arg_type, + const nlohmann::json &serialized_arg, OSSOpKernel &op_kernel) { + auto &stack = op_kernel.stack_; + auto &dynamic_args = op_kernel.dynamic_args_; + + TORCH_CHECK(serialized_arg.size() == 1); + std::string serialized_arg_type = serialized_arg.begin().key(); + auto &serialized_arg_val = serialized_arg.begin().value(); + + switch (schema_arg_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_arg_type == "as_tensor", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_tensor for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 1); + break; + } + case c10::TypeKind::IntType: { + TORCH_CHECK(serialized_arg_type == "as_int", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_int for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + break; + } + case c10::TypeKind::SymIntType: { + TORCH_CHECK(serialized_arg_type == "as_int" || + serialized_arg_type == "as_sym_int", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_int or as_sym_int for " + "argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + break; + } + case c10::TypeKind::FloatType: { + TORCH_CHECK(serialized_arg_type == "as_float", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_float for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::BoolType: { + TORCH_CHECK(serialized_arg_type == "as_bool", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_bool for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::NumberType: { + if (serialized_arg_type == "as_int") { + // Only int Scalar is treated as dynamic arg for now + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + } else if (serialized_arg_type == "as_float") { + stack.at(index) = serialized_arg_val.get(); + } else if (serialized_arg_type == "as_bool") { + stack.at(index) = serialized_arg_val.get(); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a scalar input for argument ", index, " but got ", + serialized_arg_type); + } + break; + } + case c10::TypeKind::StringType: { + TORCH_CHECK(serialized_arg_type == "as_string", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_string for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::ScalarTypeType: { + TORCH_CHECK( + serialized_arg_type == "as_scalar_type", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_scalar_type for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::MemoryFormatType: { + TORCH_CHECK( + serialized_arg_type == "as_memory_format", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_memory_format for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::LayoutType: { + TORCH_CHECK(serialized_arg_type == "as_layout", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_layout for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::DeviceObjType: { + TORCH_CHECK(serialized_arg_type == "as_device", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_device for argument ", + index, " but got ", serialized_arg_type); + + std::string device_string = serialized_arg_val["type"].get(); + if (serialized_arg_val.contains("index") && + serialized_arg_val["index"].is_number()) { + device_string += ":" + serialized_arg_val["index"].get(); + } + + c10::Device device(device_string); + + if (device != *device_) { + VLOG(1) << "ProxyExecutor is using " << *device_ << " for " + << op_kernel.target_ << " argument #" << index + << ", which is different from the one serialized in thrift: " + << device << ". Please ensure this is intentional."; + } + + stack.at(index) = *device_; + break; + } + case c10::TypeKind::ListType: { + if (schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK( + serialized_arg_type == "as_tensors", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_tensors for argument ", index, + " but got ", serialized_arg_type); + TORCH_CHECK(serialized_arg_type == "as_tensors"); + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofInts())) { + TORCH_CHECK(serialized_arg_type == "as_ints", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_ints for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { + TORCH_CHECK(serialized_arg_type == "as_ints" || + serialized_arg_type == "as_sym_ints", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_ints or as_sym_ints " + "for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofFloats())) { + TORCH_CHECK(serialized_arg_type == "as_floats", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_floats for argument ", + index, " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofBools())) { + TORCH_CHECK(serialized_arg_type == "as_bools", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_bools for argument ", + index, " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofNumbers())) { + if (serialized_arg_type == "as_ints") { + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (serialized_arg_type == "as_floats") { + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg); + } + stack.at(index) = std::move(ret); + } else if (serialized_arg_type == "as_bools") { + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg); + } + stack.at(index) = std::move(ret); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a List[Scalar] input for argument ", index, + " but got ", serialized_arg_type); + } + } else if (schema_arg_type->isSubtypeOf( + at::ListType::ofOptionalTensors())) { + if (serialized_arg_type == "as_optional_tensors") { + std::vector list_item_types; + for (const auto &arg : serialized_arg_val) { + list_item_types.push_back(arg.begin().key()); + } + dynamic_args.emplace_back(index, + DynamicArgType::ListOptionalTensorType, + serialized_arg_val.size(), list_item_types); + } else if (serialized_arg_type == "as_tensors") { + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, + serialized_arg_val.size()); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a Tensor?[] input for argument ", index, + " but got ", serialized_arg_type); + } + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofStrings())) { + TORCH_CHECK( + serialized_arg_type == "as_strings", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_strings for argument ", index, + " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else { + TORCH_CHECK(false, "NYI: Unsupported list type ", serialized_arg_type, + " for extern kernel ", op_kernel.target_, " argument ", + index); + } + break; + } + case c10::TypeKind::OptionalType: { + auto inner_type = + schema_arg_type->castRaw()->getElementType(); + + if (serialized_arg_type == "as_none") { + stack.at(index) = c10::IValue{}; + if (inner_type->kind() == c10::TypeKind::TensorType) { + // Tensor is None + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 0); + } else if (inner_type->kind() == c10::TypeKind::IntType || + inner_type->kind() == c10::TypeKind::SymIntType) { + // Int or SymInt is None + dynamic_args.emplace_back(index, DynamicArgType::IntType, 0); + } else if (inner_type->kind() == c10::TypeKind::ListType && + schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) { + // List[Tensor] is None + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, 0); + } else if (inner_type->kind() == c10::TypeKind::ListType && + schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { + // List[SymInt] is None + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, 0); + } + } else { + prefill_stack_with_static_arguments(index, inner_type, serialized_arg, + op_kernel); + } + break; + } + default: + TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type, + " for extern kernel ", op_kernel.target_, " argument ", + index); + } +} + +// Populates op_kernel.stack_, op_kernel.dynamic_args_ +void OSSProxyExecutor::get_input_info_from_serialized( + const std::vector &schema_args, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel) { + std::vector filled(schema_args.size(), false); + TORCH_CHECK(op_kernel.stack_.size() == 0); + op_kernel.stack_.resize(schema_args.size()); + for (const auto &named_argument : serialized_node["inputs"]) { + const auto &arg = named_argument["arg"]; + const auto &name = named_argument["name"].get(); + + // Doing a linear lookup in the schema to find the index + // of a static argument. Should be fine performance wise + // because we usually only have small amount of arguments. + for (size_t index = 0; index < schema_args.size(); index++) { + auto &schema_arg = schema_args[index]; + if (schema_arg.name() == name) { + prefill_stack_with_static_arguments(index, schema_arg.real_type(), arg, + op_kernel); + filled[index] = true; + break; + } + } + } + + // If an argument is not filled and has a default value, we should + // also prefill the default value. + for (size_t index = 0; index < schema_args.size(); index++) { + if (!filled[index] && schema_args[index].default_value()) { + auto default_value = *schema_args[index].default_value(); + op_kernel.stack_.at(index) = default_value; + } + } +} + +// Populates op_kernel.outputs_ +void OSSProxyExecutor::get_output_info_from_serialized( + const std::vector &schema_returns, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel) { + std::vector &outputs = op_kernel.outputs_; + + TORCH_CHECK(schema_returns.size() == serialized_node["outputs"].size(), + "Serialized node doesn't match operator ", + serialized_node["target"], "'s schema outputs."); + + size_t output_index = 0; + for (const auto &serialized_output : serialized_node["outputs"]) { + TORCH_CHECK(serialized_output.size() == 1); + std::string serialized_output_type = serialized_output.begin().key(); + auto &serialized_output_val = serialized_output.begin().value(); + + auto &schema_return = schema_returns[output_index]; + const at::TypePtr &schema_return_type = schema_return.real_type(); + + switch (schema_return_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_output_type == "as_tensor", + "Expected extern kernel ", serialized_node["target"], + " to have serialized output type as_tensor, ", " but got ", + serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::TensorType, 1); + break; + } + case c10::TypeKind::ListType: { + if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK(serialized_output_type == "as_tensors", + "Expected extern kernel ", serialized_node["target"], + " to have serialized output type as_tensors, ", + " but got ", serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::ListTensorType, + serialized_output_val.size()); + } else { + TORCH_CHECK(false, "Unsupported return list type ", + schema_return_type->repr_str()); + } + break; + } + default: { + TORCH_CHECK(false, "Unsupported return type ", + schema_return_type->repr_str(), " for extern kernel ", + op_kernel.target_); + } + } + + output_index++; + } +} + +OSSProxyExecutor::OSSProxyExecutor(const std::string &json_path, bool is_cpu) { + if (is_cpu) { + device_ = std::make_unique(c10::DeviceType::CPU); + } else { + int device_idx = -1; + // Mock a cuda device for now. + device_ = std::make_unique(c10::DeviceType::CUDA, device_idx); + } + + std::string extern_kernel_nodes_serialized; + + std::ifstream json_file(json_path); + TORCH_CHECK(json_file.is_open(), "Unable to open file ", json_path); + + // Parse file into a json object + nlohmann::json json_obj; + json_file >> json_obj; + + // Access data + for (auto const &serialized_extern_node : json_obj["nodes"]) { + auto const &serialized_node = serialized_extern_node["node"]; + + const std::string &target = serialized_node["target"]; + + std::string opName; + std::string overloadName; + size_t pos = target.find('.'); + if (pos == std::string::npos) { + opName = target; + overloadName = ""; + } else { + // There should be no more periods + size_t pos2 = target.find('.', pos + 1); + TORCH_CHECK(pos2 == std::string::npos); + + opName = target.substr(0, pos); + overloadName = target.substr(pos + 1, target.length() - pos); + } + + c10::OperatorHandle op_handle = + c10::Dispatcher::singleton().findSchemaOrThrow(opName.c_str(), + overloadName.c_str()); + const c10::FunctionSchema &schema = op_handle.schema(); + + const auto &schema_args = schema.arguments(); + const auto &schema_returns = schema.returns(); + + OSSOpKernel op_kernel(target, op_handle); + get_input_info_from_serialized(schema_args, serialized_node, op_kernel); + get_output_info_from_serialized(schema_returns, serialized_node, op_kernel); + + op_kernels_.emplace_back(std::move(op_kernel)); + } +} + +void OSSProxyExecutor::call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args) { + TORCH_CHECK(extern_node_index < static_cast(op_kernels_.size()), + "Invalid extern node index"); + OSSOpKernel &op_kernel = op_kernels_[extern_node_index]; + + std::vector stack = op_kernel.stack_; + auto &dynamic_args = op_kernel.dynamic_args_; + + int tensor_id = 0; + int int_id = 0; + for (auto &dynamic_arg : dynamic_args) { + int arg_index = dynamic_arg.arg_index; + DynamicArgType dynamic_arg_type = dynamic_arg.arg_type; + int length = dynamic_arg.length; + + if (length == 0) { + continue; + } + + switch (dynamic_arg_type) { + case DynamicArgType::TensorType: { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + stack[arg_index] = *tensor; + break; + } + case DynamicArgType::IntType: { + int64_t val = flatten_int_args[int_id++]; + stack[arg_index] = val; + break; + } + case DynamicArgType::ListTensorType: { + std::vector tensor_list; + for (int j = 0; j < length; j++) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + tensor_list.push_back(*tensor); + } + stack[arg_index] = tensor_list; + break; + } + case DynamicArgType::ListOptionalTensorType: { + std::vector> optional_tensor_list; + auto &list_item_types = dynamic_arg.list_item_types; + TORCH_CHECK( + list_item_types.has_value(), + "Could not find list of item types for optional tensor list input"); + + for (const std::string &item_type : list_item_types.value()) { + if (item_type == "as_tensor") { + at::Tensor *tensor = tensor_handle_to_tensor_pointer( + flatten_tensor_args[tensor_id++]); + optional_tensor_list.emplace_back(*tensor); + } else if (item_type == "as_none") { + optional_tensor_list.emplace_back(std::nullopt); + } + } + stack[arg_index] = optional_tensor_list; + break; + } + case DynamicArgType::ListIntType: { + std::vector vals; + vals.reserve(length); + for (int j = 0; j < length; j++) { + vals.push_back(flatten_int_args[int_id++]); + } + stack[arg_index] = vals; + break; + } + default: + TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type); + } + } + + int num_output_tensors = op_kernel.num_output_tensors(); + TORCH_CHECK(tensor_id == num_tensors - num_output_tensors, + "Mismatch between tensors consumed and num of input tensor, got " + "tensor_id = .", + tensor_id, ", expected num = ", num_tensors - num_output_tensors); + TORCH_CHECK(int_id == num_ints, + "Mismatch between ints consumed and num_ints, got int_id = ", + int_id, ", num_ints = ", num_ints); + + // Call the op with the prepared stack. + const c10::OperatorHandle &op = op_kernel.op_handle_; + op.callBoxed(stack); + + const c10::FunctionSchema &schema = op.schema(); + const auto &schema_returns = schema.returns(); + + TORCH_CHECK(op_kernel.outputs_.size() == stack.size()); + TORCH_CHECK(stack.size() == schema_returns.size()); + + int index = 0; + for (const auto &schema_return : schema_returns) { + if (schema_return.type()->kind() == c10::TypeKind::TensorType) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = stack[index++].toTensor(); + } else if (schema_return.type()->kind() == c10::TypeKind::ListType && + schema_return.type()->isSubtypeOf(at::ListType::ofTensors())) { + auto tensors = stack[index++].toTensorList(); + for (auto &&t : tensors) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = t; + } + } else { + TORCH_CHECK(false, "NYI: Unsupported return type for schema: ", + schema_return.type()->repr_str()); + } + } + + TORCH_CHECK( + tensor_id == num_tensors, + "Mismatch between tensors consumed and num_tensors, got tensor_id = ", + tensor_id, ", expected num = ", num_tensors); +} + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.h b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..7dbadc99b7211d6d3e27318f3bc1b89828b93084 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::aot_inductor { + +enum class DynamicArgType : int { + TensorType = 0, + ListTensorType = 1, + ListOptionalTensorType = 2, + IntType = 3, + ListIntType = 4, +}; + +inline std::ostream &operator<<(std::ostream &os, DynamicArgType arg_type) { + os << static_cast(arg_type); + return os; +} + +inline bool isTensorType(DynamicArgType arg_type) { + return arg_type == DynamicArgType::TensorType || + arg_type == DynamicArgType::ListTensorType || + arg_type == DynamicArgType::ListOptionalTensorType; +} + +struct OSSDynamicArg { + OSSDynamicArg( + int arg_index, DynamicArgType arg_type, int length, + std::optional> list_item_types = std::nullopt) + : arg_index(arg_index), + arg_type(arg_type), + length(length), + list_item_types(std::move(list_item_types)) {} + int arg_index; + DynamicArgType arg_type; + int length; + std::optional> + list_item_types; // only used for parsing list of optional tensors +}; + +struct OSSOpKernel { + OSSOpKernel(std::string target, c10::OperatorHandle op_handle) + : target_(std::move(target)), op_handle_(std::move(op_handle)) {} + + std::string target_; + c10::OperatorHandle op_handle_; + std::vector dynamic_args_; + std::vector outputs_; + std::vector stack_; + + int num_output_tensors() const { + int num_output_tensors = 0; + for (const auto &output : outputs_) { + if (isTensorType(output.arg_type)) { + num_output_tensors += output.length; + } + } + return num_output_tensors; + } +}; + +class OSSProxyExecutor : public ProxyExecutor { + public: + explicit OSSProxyExecutor(const std::string &json_path, bool is_cpu); + + void call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args) override; + + private: + void prefill_stack_with_static_arguments(size_t index, + const at::TypePtr &schema_arg_type, + const nlohmann::json &serialized_arg, + OSSOpKernel &op_kernel); + + void get_input_info_from_serialized( + const std::vector &schema_args, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel); + + void get_output_info_from_serialized( + const std::vector &schema_returns, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel); + + std::vector op_kernels_; + std::unique_ptr device_; +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.cpp b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f7404b389c415834bd34297c2fd0b4cfdec94355 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.cpp @@ -0,0 +1,537 @@ +#include + +#include +#include +#include + +namespace { +at::Tensor *tensor_handle_to_tensor_pointer(AtenTensorHandle handle) { + return reinterpret_cast(handle); +} +} // namespace + +namespace torch::aot_inductor { +void OSSProxyExecutorNpu::prefill_stack_with_static_arguments( + size_t index, const at::TypePtr &schema_arg_type, + const nlohmann::json &serialized_arg, OSSOpKernel &op_kernel) { + auto &stack = op_kernel.stack_; + auto &dynamic_args = op_kernel.dynamic_args_; + + TORCH_CHECK(serialized_arg.size() == 1); + std::string serialized_arg_type = serialized_arg.begin().key(); + auto &serialized_arg_val = serialized_arg.begin().value(); + + switch (schema_arg_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_arg_type == "as_tensor", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_tensor for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 1); + break; + } + case c10::TypeKind::IntType: { + TORCH_CHECK(serialized_arg_type == "as_int", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_int for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + break; + } + case c10::TypeKind::SymIntType: { + TORCH_CHECK(serialized_arg_type == "as_int" || + serialized_arg_type == "as_sym_int", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_int or as_sym_int for " + "argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + break; + } + case c10::TypeKind::FloatType: { + TORCH_CHECK(serialized_arg_type == "as_float", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_float for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::BoolType: { + TORCH_CHECK(serialized_arg_type == "as_bool", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_bool for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::NumberType: { + if (serialized_arg_type == "as_int") { + // Only int Scalar is treated as dynamic arg for now + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + } else if (serialized_arg_type == "as_float") { + stack.at(index) = serialized_arg_val.get(); + } else if (serialized_arg_type == "as_bool") { + stack.at(index) = serialized_arg_val.get(); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a scalar input for argument ", index, " but got ", + serialized_arg_type); + } + break; + } + case c10::TypeKind::StringType: { + TORCH_CHECK(serialized_arg_type == "as_string", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_string for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::ScalarTypeType: { + TORCH_CHECK( + serialized_arg_type == "as_scalar_type", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_scalar_type for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::MemoryFormatType: { + TORCH_CHECK( + serialized_arg_type == "as_memory_format", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_memory_format for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::LayoutType: { + TORCH_CHECK(serialized_arg_type == "as_layout", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_layout for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::DeviceObjType: { + TORCH_CHECK(serialized_arg_type == "as_device", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_device for argument ", + index, " but got ", serialized_arg_type); + + std::string device_string = serialized_arg_val["type"].get(); + if (serialized_arg_val.contains("index") && + serialized_arg_val["index"].is_number()) { + device_string += ":" + serialized_arg_val["index"].get(); + } + + c10::Device device(device_string); + + if (device != *device_) { + VLOG(1) << "ProxyExecutor is using " << *device_ << " for " + << op_kernel.target_ << " argument #" << index + << ", which is different from the one serialized in thrift: " + << device << ". Please ensure this is intentional."; + } + + stack.at(index) = *device_; + break; + } + case c10::TypeKind::ListType: { + if (schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK( + serialized_arg_type == "as_tensors", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_tensors for argument ", index, + " but got ", serialized_arg_type); + TORCH_CHECK(serialized_arg_type == "as_tensors"); + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofInts())) { + TORCH_CHECK(serialized_arg_type == "as_ints", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_ints for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { + TORCH_CHECK(serialized_arg_type == "as_ints" || + serialized_arg_type == "as_sym_ints", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_ints or as_sym_ints " + "for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofFloats())) { + TORCH_CHECK(serialized_arg_type == "as_floats", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_floats for argument ", + index, " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofBools())) { + TORCH_CHECK(serialized_arg_type == "as_bools", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_bools for argument ", + index, " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofNumbers())) { + if (serialized_arg_type == "as_ints") { + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (serialized_arg_type == "as_floats") { + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg); + } + stack.at(index) = std::move(ret); + } else if (serialized_arg_type == "as_bools") { + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg); + } + stack.at(index) = std::move(ret); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a List[Scalar] input for argument ", index, + " but got ", serialized_arg_type); + } + } else if (schema_arg_type->isSubtypeOf( + at::ListType::ofOptionalTensors())) { + if (serialized_arg_type == "as_optional_tensors") { + std::vector list_item_types; + for (const auto &arg : serialized_arg_val) { + list_item_types.push_back(arg.begin().key()); + } + dynamic_args.emplace_back(index, + DynamicArgType::ListOptionalTensorType, + serialized_arg_val.size(), list_item_types); + } else if (serialized_arg_type == "as_tensors") { + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, + serialized_arg_val.size()); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a Tensor?[] input for argument ", index, + " but got ", serialized_arg_type); + } + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofStrings())) { + TORCH_CHECK( + serialized_arg_type == "as_strings", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_strings for argument ", index, + " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else { + TORCH_CHECK(false, "NYI: Unsupported list type ", serialized_arg_type, + " for extern kernel ", op_kernel.target_, " argument ", + index); + } + break; + } + case c10::TypeKind::OptionalType: { + auto inner_type = + schema_arg_type->castRaw()->getElementType(); + + if (serialized_arg_type == "as_none") { + stack.at(index) = c10::IValue{}; + if (inner_type->kind() == c10::TypeKind::TensorType) { + // Tensor is None + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 0); + } else if (inner_type->kind() == c10::TypeKind::IntType || + inner_type->kind() == c10::TypeKind::SymIntType) { + // Int or SymInt is None + dynamic_args.emplace_back(index, DynamicArgType::IntType, 0); + } else if (inner_type->kind() == c10::TypeKind::ListType && + schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) { + // List[Tensor] is None + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, 0); + } else if (inner_type->kind() == c10::TypeKind::ListType && + schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { + // List[SymInt] is None + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, 0); + } + } else { + prefill_stack_with_static_arguments(index, inner_type, serialized_arg, + op_kernel); + } + break; + } + default: + TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type, + " for extern kernel ", op_kernel.target_, " argument ", + index); + } +} + +// Populates op_kernel.stack_, op_kernel.dynamic_args_ +void OSSProxyExecutorNpu::get_input_info_from_serialized( + const std::vector &schema_args, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel) { + std::vector filled(schema_args.size(), false); + TORCH_CHECK(op_kernel.stack_.size() == 0); + op_kernel.stack_.resize(schema_args.size()); + for (const auto &named_argument : serialized_node["inputs"]) { + const auto &arg = named_argument["arg"]; + const auto &name = named_argument["name"].get(); + + // Doing a linear lookup in the schema to find the index + // of a static argument. Should be fine performance wise + // because we usually only have small amount of arguments. + for (size_t index = 0; index < schema_args.size(); index++) { + auto &schema_arg = schema_args[index]; + if (schema_arg.name() == name) { + prefill_stack_with_static_arguments(index, schema_arg.real_type(), arg, + op_kernel); + filled[index] = true; + break; + } + } + } + + // If an argument is not filled and has a default value, we should + // also prefill the default value. + for (size_t index = 0; index < schema_args.size(); index++) { + if (!filled[index] && schema_args[index].default_value()) { + auto default_value = *schema_args[index].default_value(); + op_kernel.stack_.at(index) = default_value; + } + } +} + +// Populates op_kernel.outputs_ +void OSSProxyExecutorNpu::get_output_info_from_serialized( + const std::vector &schema_returns, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel) { + std::vector &outputs = op_kernel.outputs_; + + TORCH_CHECK(schema_returns.size() == serialized_node["outputs"].size(), + "Serialized node doesn't match operator ", + serialized_node["target"], "'s schema outputs."); + + size_t output_index = 0; + for (const auto &serialized_output : serialized_node["outputs"]) { + TORCH_CHECK(serialized_output.size() == 1); + std::string serialized_output_type = serialized_output.begin().key(); + auto &serialized_output_val = serialized_output.begin().value(); + + auto &schema_return = schema_returns[output_index]; + const at::TypePtr &schema_return_type = schema_return.real_type(); + + switch (schema_return_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_output_type == "as_tensor", + "Expected extern kernel ", serialized_node["target"], + " to have serialized output type as_tensor, ", " but got ", + serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::TensorType, 1); + break; + } + case c10::TypeKind::ListType: { + if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK(serialized_output_type == "as_tensors", + "Expected extern kernel ", serialized_node["target"], + " to have serialized output type as_tensors, ", + " but got ", serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::ListTensorType, + serialized_output_val.size()); + } else { + TORCH_CHECK(false, "Unsupported return list type ", + schema_return_type->repr_str()); + } + break; + } + default: { + TORCH_CHECK(false, "Unsupported return type ", + schema_return_type->repr_str(), " for extern kernel ", + op_kernel.target_); + } + } + + output_index++; + } +} + +OSSProxyExecutorNpu::OSSProxyExecutorNpu(const std::string &json_path, + bool is_cpu) { + if (is_cpu) { + device_ = std::make_unique(c10::DeviceType::CPU); + } else { + int device_idx = -1; + // Mock a cuda device for now. + device_ = std::make_unique(c10::DeviceType::CUDA, device_idx); + } + + std::string extern_kernel_nodes_serialized; + + std::ifstream json_file(json_path); + TORCH_CHECK(json_file.is_open(), "Unable to open file ", json_path); + + // Parse file into a json object + nlohmann::json json_obj; + json_file >> json_obj; + + // Access data + for (auto const &serialized_extern_node : json_obj["nodes"]) { + auto const &serialized_node = serialized_extern_node["node"]; + + const std::string &target = serialized_node["target"]; + + std::string opName; + std::string overloadName; + size_t pos = target.find('.'); + if (pos == std::string::npos) { + opName = target; + overloadName = ""; + } else { + // There should be no more periods + size_t pos2 = target.find('.', pos + 1); + TORCH_CHECK(pos2 == std::string::npos); + + opName = target.substr(0, pos); + overloadName = target.substr(pos + 1, target.length() - pos); + } + + c10::OperatorHandle op_handle = + c10::Dispatcher::singleton().findSchemaOrThrow(opName.c_str(), + overloadName.c_str()); + const c10::FunctionSchema &schema = op_handle.schema(); + + const auto &schema_args = schema.arguments(); + const auto &schema_returns = schema.returns(); + + OSSOpKernel op_kernel(target, op_handle); + get_input_info_from_serialized(schema_args, serialized_node, op_kernel); + get_output_info_from_serialized(schema_returns, serialized_node, op_kernel); + + op_kernels_.emplace_back(std::move(op_kernel)); + } +} + +void OSSProxyExecutorNpu::call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, + int num_tensors, + AtenTensorHandle *flatten_tensor_args) { + TORCH_CHECK(extern_node_index < static_cast(op_kernels_.size()), + "Invalid extern node index"); + OSSOpKernel &op_kernel = op_kernels_[extern_node_index]; + + std::vector stack = op_kernel.stack_; + auto &dynamic_args = op_kernel.dynamic_args_; + + int tensor_id = 0; + int int_id = 0; + for (auto &dynamic_arg : dynamic_args) { + int arg_index = dynamic_arg.arg_index; + DynamicArgType dynamic_arg_type = dynamic_arg.arg_type; + int length = dynamic_arg.length; + + if (length == 0) { + continue; + } + + switch (dynamic_arg_type) { + case DynamicArgType::TensorType: { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + stack[arg_index] = *tensor; + break; + } + case DynamicArgType::IntType: { + int64_t val = flatten_int_args[int_id++]; + stack[arg_index] = val; + break; + } + case DynamicArgType::ListTensorType: { + std::vector tensor_list; + for (int j = 0; j < length; j++) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + tensor_list.push_back(*tensor); + } + stack[arg_index] = tensor_list; + break; + } + case DynamicArgType::ListOptionalTensorType: { + std::vector> optional_tensor_list; + auto &list_item_types = dynamic_arg.list_item_types; + TORCH_CHECK( + list_item_types.has_value(), + "Could not find list of item types for optional tensor list input"); + + for (const std::string &item_type : list_item_types.value()) { + if (item_type == "as_tensor") { + at::Tensor *tensor = tensor_handle_to_tensor_pointer( + flatten_tensor_args[tensor_id++]); + optional_tensor_list.emplace_back(*tensor); + } else if (item_type == "as_none") { + optional_tensor_list.emplace_back(std::nullopt); + } + } + stack[arg_index] = optional_tensor_list; + break; + } + case DynamicArgType::ListIntType: { + std::vector vals; + vals.reserve(length); + for (int j = 0; j < length; j++) { + vals.push_back(flatten_int_args[int_id++]); + } + stack[arg_index] = vals; + break; + } + default: + TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type); + } + } + + int num_output_tensors = op_kernel.num_output_tensors(); + + // Call the op with the prepared stack. + + const c10::OperatorHandle &op = op_kernel.op_handle_; + op.callBoxed(stack); + + const c10::FunctionSchema &schema = op.schema(); + const auto &schema_returns = schema.returns(); + + TORCH_CHECK(op_kernel.outputs_.size() == stack.size()); + TORCH_CHECK(stack.size() == schema_returns.size()); + + int index = 0; + for (const auto &schema_return : schema_returns) { + if (schema_return.type()->kind() == c10::TypeKind::TensorType) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = stack[index++].toTensor(); + } else if (schema_return.type()->kind() == c10::TypeKind::ListType && + schema_return.type()->isSubtypeOf(at::ListType::ofTensors())) { + auto tensors = stack[index++].toTensorList(); + for (auto &&t : tensors) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = t; + } + } else { + TORCH_CHECK(false, "NYI: Unsupported return type for schema: ", + schema_return.type()->repr_str()); + } + } + + TORCH_CHECK( + tensor_id == num_tensors, + "Mismatch between tensors consumed and num_tensors, got tensor_id = ", + tensor_id, ", expected num = ", num_tensors); +} +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.h b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.h new file mode 100644 index 0000000000000000000000000000000000000000..f2478bbeaa967452d3d4516f3ba70a60493ee050 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::aot_inductor { + +class OSSProxyExecutorNpu : public ProxyExecutor { + public: + explicit OSSProxyExecutorNpu(const std::string &json_path, bool is_cpu); + + void call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args) override; + + private: + void prefill_stack_with_static_arguments(size_t index, + const at::TypePtr &schema_arg_type, + const nlohmann::json &serialized_arg, + OSSOpKernel &op_kernel); + + void get_input_info_from_serialized( + const std::vector &schema_args, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel); + + void get_output_info_from_serialized( + const std::vector &schema_returns, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel); + + std::vector op_kernels_; + std::unique_ptr device_; +}; + +} // namespace torch::aot_inductor \ No newline at end of file diff --git a/torch_npu/csrc/inductor/aoti_torch/proxy_executor.h b/torch_npu/csrc/inductor/aoti_torch/proxy_executor.h new file mode 100644 index 0000000000000000000000000000000000000000..486fc802353b09b1cc0d0cfa57369bb0f8f9b666 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/proxy_executor.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +namespace torch::aot_inductor { + +class ProxyExecutor { + public: + ProxyExecutor() = default; + virtual ~ProxyExecutor() = default; + + virtual void call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args) = 0; +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp b/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f920bae6e981c90dd651e05e8a647f205648dac6 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp @@ -0,0 +1,69 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif +int32_t aoti_torch_device_type_npu() { + return (int32_t)c10::DeviceType::PrivateUse1; +} + +#ifdef __cplusplus +} // extern "C" +#endif + +namespace { +static c10::Device c10_device(int32_t device_type, int32_t device_index) { + if (device_type == aoti_torch_device_type_cpu()) { + return c10::Device(static_cast(device_type)); + } else { + return c10::Device(static_cast(device_type), + static_cast(device_index)); + } +} +} // namespace + +AOTITorchError aoti_torch_create_tensor_from_blob_npu( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret_new_tensor) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::IntArrayRef sizes(sizes_ptr, ndim); + c10::IntArrayRef strides(strides_ptr, ndim); + c10::Device device = c10_device(device_type, device_index); + c10::TensorOptions options = c10::TensorOptions().device(device).dtype( + static_cast(dtype)); + *ret_new_tensor = torch::aot_inductor::new_tensor_handle( + // data == nullptr can happen for a 0-size tensor + (data != nullptr) + ? at_npu::native::from_blob(data, sizes, strides, storage_offset, + options, device) + : at::empty_strided(sizes, strides, options)); + }); +} + +AOTITorchError aoti_torch_create_tensor_from_blob_npu_v2( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, AtenTensorHandle *ret_new_tensor, + int32_t layout, const uint8_t *opaque_metadata, + int64_t opaque_metadata_size) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + if (layout == static_cast(at::kMkldnn)) { + throw std::runtime_error("do not support mkldnn on npu."); + } else { + aoti_torch_create_tensor_from_blob_npu(data, ndim, sizes_ptr, strides_ptr, + storage_offset, dtype, device_type, + device_index, ret_new_tensor); + } + }); +} \ No newline at end of file diff --git a/torch_npu/csrc/inductor/aoti_torch/tensor_converter.h b/torch_npu/csrc/inductor/aoti_torch/tensor_converter.h new file mode 100644 index 0000000000000000000000000000000000000000..f92d556eede02105dac8a175179691e3c907775e --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/tensor_converter.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace torch::aot_inductor { + +// Functions declared here are not meant to be called from the AOTInductor +// generated model.so + +// unsafe_alloc_new_handles_from_tensors is used for allocating new aten +// tensor objects and return them as a vector of AtenTensorHandle (raw +// pointers), and those pointers will be stolen by model.so. +TORCH_API std::vector unsafe_alloc_new_handles_from_tensors( + const std::vector &tensors); + +// alloc_tensors_by_stealing_from_handles is used for creating a vector of aten +// tensors by stealing from an array of handles. Only the handles are stolen, +// and the array itself is borrowed. +// +// WARNING: Can NOT be called in model.so +TORCH_API std::vector alloc_tensors_by_stealing_from_handles( + AtenTensorHandle *handles, size_t length); + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/utils.h b/torch_npu/csrc/inductor/aoti_torch/utils.h new file mode 100644 index 0000000000000000000000000000000000000000..676429658241da1d9d8f6b01ebade15cc6a3ce89 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/utils.h @@ -0,0 +1,217 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception &e) { \ + LOG(ERROR) << "Exception in aoti_torch: " << e.what(); \ + return AOTI_TORCH_FAILURE; \ + } catch (...) { \ + LOG(ERROR) << "Exception in aoti_torch: UNKNOWN"; \ + return AOTI_TORCH_FAILURE; \ + } \ + return AOTI_TORCH_SUCCESS; + +namespace torch::aot_inductor { + +inline at::Tensor *tensor_handle_to_tensor_pointer(AtenTensorHandle handle) { + return reinterpret_cast(handle); +} + +inline AtenTensorHandle tensor_pointer_to_tensor_handle(at::Tensor *tensor) { + return reinterpret_cast(tensor); +} + +inline at::Tensor resolve_tensor_dispatch_flags(AtenTensorHandle handle) { + at::Tensor *tensor{tensor_handle_to_tensor_pointer(handle)}; + if (tensor->is_conj() || tensor->is_neg()) { + // If the conjugation or negation dispatch flags are set, runtime dispatch + // handles them by cloning the tensor before passing them to the native ATen + // function. Since the C-shim calls the native function directly, we have + // to handle the flags ourselves, or results will be silently incorrect. + return tensor->clone(); + } + return *tensor; +} + +inline std::optional resolve_tensor_dispatch_flags( + const AtenTensorHandle *handle) { + return handle ? std::make_optional(resolve_tensor_dispatch_flags(*handle)) + : std::nullopt; +} + +inline std::vector resolve_tensor_list_dispatch_flags( + const AtenTensorHandle *handle, int64_t len) { + std::vector ret{}; + ret.reserve(len); + for (int64_t i{0}; i < len; ++i) { + ret.emplace_back(resolve_tensor_dispatch_flags(handle[i])); + } + return ret; +} + +inline std::vector> +resolve_tensor_list_dispatch_flags(const AtenTensorHandle **handle, + int64_t len) { + std::vector> ret{}; + ret.reserve(len); + for (int64_t i{0}; i < len; ++i) { + ret.emplace_back(resolve_tensor_dispatch_flags(handle[i])); + } + return ret; +} + +inline at::Generator *generator_handle_to_generator_pointer( + AtenGeneratorHandle handle) { + return reinterpret_cast(handle); +} + +inline AtenGeneratorHandle generator_pointer_to_generator_handle( + at::Generator *generator) { + return reinterpret_cast(generator); +} + +inline AtenTensorHandle new_tensor_handle(at::Tensor &&tensor) { + at::Tensor *new_tensor = new at::Tensor(std::move(tensor)); + return tensor_pointer_to_tensor_handle(new_tensor); +} + +inline void assert_inf_and_nan(const std::string &tensor_name, + at::Tensor &check_tensor) { + auto isnan_tensor = check_tensor.isnan(); + if (isnan_tensor.any().item()) { + throw std::runtime_error("At least one NaN in " + tensor_name); + } + auto isinf_tensor = check_tensor.isinf(); + if (isinf_tensor.any().item()) { + throw std::runtime_error("At least one INF in " + tensor_name); + } +} + +// utility functions to convert a pointer to an optional value +template +inline std::optional pointer_to_optional(T *ptr) { + return ptr ? std::make_optional(*ptr) : std::nullopt; +} + +template >> +inline std::optional pointer_to_optional(U *ptr) { + return ptr ? std::make_optional(T(*ptr)) : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional(AtenTensorHandle *ptr) { + return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional( + const AtenTensorHandle *ptr) { + return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional( + AtenGeneratorHandle *ptr) { + return ptr ? std::make_optional(*generator_handle_to_generator_pointer(*ptr)) + : std::nullopt; +} + +inline std::optional pointer_to_optional_device( + int32_t *device_type, int32_t device_index) { + return device_type ? std::make_optional(c10::Device( + static_cast(*device_type), + static_cast(device_index))) + : std::nullopt; +} + +// utility functions to convert a pointer to a list +template +struct is_optional : std::false_type {}; +template +struct is_optional> : std::true_type {}; + +template +inline c10::ArrayRef pointer_to_list(T *ptr, int64_t len) { + return c10::ArrayRef(ptr, len); +} + +template >, + typename = std::enable_if_t::value>> +inline std::vector pointer_to_list(U *ptr, int64_t len) { + // std::vector will be implicitly converted to c10::ArrayRef at the call + // site + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(T(ptr[i])); + } + return result; +} + +template ::value>> +inline std::vector pointer_to_list(U **ptr, int64_t len) { + // Here U** denotes a list of optional arguments + // std::vector will be implicitly converted to c10::ArrayRef at the call + // site + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(pointer_to_optional(ptr[i])); + } + return result; +} + +template <> +inline std::vector pointer_to_list(const AtenTensorHandle *ptr, + int64_t len) { + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(*tensor_handle_to_tensor_pointer(ptr[i])); + } + return result; +} + +template <> +inline std::vector> pointer_to_list( + const AtenTensorHandle **ptr, int64_t len) { + std::vector> result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(pointer_to_optional(ptr[i])); + } + return result; +} + +template +inline std::array pointer_to_list(const int32_t *ptr) { + std::array result; + std::copy(ptr, ptr + N, result.begin()); + return result; +} + +// Utility function to convert a pointer to an optional list of values +template +inline std::optional> pointer_to_optional_list(U **ptr, + int64_t len) { + return ptr ? std::make_optional>( + pointer_to_list(*ptr, len)) + : std::nullopt; +} + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/array_ref_impl.h b/torch_npu/csrc/inductor/array_ref_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..1c0aeb549e1bc1b53fce06bd0f5ba4138a765ea0 --- /dev/null +++ b/torch_npu/csrc/inductor/array_ref_impl.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::aot_inductor { +template +void convert_output_to_handle(const ArrayRefTensor &output, + AtenTensorHandle &handle) { + handle = output.expensiveCopyToTensor(); +} + +template +void convert_outputs_to_handles_helper( + const std::tuple...> &outputs, + AtenTensorHandle *output_handles, std::index_sequence) { + (convert_output_to_handle(std::get(outputs), output_handles[Is]), ...); +} +template +void convert_outputs_to_handles( + const std::tuple...> &outputs, + AtenTensorHandle *output_handles) { + convert_outputs_to_handles_helper(outputs, output_handles, + std::make_index_sequence()); +} + +template +void convert_handle_to_arrayref_tensor(AtenTensorHandle handle, + ArrayRefTensor &input) { + void *data_ptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr)); + int64_t dim; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim)); + int64_t numel; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel)); + int64_t *sizes; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes)); + int64_t *strides; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides)); + int32_t dtype; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype)); + int32_t device_type; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type)); + int32_t device_index; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(handle, &device_index)); + + input = ArrayRefTensor( + MiniArrayRef(reinterpret_cast(data_ptr), numel), + MiniArrayRef(sizes, dim), + MiniArrayRef(strides, dim), device_type, device_index); +} + +template +void convert_handles_to_inputs_helper(AtenTensorHandle *input_handles, + std::tuple...> &inputs, + std::index_sequence) { + (convert_handle_to_arrayref_tensor(input_handles[Is], std::get(inputs)), + ...); +} + +template +void convert_handles_to_inputs(AtenTensorHandle *input_handles, + std::tuple...> &inputs) { + convert_handles_to_inputs_helper(input_handles, inputs, + std::make_index_sequence()); +} + +template +void assert_numel(const ArrayRefTensor &tensor, uint64_t numel) { + if (tensor.numel() != numel) { + std::stringstream err; + err << "incorrect numel for input tensor. expected " << numel << ", got " + << tensor.numel(); + throw std::runtime_error(err.str()); + } +} +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/inductor_ops.h b/torch_npu/csrc/inductor/inductor_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..924d54050ce26d6221e7259c87b3b0a9ee51fe16 --- /dev/null +++ b/torch_npu/csrc/inductor/inductor_ops.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace torch::inductor { + +TORCH_API at::Tensor _mm_plus_mm_out(at::Tensor &out, const at::Tensor &a, + const at::Tensor &b, const at::Tensor &c, + const at::Tensor &d); + +// After adding _mm_plus_mm_out, this should not be exposed and called by model +// code. Keeping it around for backward compatibility. Will be deprecated later. +TORCH_API at::Tensor _mm_plus_mm(const at::Tensor &a, const at::Tensor &b, + const at::Tensor &c, const at::Tensor &d, + at::Tensor &out); + +TORCH_API at::Tensor _alloc_from_pool(const at::Tensor &self, + int64_t offset_bytes, + at::ScalarType dtype, + at::IntArrayRef size, + at::IntArrayRef stride); + +// Similar to as_strided with the following differences +// - offset is added to the existing offset (rather than replacing it) +// - view tracking is disabled similar to unsafe_view +TORCH_API at::Tensor _reinterpret_tensor(const at::Tensor &self, + at::IntArrayRef size, + at::IntArrayRef stride, + int64_t offset_increment = 0); + +} // namespace torch::inductor