From 36a54e5263a345d126a42c5b9d95f4b894a27bad Mon Sep 17 00:00:00 2001 From: Zichun Ye Date: Mon, 1 Sep 2025 10:56:11 +0800 Subject: [PATCH] add aot inductor for v2.7.1 code clean code clean code clean code clean fix stream issue drop code fix typo fix code check del code drop gpu code fix typo fix code check fix code check fix code clean --- CMakeLists.txt | 7 +- build_libtorch_npu.py | 3 +- setup.py | 22 + torch_npu/_inductor/codecache.py | 28 +- .../codegen/aoti_runtime/interface.cpp | 272 +++++++ torch_npu/_inductor/codegen/cpp_wrapper.py | 677 +++++++++--------- torch_npu/_inductor/lowering.py | 1 - torch_npu/_inductor/lowering_op_list.py | 2 +- torch_npu/_inductor/npu_triton_heuristics.py | 24 +- torch_npu/_inductor/utils.py | 6 + torch_npu/csrc/inductor/CMakeLists.txt | 9 + .../aoti_package/model_package_loader.h | 40 ++ torch_npu/csrc/inductor/aoti_package/pybind.h | 7 + .../aoti_runner/model_container_runner.h | 113 +++ .../model_container_runner_npu.cpp | 87 +++ .../aoti_runner/model_container_runner_npu.h | 41 ++ torch_npu/csrc/inductor/aoti_runner/pybind.h | 7 + .../inductor/aoti_runtime/arrayref_tensor.h | 327 +++++++++ .../csrc/inductor/aoti_runtime/device_utils.h | 43 ++ .../csrc/inductor/aoti_runtime/interface.h | 183 +++++ torch_npu/csrc/inductor/aoti_runtime/model.h | 592 +++++++++++++++ .../inductor/aoti_runtime/model_container.h | 589 +++++++++++++++ .../inductor/aoti_runtime/scalar_to_tensor.h | 38 + .../csrc/inductor/aoti_runtime/thread_local.h | 144 ++++ torch_npu/csrc/inductor/aoti_runtime/utils.h | 234 ++++++ torch_npu/csrc/inductor/aoti_torch/c/shim.h | 653 +++++++++++++++++ .../aoti_torch/oss_proxy_executor.cpp | 545 ++++++++++++++ .../inductor/aoti_torch/oss_proxy_executor.h | 96 +++ .../aoti_torch/oss_proxy_executor_npu.cpp | 537 ++++++++++++++ .../aoti_torch/oss_proxy_executor_npu.h | 42 ++ .../csrc/inductor/aoti_torch/proxy_executor.h | 19 + .../csrc/inductor/aoti_torch/shim_npu.cpp | 69 ++ .../inductor/aoti_torch/tensor_converter.h | 25 + torch_npu/csrc/inductor/aoti_torch/utils.h | 217 ++++++ torch_npu/csrc/inductor/array_ref_impl.h | 80 +++ torch_npu/csrc/inductor/inductor_ops.h | 31 + 36 files changed, 5437 insertions(+), 373 deletions(-) create mode 100644 torch_npu/_inductor/codegen/aoti_runtime/interface.cpp create mode 100644 torch_npu/csrc/inductor/CMakeLists.txt create mode 100644 torch_npu/csrc/inductor/aoti_package/model_package_loader.h create mode 100644 torch_npu/csrc/inductor/aoti_package/pybind.h create mode 100644 torch_npu/csrc/inductor/aoti_runner/model_container_runner.h create mode 100644 torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.cpp create mode 100644 torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.h create mode 100644 torch_npu/csrc/inductor/aoti_runner/pybind.h create mode 100644 torch_npu/csrc/inductor/aoti_runtime/arrayref_tensor.h create mode 100644 torch_npu/csrc/inductor/aoti_runtime/device_utils.h create mode 100644 torch_npu/csrc/inductor/aoti_runtime/interface.h create mode 100644 torch_npu/csrc/inductor/aoti_runtime/model.h create mode 100644 torch_npu/csrc/inductor/aoti_runtime/model_container.h create mode 100644 torch_npu/csrc/inductor/aoti_runtime/scalar_to_tensor.h create mode 100644 torch_npu/csrc/inductor/aoti_runtime/thread_local.h create mode 100644 torch_npu/csrc/inductor/aoti_runtime/utils.h create mode 100644 torch_npu/csrc/inductor/aoti_torch/c/shim.h create mode 100644 torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.cpp create mode 100644 torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.h create mode 100644 torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.cpp create mode 100644 torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.h create mode 100644 torch_npu/csrc/inductor/aoti_torch/proxy_executor.h create mode 100644 torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp create mode 100644 torch_npu/csrc/inductor/aoti_torch/tensor_converter.h create mode 100644 torch_npu/csrc/inductor/aoti_torch/utils.h create mode 100644 torch_npu/csrc/inductor/array_ref_impl.h create mode 100644 torch_npu/csrc/inductor/inductor_ops.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 935f23c933..6b12fb0619 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -203,6 +203,7 @@ endif() include_directories(${PROJECT_SOURCE_DIR}) include_directories(${PROJECT_SOURCE_DIR}/torch_npu/csrc/aten) +include_directories(${PROJECT_SOURCE_DIR}/torch_npu/csrc/inductor) include_directories(${PROJECT_SOURCE_DIR}/third_party/hccl/inc) include_directories(${PROJECT_SOURCE_DIR}/third_party/acl/inc) include_directories(${PROJECT_SOURCE_DIR}/third_party/Tensorpipe) @@ -229,6 +230,7 @@ set(ATEN_SRCS) set(CORE_SRCS) set(FRAMEWORK_SRCS) set(LOGGING_SRCS) +set(INDUCTOR_SRCS) if (NOT DEFINED BUILD_LIBTORCH) set(DIST_SRCS) @@ -251,6 +253,7 @@ add_subdirectory(${TORCHNPU_ROOT}/framework) add_subdirectory(${TORCHNPU_ROOT}/flopcount) add_subdirectory(${TORCHNPU_ROOT}/logging) add_subdirectory(${TORCHNPU_ROOT}/custom_dtype) +add_subdirectory(${TORCHNPU_ROOT}/inductor) if (NOT DEFINED BUILD_LIBTORCH) add_subdirectory(${TORCHNPU_ROOT}/distributed) @@ -279,10 +282,10 @@ if (DEFINED BUILD_TENSORPIPE) endif() if (DEFINED BUILD_LIBTORCH) - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${INDUCTOR_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${FRAMEWORK_SRCS} ${LOGGING_SRCS} ${NPU_CPP_LIBS_SRCS}) else() # Compile code with pybind11 - set(CPP_SRCS ${ATEN_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${IPC_SRCS} ${UTILS_SRCS} ${SAN_SRCS} ${AFD_SRCS}) + set(CPP_SRCS ${ATEN_SRCS} ${INDUCTOR_SRCS} ${CORE_SRCS} ${OPS_PLUGIN_SRCS} ${DIST_SRCS} ${FLOP_SRCS} ${CUS_DTYPE_SRCS} ${LOGGING_SRCS} ${FRAMEWORK_SRCS} ${NPU_SRCS} ${PROF_SRCS} ${IPC_SRCS} ${UTILS_SRCS} ${SAN_SRCS} ${AFD_SRCS}) endif() add_library(${PLUGIN_NAME} SHARED ${CPP_SRCS}) diff --git a/build_libtorch_npu.py b/build_libtorch_npu.py index b91c2a28e6..77f7a0b3bd 100644 --- a/build_libtorch_npu.py +++ b/build_libtorch_npu.py @@ -225,7 +225,8 @@ def copy_hpp(): "torch_npu/csrc/framework/*.h", "torch_npu/csrc/framework/*/*.h", "torch_npu/csrc/framework/*/*/*.h", - "torch_npu/csrc/libs/*.h" + "torch_npu/csrc/libs/*.h", + "torch_npu/csrc/inductor/**/*.h", ] glob_header_files = [] for regex_pattern in header_files: diff --git a/setup.py b/setup.py index 4bbf41448f..5df0b17e27 100644 --- a/setup.py +++ b/setup.py @@ -481,6 +481,28 @@ def get_src_py_and_dst(): os.path.relpath(src, os.path.join(BASE_DIR, "patch/include"))) os.makedirs(os.path.dirname(dst), exist_ok=True) ret.append((src, dst)) + + aot_inductor_files = [ + # Follow torch v2.6.0. + # These aoti_runtime/*.cpp don't compile to libtorch_npu, + # but act like header files when generate cppwrapper in aot-inductor. + "torch_npu/_inductor/codegen/aoti_runtime/*.cpp" + ] + glob_aoti_files = [] + for regex_pattern in aot_inductor_files: + glob_aoti_files += glob.glob( + os.path.join(BASE_DIR, regex_pattern), recursive=True + ) + + for src in glob_aoti_files: + # Dst: torch_npu/_inductor/codegen/aoti_runtime/*.cpp + dst = os.path.join( + os.path.join(BASE_DIR, "build/packages/torch_npu/"), + os.path.relpath(src, os.path.join(BASE_DIR, "torch_npu")), + ) + os.makedirs(os.path.dirname(dst), exist_ok=True) + ret.append((src, dst)) + return ret diff --git a/torch_npu/_inductor/codecache.py b/torch_npu/_inductor/codecache.py index 1efec225f2..3379bcc421 100644 --- a/torch_npu/_inductor/codecache.py +++ b/torch_npu/_inductor/codecache.py @@ -88,29 +88,39 @@ def patch_aot_code_compiler_compile(): # which could not be skipped, so here we try to create a new npu op_json, # and clear the content of default op_json. from torch._inductor.codecache import AotCodeCompiler + AotCodeCompiler.src_compile = AotCodeCompiler.compile @classmethod def compile_npu( cls, graph: GraphLowering, - source_code: str, + wrapper_code: str, + kernel_code: str, serialized_extern_kernel_nodes: Optional[str], + *, device_type: str, - additional_files: List[str], + additional_files: list[str], ) -> Union[List[str], str]: result = cls.src_compile( - graph, source_code, serialized_extern_kernel_nodes, - device_type, additional_files + graph, + wrapper_code, + kernel_code, + serialized_extern_kernel_nodes, + device_type=device_type, + additional_files=additional_files, ) generated_files = additional_files if not config.aot_inductor.package: return result - + output_so = [r for r in result if r.endswith(".so")] if len(output_so) > 1: - raise RuntimeError(f"Could not generate npu op json, because there are" - f"more than one so in generated files: {result}" + pta_error(ErrCode.INTERNAL)) + raise RuntimeError( + f"Could not generate npu op json, because there are" + f"more than one so in generated files: {result}" + + pta_error(ErrCode.INTERNAL) + ) output_so = output_so[0] key = os.path.basename(output_so)[0].replace(".", "_") dir_basename = os.path.splitext(output_so)[0] @@ -120,11 +130,11 @@ def patch_aot_code_compiler_compile(): with open(extern_kernel_nodes_json, "w") as f: f.write(serialized_extern_kernel_nodes) generated_files.append(extern_kernel_nodes_json) - + if serialized_extern_kernel_nodes: source_json_file = dir_basename + ".json" with open(source_json_file, "w") as f: f.write(empty_json) return generated_files + AotCodeCompiler.compile = compile_npu - \ No newline at end of file diff --git a/torch_npu/_inductor/codegen/aoti_runtime/interface.cpp b/torch_npu/_inductor/codegen/aoti_runtime/interface.cpp new file mode 100644 index 0000000000..2d583effa3 --- /dev/null +++ b/torch_npu/_inductor/codegen/aoti_runtime/interface.cpp @@ -0,0 +1,272 @@ +// Definition of AOTI runtime interface functions + +#include +#include + +#include +#include +#include +#include + +#define CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception &e) { \ + std::cerr << "Error: " << e.what() << std::endl; \ + return AOTI_RUNTIME_FAILURE; \ + } catch (...) { \ + std::cerr << "Unknown exception occurred." << std::endl; \ + return AOTI_RUNTIME_FAILURE; \ + } \ + return AOTI_RUNTIME_SUCCESS; + +#define AOTI_VECTOR_SIZE_CHECK(actual_size, expected_size, name) \ + do { \ + AOTI_RUNTIME_CHECK(actual_size == expected_size, "expected " + std::string(name) + " vector size to be " + \ + std::to_string(expected_size) + ", but got " + \ + std::to_string(actual_size)); \ + } while (0) + +// AOTInductor uses at::addmm_out, which doesn't supports +// arguments that requires gradient. For this reason, we +// enforce no_grad context for run APIs. +// +// A RAII, thread local (!) guard that enables or disables grad mode upon +// construction, and sets it back to the original value upon destruction. +struct AOTINoGradGuard { + AOTINoGradGuard() : prev_mode(aoti_torch_grad_mode_is_enabled()) { aoti_torch_grad_mode_set_enabled(false); } + ~AOTINoGradGuard() { aoti_torch_grad_mode_set_enabled(prev_mode); } + bool prev_mode; +}; + +extern "C" { + +AOTIRuntimeError AOTInductorModelContainerCreate(AOTInductorModelContainerHandle *container_handle, size_t num_models, + bool is_cpu, const char *cubin_dir) { + return AOTInductorModelContainerCreateWithDevice(container_handle, num_models, is_cpu ? "cpu" : "npu", cubin_dir); +} + +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice(AOTInductorModelContainerHandle *container_handle, + size_t num_models, const char *device_str, + const char *cubin_dir) { + if (num_models == 0) { + std::cerr << "Error: num_models must be positive, but got 0" << std::endl; + return AOTI_RUNTIME_FAILURE; + } + CONVERT_EXCEPTION_TO_ERROR_CODE({ + std::optional cubin_dir_opt; + if (cubin_dir != nullptr) { + cubin_dir_opt.emplace(cubin_dir); + } + auto *container = + new torch::aot_inductor::AOTInductorModelContainer(num_models, std::string(device_str), cubin_dir_opt); + *container_handle = reinterpret_cast(container); + }) +} + +AOTIRuntimeError AOTInductorModelContainerDelete(AOTInductorModelContainerHandle container_handle) { + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto *container = reinterpret_cast(container_handle); + delete container; + }); +} + +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle *output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle) { + auto *container = reinterpret_cast(container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run(input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle *output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, AOTInductorStreamHandle stream_handle, AOTIProxyExecutorHandle proxy_executor_handle) { + auto *container = reinterpret_cast(container_handle); + AOTI_VECTOR_SIZE_CHECK(num_inputs, container->num_inputs(), "inputs"); + AOTI_VECTOR_SIZE_CHECK(num_outputs, container->num_outputs(), "outputs"); + + auto stream = reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_single_threaded(input_handles, output_handles, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumConstants(AOTInductorModelContainerHandle container_handle, + size_t *num_constants) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *num_constants = container->num_constants(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantName(AOTInductorModelContainerHandle container_handle, size_t idx, + const char **name) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *name = container->constant_name(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN(AOTInductorModelContainerHandle container_handle, + size_t idx, const char **original_fqn) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *original_fqn = container->constant_original_fqn(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded(AOTInductorModelContainerHandle container_handle, + size_t idx, bool *from_folded) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *from_folded = container->constant_from_folded(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantType(AOTInductorModelContainerHandle container_handle, size_t idx, + int32_t *type) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *type = container->constant_type(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype(AOTInductorModelContainerHandle container_handle, size_t idx, + int32_t *dtype) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *dtype = container->constant_dtype(idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer(AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, + bool use_inactive, bool validate_full_update) { + auto *container = reinterpret_cast(container_handle); + auto input_map = reinterpret_cast *>(constant_map_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE( + { container->update_constant_buffer(*input_map, use_inactive, validate_full_update); }) +} + +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, AOTInductorConstantMapHandle constant_map_handle) { + return AOTInductorModelContainerUpdateConstantBuffer(container_handle, constant_map_handle, + /*use_inactive*/ true, + /*validate_full_update*/ true); +} + +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding(AOTInductorModelContainerHandle container_handle, + bool use_inactive, AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle) { + auto *container = reinterpret_cast(container_handle); + auto stream = reinterpret_cast(stream_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + container->run_const_fold(use_inactive, stream, proxy_executor_handle); + }) +} + +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer(AOTInductorModelContainerHandle container_handle) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ container->swap_constant_buffer(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumInputs(AOTInductorModelContainerHandle container_handle, + size_t *ret_num_inputs) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *ret_num_inputs = container->num_inputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetInputName(AOTInductorModelContainerHandle container_handle, + size_t input_idx, const char **ret_input_names) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *ret_input_names = container->input_name(input_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs(AOTInductorModelContainerHandle container_handle, + size_t *ret_num_outputs) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *ret_num_outputs = container->num_outputs(); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetOutputName(AOTInductorModelContainerHandle container_handle, + size_t output_idx, const char **ret_output_names) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ *ret_output_names = container->output_name(output_idx); }) +} + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec(AOTInductorModelContainerHandle container_handle, + const char **in_spec, const char **out_spec) { + auto *container = reinterpret_cast(container_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + *in_spec = container->get_in_spec(); + *out_spec = container->get_out_spec(); + }) +} + +AOTIRuntimeError AOTInductorModelCreate(AOTInductorModelHandle *model_handle, + AOTInductorConstantMapHandle constant_map_handle){ + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto constant_array = std::make_shared>(); + auto input_map = reinterpret_cast *>(constant_map_handle); + + auto model = new torch::aot_inductor::AOTInductorModel( + constant_map, constant_array, + "cpu", // device_str is hardcoded, as AOTInductorModelCreate is only use for CPU models + ""); + + if (input_map) { + for (auto const &kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + } else { + model->load_constants(); + } + + *model_handle = reinterpret_cast(model); + })} + +AOTIRuntimeError AOTInductorModelRun(AOTInductorModelHandle model_handle, AtenTensorHandle *input_handles, + AtenTensorHandle *output_handles) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + AOTINoGradGuard guard; + model->run_impl(input_handles, output_handles, (torch::aot_inductor::DeviceStreamType) nullptr, nullptr); + }) +} + +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle){CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + delete model; +})} + +AOTIRuntimeError AOTInductorModelGetNumOutputs(AOTInductorModelHandle model_handle, + size_t *ret_num_outputs){CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto model = reinterpret_cast(model_handle); + *ret_num_outputs = model->num_outputs(); +})} + +AOTIRuntimeError AOTInductorModelUpdateConstantsMap(AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle) { + auto model = reinterpret_cast(model_handle); + CONVERT_EXCEPTION_TO_ERROR_CODE({ + auto constant_map = std::make_shared(); + auto input_map = reinterpret_cast *>(constant_map_handle); + + for (auto const &kv : *input_map) { + constant_map->emplace(kv.first, kv.second); + } + model->update_constants_map(std::move(constant_map)); + }) +} + +} // extern "C" diff --git a/torch_npu/_inductor/codegen/cpp_wrapper.py b/torch_npu/_inductor/codegen/cpp_wrapper.py index 9a16cfabf8..6511b550dc 100644 --- a/torch_npu/_inductor/codegen/cpp_wrapper.py +++ b/torch_npu/_inductor/codegen/cpp_wrapper.py @@ -1,4 +1,6 @@ import functools +import re +import dataclasses import os import sys from itertools import chain, count, zip_longest @@ -12,12 +14,13 @@ from torch._inductor.codecache import get_cpp_wrapper_cubin_path_name from torch._inductor.codegen.aoti_hipify_utils import maybe_hipify_code_wrapper from torch._inductor.codegen.common import get_device_op_overrides from torch._inductor.codegen.cpp_utils import cexpr, DTYPE_TO_CPP, DEVICE_TO_ATEN +from torch._inductor.codegen.triton_utils import should_unwrap_unspec_arg from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu from torch._inductor.codegen.multi_kernel import MultiKernelCall from torch._inductor.codegen.wrapper import PythonWrapperCodegen, SymbolicCallArg -from torch._inductor.ir import IRNode, TensorBox +from torch._inductor.ir import IRNode, TensorBox, GraphPartitionSignature from torch._inductor.runtime.runtime_utils import dynamo_timed -from torch._inductor.utils import DeferredLineBase +from torch._inductor.utils import DeferredLineBase, IndentedBuffer from torch._inductor.virtualized import V from torch._inductor.utils import _align, ALIGN_BYTES @@ -34,153 +37,164 @@ def checkIfTrue(value, msg): return True -class DeferredNpuKernelLine(DeferredLineBase): - """ - When using cpp wrapper, NPU kernel load and launch needs to wait for Triton kernels - to be tuned and stored as cubin files, so use a deferred line to backfill those information - """ - - def __init__( - self, - kernel_name: str, - line_template: str, - keys: Tuple[str, ...], - additional_files: List[str], - ): - super().__init__(line_template) - checkIfTrue(not isinstance(line_template, DeferredLineBase), "line template can not be DeferredLineBase") - self.additional_files = additional_files - self.kernel_name = kernel_name - self.line_template = line_template - self.keys = keys - - def __call__(self): - if self.kernel_name.startswith("multi_kernel_"): - # MultiKernel will select one kernel after running the autotune block - self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) - params = CudaKernelParamCache.get(self.kernel_name) - checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") - - for key in self.keys: - checkIfTrue(key in params, f"{key} not found in CudaKernelParamCache[{self.kernel_name}]") - - if key == get_cpp_wrapper_cubin_path_name(): - checkIfTrue(os.path.exists(params[key]), f"{params[key]} does not exist") - self.additional_files.append(params[key]) - - return self.line_template % tuple(params[key] for key in self.keys) - - def _new_line(self, line): - return DeferredNpuKernelLine( - self.kernel_name, line, self.keys, self.additional_files - ) - +_cpp_string_literal_escapes = { + "\\": "\\\\", + '"': '\\"', + "\n": "\\n", + "\t": "\\t", + "\r": "\\r", +} +_cpp_string_literal_pattern = re.compile(r'["\\\n\t\r]') -class DeferredNpuDefaultGrid: - """ - A container for the default grid, which may be used by DeferredNpuGridLine - """ - - def __init__( - self, - kernel_name: str, - grid, - grid_callable: Optional[Callable[..., Any]] = None, - **grid_extra_kwargs, - ): - self.kernel_name = kernel_name - self.grid = grid - self.grid_callable = grid_callable - self.grid_extra_kwargs = grid_extra_kwargs - - def __iter__(self): - # DeferredNpuDefaultGrid can be passed to the base class, PythonWrapperCodegen, - # to generate the autotune code block, and thus we need this iterator - return iter(self.grid) - - def _process_grid(self, grid: Union[List[Any], Tuple[Any, ...]]): - if isinstance(grid, (list, tuple)): - return [self._process_grid(e) for e in grid] - else: - return grid.inner_expr if isinstance(grid, SymbolicCallArg) else grid - - def __call__(self): - if self.kernel_name.startswith("multi_kernel_"): - # MultiKernel will select one kernel after running the autotune block - self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) - - grid = self.grid - checkIfTrue(isinstance(grid, (list, tuple)), f"expected {grid=} to be a list") - grid = self._process_grid(grid) +def cpp_string_literal(s: str) -> str: + escaped = _cpp_string_literal_pattern.sub( + lambda match: _cpp_string_literal_escapes[match.group(0)], s + ) + return f'"{escaped}"' - checkIfTrue(self.grid_callable is not None, "grid_callable can't be None") - if not self.grid_extra_kwargs: - grid_fn = self.grid_callable(*grid) - else: - grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) - - params = CudaKernelParamCache.get(self.kernel_name) - checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") +@dataclasses.dataclass +class UnwrapUnspecArg: + """Marker that we need to call .item() on the tensor""" - return grid_fn(params["meta"]) + dtype: torch_dtype -class DeferredNpuGridLine(DeferredLineBase): +@dataclasses.dataclass +class DeferredNpuTritonCallWrapper: """ - When using cpp wrapper, NPU kernel load and launch needs to wait for Triton kernels - to be tuned and stored as cubin files, so use a deferred line to backfill those information + When using cpp wrapper, GPU kernel load and launch needs to wait for Triton kernels + to be tuned and stored as cubin files, so use a deferred generating the final wrapper around + the triton kernel until right before the prefix is written. """ - def __init__( - self, - kernel_name: str, - grid_var: str, - grid, - autotune_configs, - ): - super().__init__("") - self.kernel_name = kernel_name - self.grid_var = grid_var - self.grid = grid - self.autotune_configs = autotune_configs + wrapper_name: str + kernel_name: str + arg_types: list[Any] + kernel_id: int - def __call__(self): + def generate(self, wrapper): + prefix = wrapper.prefix if self.kernel_name.startswith("multi_kernel_"): # MultiKernel will select one kernel after running the autotune block self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) - params = CudaKernelParamCache.get(self.kernel_name) + def_args = params["def_args"] + arg_types = self.arg_types + inductor_meta = params["inductor_meta"] + + if "extra_launcher_args" in inductor_meta and len(def_args) > len(arg_types): + # extra_launcher_args should already be in def_args + arg_types = arg_types + [SymbolicCallArg] * len( + inductor_meta["extra_launcher_args"] + ) - checkIfTrue(params is not None, f"{self.kernel_name} not found in CudaKernelParamCache") - - if self.autotune_configs is not None: - # This indicates the Triton kernel is a user-defined one. - grid = None - if len(self.grid) == 1: - grid = self.grid[0] - else: - for i, c in enumerate(self.autotune_configs): - if all(arg == params["meta"][key] for key, arg in c.kwargs.items()): - grid = self.grid[i] - break - checkIfTrue(grid is not None, "grid can not be None") - grid_args_str = ", ".join( - [cexpr(V.graph.sizevars.simplify(item)) for item in grid] + if not V.graph.aot_mode: + prefix.writeline( + maybe_hipify_code_wrapper( + f"static {wrapper.device_codegen.cpp_kernel_type()} {self.kernel_name} = nullptr;" + ) ) + kernel_var_name = self.kernel_name else: - launch_grid = (params['grid_x'], params['grid_y'], params['grid_z']) - grid_args_str = ", ".join( - [cexpr(item) for item in launch_grid] + kernel_var_name = f"kernels_.{self.kernel_name}" + + # tensors can be RAIIAtenTensorHandle or ConstantHandle, so make them template types + template_types = [ + f"typename {name}_type_" + for name, arg_type in zip(def_args, arg_types) + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)) + ] + if V.graph.aot_mode: + template_types.append("typename kernels_type_") + if template_types: + prefix.writeline(f"template <{', '.join(template_types)}>") + prefix.writeline(f"static inline void {self.wrapper_name}(") + with prefix.indent(): + for name, arg_type in zip(def_args, arg_types): + if isinstance(arg_type, (torch_dtype, UnwrapUnspecArg)): + prefix.writeline(f"const {name}_type_& {name},") + elif issubclass(arg_type, (SymbolicCallArg, sympy.Expr, int)): + prefix.writeline(f"int64_t {name},") + elif arg_type is float: + prefix.writeline(f"float {name},") + elif arg_type is bool: + prefix.writeline(f"bool {name},") + else: + raise ValueError(f"Unexpected arg type {arg_type}") + prefix.writeline(f"{wrapper.device_codegen.cpp_stream_type()} stream_,") + if V.graph.aot_mode: + prefix.writeline("kernels_type_& kernels_,") + prefix.writeline( + "const std::optional& cubin_dir_ = std::nullopt" ) + prefix.writeline("){") + with prefix.indent(): + self.generate_grid(prefix, inductor_meta, params) + self.generate_load_kernel(prefix, kernel_var_name, params) + self.generate_launch_kernel(prefix, wrapper, kernel_var_name, params) + prefix.writeline("}") + # Ensure the cubin file is included in the package + V.graph.wrapper_code.additional_files.append( + params[get_cpp_wrapper_cubin_path_name()] + ) - return f"\n Grid {self.grid_var} = Grid({grid_args_str});\n" + def generate_grid( + self, + prefix: IndentedBuffer, + inductor_meta: dict[str, Any], + params: dict[str, Any], + ): + from ..npu_triton_heuristics import GridExprNpu - def _new_line(self, line): - return DeferredNpuGridLine( - self.kernel_name, self.grid_var, self.grid, self.autotune_configs + numels = [arg for arg in params["call_args"] if "_numel" in arg] + grid = GridExprNpu.from_meta_and_set_numel( + inductor_meta, params["config"], numels, "cpp" + ) + for line in grid.prefix: + prefix.writeline(line) + prefix.splice( + f"""\ + uint32_t grid_0 = {grid.x_grid}; + uint32_t grid_1 = {grid.y_grid}; + uint32_t grid_2 = {grid.z_grid}; + """ + ) + prefix.writeline("if (grid_0 == 0 || grid_1 == 0 || grid_2 == 0) return;") + + def generate_load_kernel(self, prefix, kernel_var_name, params): + prefix.writeline(f"if ({kernel_var_name} == nullptr) {{") + with prefix.indent(): + load_kernel_args = [ + cpp_string_literal(params[get_cpp_wrapper_cubin_path_name()]), + cpp_string_literal(params["mangled_name"]), + str(params["shared_mem"]), + "cubin_dir_", + ] + prefix.writeline( + f"{kernel_var_name} = loadKernel({', '.join(load_kernel_args)}); " + ) + prefix.writeline("}") + + def generate_launch_kernel(self, prefix, wrapper, kernel_var_name, params): + triton_meta = params["triton_meta"] + arg_type_lookup = dict(zip(params["def_args"], self.arg_types)) + # difference between Python and C++ wrapper: C++ wrapper strips out equal_to_1 constants + call_args = [name for name in params["call_args"] if name not in triton_meta["constants"]] + arg_types = [arg_type_lookup[name] for name in call_args] + arg_signatures = [triton_meta["signature"][name] for name in call_args] + call, call_args_str = wrapper.generate_args_decl( + prefix, + call_args, + arg_types, + arg_signatures, + kernel_var_name, + self.kernel_id, ) + prefix.writeline(f"{call_args_str}") + + prefix.writeline(r"launchKernel({}, {});".format(call, f'"{kernel_var_name}"')) class CppWrapperNpu(CppWrapperCpu): @@ -189,24 +203,28 @@ class CppWrapperNpu(CppWrapperCpu): """ def __init__(self) -> None: - self.device = 'npu' + self.device = "npu" self.device_codegen = get_device_op_overrides(self.device) super().__init__() self.grid_id = count() self.visited_raii_handle = set() self.visited_handle_for_kernel_id = dict() + self._triton_call_wrappers: dict[str, DeferredNpuTritonCallWrapper] = {} @staticmethod def create( - is_subgraph: bool, subgraph_name: str, parent_wrapper: PythonWrapperCodegen + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, ): # comment at CppWrapperCpu `codegen_subgraph` function. return CppWrapperNpu() def super_write_header_rewrite(self): """Copied from CppWrapperCpu to: - (1) change __file__ path for cpython, so that we can use aoti_runtime in current path. - (2) rewrite include path of aoti header file. + (1) change __file__ path for cpython, so that we can use aoti_runtime in current path. + (2) rewrite include path of aoti header file. """ if V.graph.is_const_graph: # We do not write header for constant graph, it will be written by main module. @@ -343,11 +361,7 @@ class CppWrapperNpu(CppWrapperCpu): if V.graph.aot_mode and V.graph.inputs_to_check: for idx in V.graph.inputs_to_check: input_name = V.graph.graph_input_names[idx] - checkIfTrue(input_name in V.graph.graph_inputs, f"{input_name} not found in graph inputs") - value = V.graph.graph_inputs[input_name] - checkIfTrue(isinstance(value, TensorBox), - f"{input_name} is expected to be tensor but found as {type(value)}") self.prefix.splice( f""" @@ -360,11 +374,11 @@ class CppWrapperNpu(CppWrapperCpu): super().codegen_inputs() def define_kernel( - self, - kernel_name: str, - kernel_body: str, - metadata: Optional[str] = None, - gpu=True, + self, + kernel_name: str, + kernel_body: str, + metadata: Optional[str] = None, + gpu=True, ): if gpu: if config.triton.autotune_at_compile_time: @@ -382,10 +396,10 @@ class CppWrapperNpu(CppWrapperCpu): self.prefix.writeline("\n") if not V.graph.aot_mode: for kernel in chain( - sorted(self.src_to_kernel.values()), - sorted( - [entry[0] for entry in self.user_defined_kernel_cache.values()] - ), + sorted(self.src_to_kernel.values()), + sorted( + [entry[0] for entry in self.user_defined_kernel_cache.values()] + ), ): self.prefix.writeline( maybe_hipify_code_wrapper( @@ -396,17 +410,17 @@ class CppWrapperNpu(CppWrapperCpu): return super().generate(is_inference) def generate_user_defined_triton_kernel( - self, - kernel_name: str, - raw_args: List[Any], - grid: List[Any], - configs, - triton_meta, - constexprs, + self, + kernel_name: str, + raw_args: List[Any], + grid: List[Any], + configs, + triton_meta, + constexprs, ): if ( - config.triton.autotune_at_compile_time - and kernel_name not in self.kernel_autotune_names + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names ): # Call PythonWrapperCodegen to create the autotune code block PythonWrapperCodegen.generate_user_defined_triton_kernel( @@ -442,28 +456,8 @@ class CppWrapperNpu(CppWrapperCpu): autotune_configs=configs, ) - @functools.lru_cache(None) # noqa: B019 - def generate_load_kernel_once( - self, - kernel_name: str, - device_index, - graph: "GraphLowering", # for per-graph caching - ): - keys = (get_cpp_wrapper_cubin_path_name(), "mangled_name", "shared_mem") - kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name - self.writeline(f"if ({kernel_var_name} == nullptr) {{") - deferred_gpu_kernel_line = DeferredNpuKernelLine( - kernel_name, - " " + kernel_var_name + r' = loadKernel("%s", "%s", %s, this->cubin_dir_);', - keys, - self.additional_files, - ) - self.writeline(deferred_gpu_kernel_line) - self.writeline("}") - return kernel_var_name - def codegen_tensor_item_npu( - self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None + self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None ): dtype_str = str(dtype).split(".")[-1] writer = indented_buffer or self @@ -475,22 +469,22 @@ class CppWrapperNpu(CppWrapperCpu): f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar_tmp}));" ) writer.writeline(f"float {scalar} = float({scalar_tmp});") - struct_data = f'float {scalar} __attribute__((aligned(4)));' - arg_data = f'static_cast({scalar})' + struct_data = f"float {scalar} __attribute__((aligned(4)));" + arg_data = f"static_cast({scalar})" else: writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};") writer.writeline( f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));" ) - struct_data = f'{DTYPE_TO_CPP[dtype]} {scalar} __attribute__((aligned(sizeof({DTYPE_TO_CPP[dtype]} ))));' - arg_data = f'static_cast<{DTYPE_TO_CPP[dtype]}>({scalar})' + struct_data = f"{DTYPE_TO_CPP[dtype]} {scalar} __attribute__((aligned(sizeof({DTYPE_TO_CPP[dtype]} ))));" + arg_data = f"static_cast<{DTYPE_TO_CPP[dtype]}>({scalar})" return struct_data, arg_data def codegen_device(self, device): if device.type not in DEVICE_TO_ATEN: raise RuntimeError(device.type + "not found in DEVICE_TO_ATEN") - device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" + device_str = DEVICE_TO_ATEN[device.type][5:].lower() # remove "at::k" if device_str == "privateuse1": device_str = "npu" self.used_cached_devices.add(device_str) @@ -525,21 +519,21 @@ class CppWrapperNpu(CppWrapperCpu): return "" if kernel_id not in self.visited_handle_for_kernel_id: self.visited_handle_for_kernel_id[kernel_id] = set() - + def get_tensor_from_handle(h, t): if h in self.visited_handle_for_kernel_id[kernel_id]: return "" self.visited_handle_for_kernel_id[kernel_id].add(h) return f" auto {t} = *tensor_handle_to_tensor_pointer({h});\n" - + # Only dump tensor args, e.g, ['buf2', '8L', '4L'] => ['buf2'] tensor_args = [arg for arg in args if not arg[0].isdigit()] tensor_args_h = [f"{arg}_h" for arg in tensor_args] tensor_args_t = [f"{arg}_t" for arg in tensor_args] - handle_tensor_str = "".join([ - get_tensor_from_handle(h, t) for h, t in zip(tensor_args_h, tensor_args_t) - ]) + handle_tensor_str = "".join( + [get_tensor_from_handle(h, t) for h, t in zip(tensor_args_h, tensor_args_t)] + ) dump_path = npu_config.aot_inductor.dump_path_cpp return f""" @@ -549,16 +543,30 @@ class CppWrapperNpu(CppWrapperCpu): torch::save(arg_{mark}, "{dump_path}/{kernel_id}_{kernel_name}_{mark}.pt"); """ - def generate_launch_call( + def generate_args_decl( self, + code, call_args, arg_types, arg_signatures, + kernel_name, kernel_id, - grid_var, - kernel_name + is_triton_kernel=True, ): - kernel_val_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name + """ + Generates any declarations of args to pass into a kernel call, and then returns the arg names. + + In more detail: + * declarations: e.g. this function has a side effect of generating lines like `auto var_0 = ...;` + * returns: a string with the list of args, e.g. "var_0, var_1" + + call_args: list of call arguments + arg_types: list of argument types + arg_signatures: list with signatures of all the args + is_triton_kernel: whether these are passed into a triton kernel or not. In particular, + calls to triton kernels will have an additional global scratch space + arg injected at the front of the arg list. + """ new_args: list[str] = [] # Add more cases for other types as needed @@ -577,115 +585,105 @@ class CppWrapperNpu(CppWrapperCpu): "fp64": "double", } - - struct_def_body = '' - struct_arg_body = '' + struct_def_body = "" + struct_arg_body = "" def process_args(arg, arg_type, arg_signature=None): var_name = f"var_{next(self.arg_var_id)}" # ignore nvTmaDesc, as host-side TMA descriptors need # to be passed to the compiled Triton kernel by value - if isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc": - if arg.endswith(".item()"): # scalar - # Need to declare a scalar in this case - arg = arg[:-7] - struct_data, arg_data = self.codegen_tensor_item_npu( - arg_type, - arg, - var_name, - ) - else: - # void* - device_ptr_type = self.device_codegen.cpp_device_ptr() - self.writeline( - maybe_hipify_code_wrapper( - f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" - ) + if isinstance(arg_type, UnwrapUnspecArg) and arg_signature != "nvTmaDesc": + self.codegen_tensor_item_npu( + arg_type.dtype, + arg, + var_name, + indented_buffer=code, + ) + elif isinstance(arg_type, torch_dtype) and arg_signature != "nvTmaDesc": + device_ptr_type = self.device_codegen.cpp_device_ptr() + code.writeline( + maybe_hipify_code_wrapper( + f"{device_ptr_type} {var_name} = reinterpret_cast<{device_ptr_type}>({arg}.data_ptr());" ) - if npu_config.aot_inductor.debug_kernel: - if arg not in self.visited_raii_handle: - self.writeline( - f"AtenTensorHandle {arg}_h = {arg}.get();" - ) - self.visited_raii_handle.add(arg) - struct_data = f'void* {var_name} __attribute__((aligned(8)));' - arg_data = f'static_cast({var_name})' - + ) + if npu_config.aot_inductor.debug_kernel: + if arg not in self.visited_raii_handle: + self.writeline(f"AtenTensorHandle {arg}_h = {arg}.get();") + self.visited_raii_handle.add(arg) + struct_data = f"void* {var_name} __attribute__((aligned(8)));" + arg_data = f"static_cast({var_name})" elif arg_type in (sympy.Integer, int): - # int - self.writeline(f"int {var_name} = {cexpr(arg)};") - struct_data = f'int {var_name} __attribute__((aligned(4)));' - arg_data = f'static_cast({var_name})' - + code.writeline(f"int {var_name} = {cexpr(arg)};") + struct_data = f"int {var_name} __attribute__((aligned(4)));" + arg_data = f"static_cast({var_name})" elif arg_type in (sympy.Float, float): - # float - self.writeline(f"float {var_name} = {cexpr(arg)};") - struct_data = f'float {var_name} __attribute__((aligned(4)));' - arg_data = f'static_cast({var_name})' - + code.writeline(f"float {var_name} = {cexpr(arg)};") + struct_data = f"float {var_name} __attribute__((aligned(4)));" + arg_data = f"static_cast({var_name})" # For symbolic call arguments, examine the arg signatures from triton meta # to explicitly cast to the right type # Reason: `auto` can infer unexpected type against kernel input signature. elif ( - isinstance(arg_type, type(SymbolicCallArg)) - and arg_signature is not None - and arg_signature in signature2dtype.keys() + isinstance(arg_type, type(SymbolicCallArg)) + and arg_signature is not None + and arg_signature in signature2dtype.keys() ): - # or scalar symbolic type,currently only support scalar symbolic type - self.writeline( + code.writeline( f"{signature2dtype[arg_signature]} {var_name} = {cexpr(arg)};" ) - struct_data = f'{signature2dtype[arg_signature]} {var_name} __attribute__((aligned(sizeof({signature2dtype[arg_signature]}))));' - arg_data = f'static_cast<{signature2dtype[arg_signature]}>({var_name})' + struct_data = f"{signature2dtype[arg_signature]} {var_name} __attribute__((aligned(sizeof({signature2dtype[arg_signature]}))));" + arg_data = f"static_cast<{signature2dtype[arg_signature]}>({var_name})" else: raise TypeError("Infer arg_type to cpp failed!") - nonlocal struct_def_body nonlocal struct_arg_body - struct_def_body += struct_data + ' ' - struct_arg_body += arg_data + ', ' + struct_def_body += struct_data + " " + struct_arg_body += arg_data + ", " for arg, arg_type, arg_signature in zip_longest( - call_args, arg_types, arg_signatures + call_args, arg_types, arg_signatures ): process_args(arg, arg_type, arg_signature) - debug_str_before_kernel = self.generate_debug_str(call_args, kernel_name, kernel_id, "before") - debug_str_after_kernel = self.generate_debug_str(call_args, kernel_name, kernel_id, "after") - + debug_str_before_kernel = self.generate_debug_str( + call_args, kernel_name, kernel_id, "before" + ) + debug_str_after_kernel = self.generate_debug_str( + call_args, kernel_name, kernel_id, "after" + ) launch_str = f""" auto launch_call_{kernel_id} = [=]() {{ - int32_t grid_x = {grid_var}.grid_x; - int32_t grid_y = {grid_var}.grid_y; - int32_t grid_z = {grid_var}.grid_z; rtError_t ret; void* ffts_addr = NULL; uint32_t ffts_len; ret = rtGetC2cCtrlAddr((uint64_t*)&ffts_addr, &ffts_len); if (ret != RT_ERROR_NONE) return ret; void* workspace_addr = NULL; + void* sync_block_lock = NULL; struct __attribute__((packed)) {{ void* ffts_addr __attribute__((aligned(8))); + void* sync_block_lock __attribute__((aligned(8))); void* workspace_addr __attribute__((aligned(8))); {struct_def_body} - int32_t grid_x __attribute__((aligned(4))); - int32_t grid_y __attribute__((aligned(4))); - int32_t grid_z __attribute__((aligned(4))); + int32_t grid_0 __attribute__((aligned(4))); + int32_t grid_1 __attribute__((aligned(4))); + int32_t grid_2 __attribute__((aligned(4))); }} kernel_args = {{ static_cast(ffts_addr), + static_cast(sync_block_lock), static_cast(workspace_addr), {struct_arg_body} - static_cast(grid_x), - static_cast(grid_y), - static_cast(grid_z) + static_cast(grid_0), + static_cast(grid_1), + static_cast(grid_2) }}; - uint32_t block_num = grid_x * grid_y * grid_z; + uint32_t block_num = grid_0 * grid_1 * grid_2; auto arg_ptr = static_cast(&kernel_args); auto arg_size = sizeof(kernel_args); {debug_str_before_kernel} - ret = rtKernelLaunch({kernel_val_name}, block_num, arg_ptr, arg_size, NULL, stream); + ret = rtKernelLaunch({kernel_name}, block_num, arg_ptr, arg_size, NULL, stream_); {debug_str_after_kernel} if (ret != RT_ERROR_NONE) return ret; return ret; @@ -693,43 +691,24 @@ class CppWrapperNpu(CppWrapperCpu): """ return f"launch_call_{kernel_id}", launch_str - def generate_default_grid( - self, - kernel_name: str, - grid_args: List[Any], - gpu: bool = True, - grid_callable: Optional[Callable[..., Any]] = None, - **grid_extra_kwargs, - ): - """ - Generate grid configs for launching a CUDA kernel using the grid - function from triton_heuristics. Because its computation needs - to read kernel config after autotune, it is done in a deferred way - using DeferredNpuDefaultGrid. - """ - checkIfTrue(gpu, "CppWrapperNpu.generate_default_grid does not support non-NPU") - return DeferredNpuDefaultGrid( - kernel_name, grid_args, grid_callable, **grid_extra_kwargs - ) - def generate_kernel_call_npu( - self, - kernel_name: str, - call_args, - grid=None, - device_index=None, - npu=True, - triton=True, - arg_types=None, - raw_args=None, - grid_fn: str = "grid", - triton_meta=None, - autotune_configs=None, - grid_extra_kwargs="", + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + npu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", ): if ( - config.triton.autotune_at_compile_time - and kernel_name not in self.kernel_autotune_names + config.triton.autotune_at_compile_time + and kernel_name not in self.kernel_autotune_names ): # Call PythonWrapperCodegen to create the autotune code block PythonWrapperCodegen.generate_kernel_call( @@ -759,69 +738,25 @@ class CppWrapperNpu(CppWrapperCpu): ) if triton: - device_index, call_args = self.prepare_triton_kernel_call( - device_index, call_args + call_args, arg_types = self.prepare_triton_wrapper_args( + call_args, arg_types ) - _ = self.generate_load_kernel_once(kernel_name, device_index, V.graph) - - # args with value 1 are added into equal_to_1 and constants - # in triton_meta (in the Python codegen) which makes them - # inlined in the PTX and compiled CUBIN - arg_signatures = [] - if ( - triton_meta is not None - and triton_meta.get("configs") - and triton_meta.get("signature") - ): - equal_to_1 = triton_meta["configs"][0].equal_to_1 - call_args = [ - arg - for i, arg in enumerate(call_args) - if i not in equal_to_1 - ] - arg_types = [t for i, t in enumerate(arg_types) if i not in equal_to_1] - # extract the arg signatures from triton_meta - arg_signatures = triton_meta["signature"].values() - arg_signatures = [ - v - for i, v in enumerate(arg_signatures) - if i not in equal_to_1 - ] - + wrapper_name = f"call_{kernel_name}" current_kernel_id = next(self.kernel_callsite_id) - current_grid_id = next(self.grid_id) - - # gen grids - grid_var = f"{kernel_name}_grid_{current_grid_id}" - self.writeline( - DeferredNpuGridLine(kernel_name, grid_var, grid, autotune_configs) - ) - - call, call_args_str = self.generate_launch_call( - call_args, arg_types, arg_signatures, current_kernel_id, grid_var, kernel_name - ) - self.writeline(f"{call_args_str}") - - # add debug printer code for all triton kernel related calls + if wrapper_name not in self._triton_call_wrappers: + self._triton_call_wrappers[wrapper_name] = DeferredNpuTritonCallWrapper( + wrapper_name, kernel_name, arg_types, current_kernel_id + ) + call_args.append(stream) + if V.graph.aot_mode: + call_args.append("kernels") + call_args.append("this->cubin_dir_") debug_printer_manager = V.graph.wrapper_code.debug_printer debug_printer_manager.set_printer_args( - call_args, kernel_name, arg_types, None + call_args[: len(arg_types)], kernel_name, arg_types, None ) with debug_printer_manager: - self.writeline(f"if ({grid_var}.is_non_zero()) {{") - self.writeline( - DeferredNpuKernelLine( - kernel_name, - r" launchKernel({}, {});".format( \ - call, - f'"{kernel_name}"', - ), - (), - self.additional_files, - ), - ) - - self.writeline("}\n") + self.writeline(f"{wrapper_name}({', '.join(call_args)});") else: casted = [] for arg_type, arg in zip(arg_types, call_args): @@ -833,19 +768,19 @@ class CppWrapperNpu(CppWrapperCpu): self.writeline(f"kernels.{kernel_name}({call_args_str}, {stream});") def generate_kernel_call( - self, - kernel_name: str, - call_args, - grid=None, - device_index=None, - gpu=True, - triton=True, - arg_types=None, - raw_args=None, - grid_fn: str = "grid", - triton_meta=None, - autotune_configs=None, - grid_extra_kwargs="", + self, + kernel_name: str, + call_args, + grid=None, + device_index=None, + gpu=True, + triton=True, + arg_types=None, + raw_args=None, + grid_fn: str = "grid", + triton_meta=None, + autotune_configs=None, + grid_extra_kwargs="", ): """ Override the default value of argument 'gpu' to True here. @@ -892,5 +827,37 @@ class CppWrapperNpu(CppWrapperCpu): grid_extra_kwargs, ) + def finalize_prefix(self): + """Define the triton kernels now that autotuning is finished""" + old_prefix = self.prefix # new content should go at start of prefix + self.prefix = IndentedBuffer() + super().finalize_prefix() + for kernel in self._triton_call_wrappers.values(): + self.prefix.writeline("\n") + kernel.generate(self) + self.prefix.writeline("\n") + self.prefix.splice(old_prefix) + + @staticmethod + def prepare_triton_wrapper_args( + call_args: list[Any], arg_types: list[Any] + ) -> tuple[list[Any], list[Any]]: + new_args = [] + new_args_types = [] + for arg, arg_type in zip(call_args, arg_types): + if isinstance(arg, str): + if isinstance(arg_type, torch_dtype) and should_unwrap_unspec_arg(arg): + # dynamo wraps unspec variable as 0d CPU tensor, need convert to scalar + arg_type = UnwrapUnspecArg(dtype=arg_type) + new_args.append(arg) + elif isinstance(arg, bool): + new_args.append(str(arg).lower()) + elif isinstance(arg, (int, float, SymbolicCallArg)): + new_args.append(str(arg)) + else: + new_args.append(cexpr(V.graph.sizevars.simplify(arg))) + new_args_types.append(arg_type) + return new_args, new_args_types + def make_zero_buffer(self, name): return f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_zero_({name}.get()));" diff --git a/torch_npu/_inductor/lowering.py b/torch_npu/_inductor/lowering.py index 2b47e091af..fa37ca8077 100644 --- a/torch_npu/_inductor/lowering.py +++ b/torch_npu/_inductor/lowering.py @@ -143,7 +143,6 @@ def _register_npu_inductor_fallbacks(): make_reduction("argmin", override_return_dtype=torch.int64) ) - @register_lowering(aten.max, type_promotion_kind=None) def reduce_max(x, dim=None, keepdim=False): if dim is not None: diff --git a/torch_npu/_inductor/lowering_op_list.py b/torch_npu/_inductor/lowering_op_list.py index db9c427e60..2ec1687399 100644 --- a/torch_npu/_inductor/lowering_op_list.py +++ b/torch_npu/_inductor/lowering_op_list.py @@ -74,7 +74,7 @@ GENERATE_LIST = [ aten.bitwise_and, aten.squeeze, aten.copy, - aten.reciprocal + aten.reciprocal, ] GENERATE_LIST2 = [ diff --git a/torch_npu/_inductor/npu_triton_heuristics.py b/torch_npu/_inductor/npu_triton_heuristics.py index 21275d77f0..877ceb7dad 100644 --- a/torch_npu/_inductor/npu_triton_heuristics.py +++ b/torch_npu/_inductor/npu_triton_heuristics.py @@ -157,20 +157,29 @@ def do_bench_using_profiling_npu(fn, warmup=2, rep=10, grad_to_none=None, quanti @dataclasses.dataclass class GridNpu(GridExpr): numels: List[str] = None + mode: Literal["python", "cpp"] = "python" def generate(self, meta: dict[str, int]) -> None: numel_args = [] split_axis = meta.get("split_axis", None) split_blocks = meta.get("split_blocks", None) if split_axis is None or split_blocks is None: - raise RuntimeError(f"Could not get split_axis or split_blocks from meta {meta}.") + raise RuntimeError( + f"Could not get split_axis or split_blocks from meta {meta}." + ) def grid_fn(i): if i >= len(split_axis): return "1" axis = split_axis[i] block = split_blocks[i] - return f"({self.numels[axis]} + {block} - 1) // {block}" + if block is None or block == 1: + return self.numels[axis] + if self.mode == "python": + return f"({self.numels[axis]} + {block} - 1) // {block}" + else: + return f"(({self.numels[axis]} + ({block} - 1)) / ({block}))" + self.x_grid = grid_fn(0) self.y_grid = grid_fn(1) self.z_grid = grid_fn(2) @@ -186,8 +195,10 @@ class GridExprNpu(GridExpr): ) -> GridExpr: grid_cls = globals()[inductor_meta["grid_type"]] if not issubclass(grid_cls, GridNpu): - raise AssertionError(f"grid_type in inductor_meta must be subclass of GridNpu" - f"but got {inductor_meta['grid_type']}") + raise AssertionError( + f"grid_type in inductor_meta must be subclass of GridNpu" + f"but got {inductor_meta['grid_type']}" + ) grid = grid_cls(inductor_meta=inductor_meta, mode=mode, numels=numels) if isinstance(cfg, Config): cfg = config_to_dict(cfg) @@ -586,6 +597,11 @@ class NPUCachingAutotuner(CachingAutotuner): "stream": input_stream, # User defined triton kernels will have arbitrary kwarg names "meta": input_launcher.config.kwargs, + "config": config_to_dict(input_launcher.config), + "inductor_meta": self.inductor_meta, + "triton_meta": self.triton_meta, + "def_args": input_launcher.def_args, + "call_args": input_launcher.call_args, } from torch._inductor.codecache import CudaKernelParamCache diff --git a/torch_npu/_inductor/utils.py b/torch_npu/_inductor/utils.py index 095f1f69cf..24aac04906 100644 --- a/torch_npu/_inductor/utils.py +++ b/torch_npu/_inductor/utils.py @@ -32,8 +32,14 @@ def patch_is_same_tensor(): def patch_is_gpu(): from torch._inductor.utils import GPU_TYPES + GPU_TYPES.append('npu') + def _return_false(device_interface): + return False + + torch._inductor.scheduler.device_need_guard = _return_false + def patch_has_triton(): from torch.utils._triton import has_triton_package diff --git a/torch_npu/csrc/inductor/CMakeLists.txt b/torch_npu/csrc/inductor/CMakeLists.txt new file mode 100644 index 0000000000..51913f46b5 --- /dev/null +++ b/torch_npu/csrc/inductor/CMakeLists.txt @@ -0,0 +1,9 @@ +FILE(GLOB _INDUCTOR_SRCS + *.cpp + aoti_runner/*.cpp + aoti_torch/*.cpp) + +LIST(APPEND INDUCTOR_SRCS ${_INDUCTOR_SRCS}) + +# Pass to parent +set(INDUCTOR_SRCS ${INDUCTOR_SRCS} PARENT_SCOPE) \ No newline at end of file diff --git a/torch_npu/csrc/inductor/aoti_package/model_package_loader.h b/torch_npu/csrc/inductor/aoti_package/model_package_loader.h new file mode 100644 index 0000000000..9f1fbde09b --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_package/model_package_loader.h @@ -0,0 +1,40 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +namespace torch::inductor { +class TORCH_API AOTIModelPackageLoader { + public: + AOTIModelPackageLoader(const std::string &model_package_path, + const std::string &model_name = "model", + const bool run_single_threaded = false); + ~AOTIModelPackageLoader(); + + AOTIModelContainerRunner *get_runner(); + std::unordered_map get_metadata(); + + std::vector run(const std::vector &inputs, + void *stream_handle = nullptr); + + // boxed_run will steal the ownership of the input tensors + std::vector boxed_run(std::vector &&inputs, + void *stream_handle = nullptr); + + std::vector get_call_spec(); + void load_constants( + std::unordered_map &constants_map, + bool use_inactive, bool check_full_update); + std::vector get_constant_fqns(); + + private: + std::string temp_dir_; + std::unique_ptr runner_; + std::unordered_map metadata_; + + void load_metadata(const std::string &cpp_filename); +}; + +} // namespace torch::inductor +#endif diff --git a/torch_npu/csrc/inductor/aoti_package/pybind.h b/torch_npu/csrc/inductor/aoti_package/pybind.h new file mode 100644 index 0000000000..71e6390878 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_package/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIPackageBindings(PyObject *module); + +} // namespace torch::inductor diff --git a/torch_npu/csrc/inductor/aoti_runner/model_container_runner.h b/torch_npu/csrc/inductor/aoti_runner/model_container_runner.h new file mode 100644 index 0000000000..9a62a8158b --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runner/model_container_runner.h @@ -0,0 +1,113 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include +#include + +// Forward declare DynamicLibrary +namespace at { +struct DynamicLibrary; +} + +namespace torch::inductor { +using TensorConstantMap = std::unordered_map; + +class TORCH_API AOTIModelContainerRunner { + public: + AOTIModelContainerRunner() = delete; + AOTIModelContainerRunner(const AOTIModelContainerRunner &other) = delete; + AOTIModelContainerRunner(AOTIModelContainerRunner &&other) = delete; + AOTIModelContainerRunner &operator=(const AOTIModelContainerRunner &other) = + delete; + AOTIModelContainerRunner &operator=(AOTIModelContainerRunner &&other) = + delete; + virtual ~AOTIModelContainerRunner(); + + std::vector run(const std::vector &inputs, + void *stream_handle = nullptr); + + // boxed_run will steal the ownership of the input tensors + std::vector boxed_run(std::vector &&inputs, + void *stream_handle = nullptr); + + std::unordered_map getConstantNamesToOriginalFQNs() + const; + std::unordered_map getConstantNamesToDtypes() const; + + void update_inactive_constant_buffer(const TensorConstantMap &const_map); + void update_constant_buffer( + std::unordered_map &tensor_map, + bool use_inactive, bool validate_full_updates); + void update_constant_buffer(const TensorConstantMap &const_map, + bool use_inactive, bool validate_full_updates); + void run_const_fold(bool use_inactive, + AOTInductorStreamHandle npu_stream_handle = nullptr); + void swap_constant_buffer(); + + std::vector get_call_spec(); + + protected: + AOTIModelContainerRunner(const std::string &model_so_path, size_t num_models, + const std::string &device_str, + const std::string &cubin_dir, + const bool run_single_threaded); + + virtual std::vector run_impl( + std::vector &input_handles, void *stream_handle); + + std::unique_ptr model_so_; + decltype(&AOTInductorModelContainerCreateWithDevice) create_func_{nullptr}; + decltype(&AOTInductorModelContainerDelete) delete_func_{nullptr}; + decltype(&AOTInductorModelContainerGetNumOutputs) get_num_outputs_func_{ + nullptr}; + decltype(&AOTInductorModelContainerRun) run_func_{nullptr}; + decltype(&AOTInductorModelContainerGetNumConstants) get_num_constants_func_{ + nullptr}; + decltype(&AOTInductorModelContainerGetConstantName) get_constant_name_func_{ + nullptr}; + decltype(&AOTInductorModelContainerGetConstantOriginalFQN) + get_constant_original_fqn_func_{nullptr}; + decltype(&AOTInductorModelContainerGetConstantDtype) get_constant_dtype_func_{ + nullptr}; + decltype(&AOTInductorModelContainerUpdateConstantBuffer) + update_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerUpdateInactiveConstantBuffer) + update_inactive_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerRunConstantFolding) run_const_fold_func_{ + nullptr}; + decltype(&AOTInductorModelContainerSwapConstantBuffer) + swap_constant_buffer_func_{nullptr}; + decltype(&AOTInductorModelContainerGetCallSpec) get_call_spec_func_{nullptr}; + + AOTInductorModelContainerHandle container_handle_ = nullptr; + + AOTIProxyExecutorHandle proxy_executor_handle_; + + private: + std::unique_ptr proxy_executor_; +}; + +using CreateAOTIModelRunnerFunc = std::unique_ptr (*)( + const std::string &model_so_path, size_t num_models, + const std::string &device_str, const std::string &bin_dir, + const bool run_single_threaded); + +// Return a global map "device name" -> "aoti model runner create function" for +// all registered in AOTI external backends +TORCH_API std::unordered_map & +getAOTIModelRunnerRegistry(); + +// To register a new external backend in AOTI one needs to create an instance of +// this struct. It is not thread-safe. Becase it is expected to be called during +// the initialization of the program. +struct TORCH_API RegisterAOTIModelRunner { + RegisterAOTIModelRunner( + const std::string &name, + CreateAOTIModelRunnerFunc create_aoti_model_runner_fn) { + getAOTIModelRunnerRegistry()[name] = create_aoti_model_runner_fn; + } +}; + +} // namespace torch::inductor +#endif diff --git a/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.cpp b/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.cpp new file mode 100644 index 0000000000..a3776f75ca --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.cpp @@ -0,0 +1,87 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#include +#include + +#include + +#ifndef _WIN32 +#include +#else +#include +namespace fs = std::filesystem; +#endif + +namespace { +bool file_exists(std::string &path) { +#ifdef _WIN32 + return fs::exists(path); +#else + struct stat rc{}; + return lstat(path.c_str(), &rc) == 0; +#endif +} +} // namespace + +namespace torch::inductor { + +AOTIModelContainerRunnerNpu::AOTIModelContainerRunnerNpu( + const std::string &model_so_path, size_t num_models, + const std::string &device_str, const std::string &cubin_dir, + const bool run_single_threaded) + : AOTIModelContainerRunner(model_so_path, num_models, device_str, cubin_dir, + run_single_threaded) { + model_so_path_ = model_so_path; + init_flag_ = false; +} + +AOTIModelContainerRunnerNpu::~AOTIModelContainerRunnerNpu() = default; + +void AOTIModelContainerRunnerNpu::init_proxy_executor() { + if (init_flag_) return; + + init_flag_ = true; + size_t lastindex = model_so_path_.find_last_of('.'); + std::string json_filename = model_so_path_.substr(0, lastindex) + "_npu.json"; + + if (file_exists(json_filename)) { + proxy_executor_npu_ = + std::make_unique( + json_filename, false); + proxy_executor_handle_ = + reinterpret_cast(proxy_executor_npu_.get()); + } else { + proxy_executor_handle_ = nullptr; + } +} + +std::vector AOTIModelContainerRunnerNpu::run_impl( + std::vector &input_handles, void *stream_handle) { + init_proxy_executor(); + c10_npu::NPUStream npu_stream = c10_npu::getCurrentNPUStream(); + return AOTIModelContainerRunner::run_impl( + input_handles, reinterpret_cast(npu_stream.stream())); +} + +std::vector AOTIModelContainerRunnerNpu::run_with_npu_stream( + const std::vector &inputs, + const c10_npu::NPUStream &npu_stream) +{ + init_proxy_executor(); + c10_npu::NPUStream cur_npu_stream = c10_npu::getCurrentNPUStream(); + return run(inputs, reinterpret_cast(cur_npu_stream.stream())); +} + +namespace { +std::unique_ptr create_aoti_runner_npu( + const std::string &model_so_path, size_t num_models, + const std::string &device_str, const std::string &cubin_dir, + const bool run_single_threaded) { + return std::make_unique( + model_so_path, num_models, device_str, cubin_dir, run_single_threaded); +} +} // namespace + +RegisterAOTIModelRunner register_npu_runner("npu", &create_aoti_runner_npu); + +} // namespace torch::inductor +#endif diff --git a/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.h b/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.h new file mode 100644 index 0000000000..af16757fbb --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runner/model_container_runner_npu.h @@ -0,0 +1,41 @@ +#if !defined(C10_MOBILE) && !defined(ANDROID) +#pragma once + +#include +#include + +namespace torch::inductor { + +// NOTICE: Following APIs are subject to change due to active development +// We provide NO BC guarantee for these APIs +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class AOTIModelContainerRunnerNpu + : public AOTIModelContainerRunner { +public: + // @param device_str: npu device string, e.g. "npu", "npu:0" + AOTIModelContainerRunnerNpu(const std::string &model_so_path, + size_t num_models = 1, + const std::string &device_str = "npu", + const std::string &cubin_dir = "", + const bool run_single_threaded = false); + + ~AOTIModelContainerRunnerNpu() override; + + std::vector run_impl(std::vector &input_handles, + void *stream_handle) override; + + std::vector run_with_npu_stream( + const std::vector &inputs, + const c10_npu::NPUStream &npu_stream); + void init_proxy_executor(); + + void set_proxy_executor(AOTIProxyExecutorHandle handle); + +private: + std::string model_so_path_; + bool init_flag_; + std::unique_ptr proxy_executor_npu_; +}; + +} // namespace torch::inductor +#endif diff --git a/torch_npu/csrc/inductor/aoti_runner/pybind.h b/torch_npu/csrc/inductor/aoti_runner/pybind.h new file mode 100644 index 0000000000..93855eaab2 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runner/pybind.h @@ -0,0 +1,7 @@ +#include + +namespace torch::inductor { + +void initAOTIRunnerBindings(PyObject *module); + +} // namespace torch::inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/arrayref_tensor.h b/torch_npu/csrc/inductor/aoti_runtime/arrayref_tensor.h new file mode 100644 index 0000000000..5c451645a9 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/arrayref_tensor.h @@ -0,0 +1,327 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace torch::aot_inductor { + +// Can't use c10::ArrayRef because it's not truly header-only and +// pulls in other c10 headers. This is (sadly) copy-pasted and +// adapted. +template +class MiniArrayRef final { + public: + using iterator = T *; + using const_iterator = const T *; + using size_type = size_t; + using value_type = T; + + using reverse_iterator = std::reverse_iterator; + + private: + /// The start of the array, in an external buffer. + T *Data; + + /// The number of elements. + size_type Length; + + public: + /// @name Constructors + /// @{ + + /// Construct an empty MiniArrayRef. + constexpr MiniArrayRef() : Data(nullptr), Length(0) {} + + /// Construct an MiniArrayRef from a single element. + // TODO Make this explicit + constexpr MiniArrayRef(const T &OneElt) : Data(&OneElt), Length(1) {} + + /// Construct an MiniArrayRef from a pointer and length. + constexpr MiniArrayRef(T *data, size_t length) : Data(data), Length(length) {} + + /// Construct an MiniArrayRef from a range. + constexpr MiniArrayRef(T *begin, T *end) : Data(begin), Length(end - begin) {} + + template ().data())>, + T *>>> + MiniArrayRef(Container &container) + : Data(container.data()), Length(container.size()) {} + + /// Construct an MiniArrayRef from a std::vector. + // The enable_if stuff here makes sure that this isn't used for + // std::vector, because MiniArrayRef can't work on a std::vector + // bitfield. + template + MiniArrayRef(const std::vector &Vec) + : Data(Vec.data()), Length(Vec.size()) { + static_assert(!std::is_same_v, + "MiniArrayRef cannot be constructed from a " + "std::vector bitfield."); + } + + /// Construct an MiniArrayRef from a std::array + template + constexpr MiniArrayRef(std::array &Arr) : Data(Arr.data()), Length(N) {} + + /// Construct an MiniArrayRef from a C array. + template + // NOLINTNEXTLINE(*c-array*) + constexpr MiniArrayRef(T (&Arr)[N]) : Data(Arr), Length(N) {} + + // /// Construct an MiniArrayRef from an empty C array. + constexpr MiniArrayRef(const volatile void *Arr) : Data(nullptr), Length(0) {} + + /// Construct an MiniArrayRef from a std::initializer_list. + constexpr MiniArrayRef(const std::initializer_list &Vec) + : Data(std::begin(Vec) == std::end(Vec) ? static_cast(nullptr) + : std::begin(Vec)), + Length(Vec.size()) {} + + /// @} + /// @name Simple Operations + /// @{ + + constexpr iterator begin() const { return Data; } + constexpr iterator end() const { return Data + Length; } + + // These are actually the same as iterator, since MiniArrayRef only + // gives you const iterators. + constexpr const_iterator cbegin() const { return Data; } + constexpr const_iterator cend() const { return Data + Length; } + + constexpr reverse_iterator rbegin() const { return reverse_iterator(end()); } + constexpr reverse_iterator rend() const { return reverse_iterator(begin()); } + + /// empty - Check if the array is empty. + constexpr bool empty() const { return Length == 0; } + + constexpr T *data() const { return Data; } + + /// size - Get the array size. + constexpr size_t size() const { return Length; } + + /// equals - Check for element-wise equality. + constexpr bool equals(MiniArrayRef RHS) const { + return Length == RHS.Length && std::equal(begin(), end(), RHS.begin()); + } + + /// @} + /// @name Operator Overloads + /// @{ + constexpr const T &operator[](size_t Index) const { return Data[Index]; } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, MiniArrayRef> &operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U &&Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, MiniArrayRef> &operator=( + std::initializer_list) = delete; +}; + +using MiniIntArrayRef = MiniArrayRef; + +static_assert(sizeof(MiniIntArrayRef) == sizeof(void *) + sizeof(size_t), + "changing the size of MiniArrayRef breaks ABI compatibility!"); + +inline bool is_contiguous_strides_for_shape(int64_t ndim, + const int64_t *strides_ptr, + const int64_t *sizes_ptr) { + int64_t z = 1; + for (int64_t d = ndim - 1; d >= 0; d--) { + const auto &size_d = sizes_ptr[d]; + if (size_d != 1) { + if (strides_ptr[d] == z) { + z *= size_d; + } else { + return false; + } + } + } + return true; +} + +// Shim for AOTI generated code to pretend a raw array works like an +// AtenTensorHandle. +template +class ArrayRefTensor { + public: + ArrayRefTensor() = default; + + explicit ArrayRefTensor(MiniArrayRef arr, + MiniArrayRef sizes, + MiniArrayRef strides, + int32_t device_type, int32_t device_idx) + : arrayRef_(arr), + sizes_(sizes), + strides_(strides), + device_type_(device_type), + device_idx_(device_idx) { + assert(sizes.size() == strides.size()); + assert(is_contiguous_strides_for_shape(sizes.size(), strides.data(), + sizes.data())); + } + + AtenTensorHandle expensiveCopyToTensor() const { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_empty_strided(sizes_.size(), sizes_.data(), strides_.data(), + aoti_torch_dtype>(), + device_type_, device_idx_, &result)); + void *dataPtr = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(result, &dataPtr)); + std::memcpy(dataPtr, data(), numel() * sizeof(T)); + return result; + } + + // We need to look the same as RAIIAtenTensorHandle, which returns + // an owning AtenTensorHandle from release(). So, we allocate one! + AtenTensorHandle release() { return expensiveCopyToTensor(); } + + AtenTensorHandle borrowAsTensor() const { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_v2( + data(), sizes_.size(), sizes_.data(), strides_.data(), 0, + aoti_torch_dtype>(), device_type_, device_idx_, + &result, aoti_torch_layout_strided(), nullptr, 0)); + return result; + } + + // We don't need to free any memory. + void reset() {} + + auto sizes() const { return sizes_; } + + auto strides() const { return strides_; } + + auto device_type() const { return device_type_; } + + auto device_idx() const { return device_idx_; } + + T *data() const { return arrayRef_.data(); } + + auto numel() const { return arrayRef_.size(); } + + void set_arrayref(MiniArrayRef new_arrayref) { arrayRef_ = new_arrayref; } + + private: + MiniArrayRef arrayRef_; + // We expect generated code to have statically available sizes & + // strides for us. + MiniArrayRef sizes_; + MiniArrayRef strides_; + int32_t device_type_ = 0; + int32_t device_idx_ = 0; + // We continue to zero-initialize this field in case we repurpose + // the space later; having predictable contents can only help. + int32_t unusedDoNotRemoveForABICompatibility_ = 0; +}; + +static_assert(sizeof(ArrayRefTensor) == + 3 * sizeof(MiniIntArrayRef) + 3 * sizeof(int32_t) + + (alignof(ArrayRefTensor) > 4 ? sizeof(int32_t) : 0), + "changing the size of ArrayRefTensor breaks ABI compatibility!"); + +template +inline ArrayRefTensor reinterpret_tensor_wrapper( + const ArrayRefTensor &self, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset) { + // REVIEW: we should add a way to build the DSO in debug mode during + // tests so we can have checks like this! + assert(is_contiguous_strides_for_shape(ndim, strides_ptr, sizes_ptr)); + return ArrayRefTensor(MiniArrayRef(self.data() + storage_offset, + self.numel() - storage_offset), + MiniArrayRef(sizes_ptr, ndim), + MiniArrayRef(strides_ptr, ndim), + self.device_type(), self.device_idx()); +} + +template +inline T *get_data_ptr_wrapper(ArrayRefTensor &tensor) { + return tensor.data(); +} + +template +inline T *get_data_ptr_wrapper(const MiniArrayRef &arr) { + return arr.data(); +} + +template +inline const ArrayRefTensor &unwrap_raii_handle_if_needed( + const ArrayRefTensor &tensor) { + return tensor; +} + +template +inline ArrayRefTensor &unwrap_raii_handle_if_needed( + ArrayRefTensor &tensor) { + return tensor; +} + +template +inline const ArrayRefTensor &wrap_with_raii_handle_if_needed( + const ArrayRefTensor &tensor) { + return tensor; +} + +template +inline ArrayRefTensor &wrap_with_raii_handle_if_needed( + ArrayRefTensor &tensor) { + return tensor; +} + +template +inline ArrayRefTensor wrap_with_raii_handle_if_needed( + ArrayRefTensor &&tensor) { + return std::move(tensor); +} + +template +inline RAIIAtenTensorHandle expensive_copy_to_tensor_if_needed( + const ArrayRefTensor &tensor) { + return tensor.expensiveCopyToTensor(); +} + +inline AtenTensorHandle expensive_copy_to_tensor_if_needed( + AtenTensorHandle handle) { + return handle; +} + +template +const T ©_arrayref_tensor_to_tensor(const T &t) { + return t; +} + +template +RAIIAtenTensorHandle copy_arrayref_tensor_to_tensor( + const ArrayRefTensor &art) { + return art.expensiveCopyToTensor(); +} + +template +const T &borrow_arrayref_tensor_as_tensor(const T &t) { + return t; +} + +template +RAIIAtenTensorHandle borrow_arrayref_tensor_as_tensor( + const ArrayRefTensor &art) { + return art.borrowAsTensor(); +} + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/device_utils.h b/torch_npu/csrc/inductor/aoti_runtime/device_utils.h new file mode 100644 index 0000000000..ed0ae6b103 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/device_utils.h @@ -0,0 +1,43 @@ +#pragma once + +#if defined(USE_NPU) + +#include "third_party/acl/inc/acl/acl_base.h" +#include "third_party/acl/inc/acl/acl_rt.h" + +typedef void *NPUdeviceptr; + +typedef void *NPUfunction; + +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ + do { \ + const aclError code = EXPR; \ + if (code != ACL_SUCCESS) { \ + throw std::runtime_error(std::string("NPU error core: ") + \ + std::to_string(code) + std::string(" ") + \ + std::string(__FILE__) + std::string(":") + \ + std::to_string(__LINE__)); \ + } \ + } while (0) + +namespace torch::aot_inductor { + +using DeviceStreamType = aclrtStream; + +} // namespace torch::aot_inductor + +#else + +#define AOTI_RUNTIME_DEVICE_CHECK(EXPR) \ + bool ok = EXPR; \ + if (!ok) { \ + throw std::runtime_error("CPU runtime error"); \ + } + +namespace torch::aot_inductor { + +using DeviceStreamType = void *; + +} // namespace torch::aot_inductor + +#endif // USE_NPU diff --git a/torch_npu/csrc/inductor/aoti_runtime/interface.h b/torch_npu/csrc/inductor/aoti_runtime/interface.h new file mode 100644 index 0000000000..4ab1485945 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/interface.h @@ -0,0 +1,183 @@ +#pragma once + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +extern "C" { +struct AOTInductorModelOpaque; +using AOTInductorModelHandle = AOTInductorModelOpaque *; + +struct AOTInductorModelContainerOpaque; +using AOTInductorModelContainerHandle = AOTInductorModelContainerOpaque *; + +struct AOTInductorStreamOpaque; +using AOTInductorStreamHandle = AOTInductorStreamOpaque *; + +struct AOTInductorConstantMap; +using AOTInductorConstantMapHandle = AOTInductorConstantMap *; + +// TODO: Deprecate this API. This was kept for BC compatibility. +// Please use AOTInductorModelContainerCreateWithDevice instead. +AOTIRuntimeError AOTInductorModelContainerCreate( + AOTInductorModelContainerHandle *container_handle, size_t num_models, + bool is_cpu, const char *cubin_dir); + +// Creates an AOTInductor model container. The parameter num_models +// specifies the number of model instances that may be run concurrently for +// the same input model. +// `device_str` MUST NOT be nullptr. It must be a valid device string, e.g. +// "cpu", "npu", "npu:0", etc. +AOTIRuntimeError AOTInductorModelContainerCreateWithDevice( + AOTInductorModelContainerHandle *container_handle, size_t num_models, + const char *device_str, const char *cubin_dir); + +// Deletes the AOTInductor model container. +AOTIRuntimeError AOTInductorModelContainerDelete( + AOTInductorModelContainerHandle container_handle); + +// Runs the inference. +AOTIRuntimeError AOTInductorModelContainerRun( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Single-threaded variant of previous. +AOTIRuntimeError AOTInductorModelContainerRunSingleThreaded( + AOTInductorModelContainerHandle container_handle, + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + size_t num_inputs, + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; handles + // will be stolen by the caller; the array itself is + // borrowed + size_t num_outputs, AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Retrieves the number of constants for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumConstants( + AOTInductorModelContainerHandle container_handle, size_t *num_constants); + +// Retrieves a constant's name. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantName( + AOTInductorModelContainerHandle container_handle, size_t idx, + const char **name); + +// Retrieves a constant's original FQN. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantOriginalFQN( + AOTInductorModelContainerHandle container_handle, size_t idx, + const char **original_fqn); + +// Retrieves whether a constant is from folded. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantFromFolded( + AOTInductorModelContainerHandle container_handle, size_t idx, + bool *from_folded); + +// Retrieves the inductor constant type. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantType( + AOTInductorModelContainerHandle container_handle, size_t idx, + int32_t *type); + +// Retrieves a constant's dtype. +// idx is the index of the internal's constants. +// Need idx < num_constants from AOTInductorModelContainerGetNumConstants +AOTIRuntimeError AOTInductorModelContainerGetConstantDtype( + AOTInductorModelContainerHandle container_handle, size_t idx, + int32_t *dtype); + +// Setup the constant buffer in model container with provided ConstantMap +// use_inactive should be set as true if the inactive buffer is to be updated. +// validate_full_update checks if all constants are included in the ConstantMap +AOTIRuntimeError AOTInductorModelContainerUpdateConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle, bool use_inactive, + bool validate_full_update); + +// Setup the inactive constant buffer in model container with provided +// ConstantMap +AOTIRuntimeError AOTInductorModelContainerUpdateInactiveConstantBuffer( + AOTInductorModelContainerHandle container_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Run constant folding on constant buffer. +AOTIRuntimeError AOTInductorModelContainerRunConstantFolding( + AOTInductorModelContainerHandle container_handle, bool use_inactive, + AOTInductorStreamHandle stream_handle, + AOTIProxyExecutorHandle proxy_executor_handle); + +// Swap the constant buffer being used to the inactive one. +AOTIRuntimeError AOTInductorModelContainerSwapConstantBuffer( + AOTInductorModelContainerHandle container_handle); + +// Retrieves the number of inputs for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumInputs( + AOTInductorModelContainerHandle container_handle, size_t *ret_num_inputs); + +// Retrieves the input name at the given index. +AOTIRuntimeError AOTInductorModelContainerGetInputName( + AOTInductorModelContainerHandle container_handle, size_t input_idx, + const char **ret_input_names); + +// Retrieves the number of outputs for the model. +AOTIRuntimeError AOTInductorModelContainerGetNumOutputs( + AOTInductorModelContainerHandle container_handle, size_t *ret_num_outputs); + +// Retrieves the output name at the given index. +AOTIRuntimeError AOTInductorModelContainerGetOutputName( + AOTInductorModelContainerHandle container_handle, size_t output_idx, + const char **ret_output_names); + +// Creates an AOTInductorModel instance. This is a thin and light wrapper +// around the compiled model; it doesn't handle concurrency, queueing, device +// management, etc. Use this if bare-metal performance is needed and you are +// willing to handle other "management" aspects yourself. +// +// constant_map_handle is an opaque type to satisfy the C ABI. It should be a +// std::unordered_map*. +AOTIRuntimeError AOTInductorModelCreate( + AOTInductorModelHandle *model_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Run an AOTInductorModel (see AOTInductorModelCreate for when one should use +// this function versus AOTInductorModelContainerRun). +AOTIRuntimeError AOTInductorModelRun(AOTInductorModelHandle model_handle, + AtenTensorHandle *input_handles, + AtenTensorHandle *output_handles); + +// Replace AOTInductorModel's constant map. Note it doesn't handle concurrency +// so be sure to handle ordering if AOTInductorModelRun is ran concurrently. +AOTIRuntimeError AOTInductorModelUpdateConstantsMap( + AOTInductorModelHandle model_handle, + AOTInductorConstantMapHandle constant_map_handle); + +// Delete an AOTInductorModel created by AOTInductorModelCreate. +AOTIRuntimeError AOTInductorModelDelete(AOTInductorModelHandle model_handle); + +AOTIRuntimeError AOTInductorModelGetNumOutputs( + AOTInductorModelHandle model_handle, size_t *ret_num_outputs); + +AOTIRuntimeError AOTInductorModelContainerGetCallSpec( + AOTInductorModelContainerHandle container_handle, const char **in_spec, + const char **out_spec); + +} // extern "C" diff --git a/torch_npu/csrc/inductor/aoti_runtime/model.h b/torch_npu/csrc/inductor/aoti_runtime/model.h new file mode 100644 index 0000000000..4b0e580151 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/model.h @@ -0,0 +1,592 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include +#include + +#define AOTI_RUNTIME_CHECK(EXPR, MSG) \ + do { \ + bool ok = EXPR; \ + if (!ok) { \ + throw std::runtime_error(MSG); \ + } \ + } while (0) + +// At codegen time, we write out a binary file called constants.bin. +// We then turn the raw binary to an object file that exposes this +// symbol and link it into the final .so. +// The constants are NOT readonly because they may be mutated. +// NOLINTNEXTLINE(*array*) +extern uint8_t _binary_constants_bin_start[]; +// NOLINTNEXTLINE(*array*) +extern uint8_t _binary_constants_bin_end[]; + +#define AOTI_CONST_ALIGNMENT 64 + +namespace { + +using RAIIDataPtr = std::unique_ptr>; + +#ifdef USE_NPU + +using RAIIDataPtr = std::unique_ptr>; + +RAIIDataPtr RAII_npuMalloc(size_t num_bytes) { + void *data_ptr; + // aclrtMalloc doesn't support allocate 0-bytes. In this case, + // e.g, model has no weight, we should do padding. + size_t padding_bytes = 32; + if (num_bytes == 0) num_bytes = padding_bytes; + AOTI_RUNTIME_DEVICE_CHECK( + aclrtMalloc((void **)&data_ptr, num_bytes, ACL_MEM_MALLOC_HUGE_FIRST)); + auto deleter = [](void *ptr) { AOTI_RUNTIME_DEVICE_CHECK(aclrtFree(ptr)); }; + return RAIIDataPtr(data_ptr, deleter); +} + +#endif // USE_NPU + +RAIIDataPtr RAII_cpuMalloc(size_t num_bytes) { + void *data_ptr = std::malloc(num_bytes); + if (!data_ptr) { + throw std::bad_alloc(); + } + auto deleter = [](void *ptr) { std::free(ptr); }; + return RAIIDataPtr(data_ptr, deleter); +} + +} // anonymous namespace + +namespace torch::aot_inductor { +enum ConstantType : uint8_t { + Unknown = 0, + Parameter = 1, + Buffer = 2, + TensorConstant = 3, + FoldedConstant = 4, +}; + +using ConstantMap = std::unordered_map; + +// valid device strs are: cpu, npu, npu:0, npu:1, ... +// Update the list here if more devices are supported in the future +inline void parse_device_str(const std::string &device_str, + int32_t &device_type, int32_t &device_idx) { + std::regex re("(cpu|npu)(:([0-9]+))?"); + std::smatch sm; + bool matched = std::regex_match(device_str, sm, re); + AOTI_RUNTIME_CHECK(matched, "Invalid device: " + device_str); + + if (sm[1].str() == "cpu") { + device_type = aoti_torch_device_type_cpu(); +#ifdef USE_NPU + } else if (sm[1].str() == "npu") { + device_type = aoti_torch_device_type_npu(); +#endif + } else { + AOTI_RUNTIME_CHECK(false, "Invalid device: " + device_str); + } + const size_t default_sm = 3; + if (sm[default_sm].matched) { + device_idx = stoi(sm[default_sm].str()); + } else { + device_idx = -1; + } +} + +// Defines the base class for AOTInductorModel, which is generated by the +// AOTInductor cpp codegen. Since we do not need dynamic dispatch, we rely +// on curiously recurring template pattern (CRTP) to save some runtime +// v-table overhead. The generated AOTInductorModel is specialized with +// methods such as run_impl. +template +class AOTInductorModelBase { + public: + AOTInductorModelBase(size_t num_inputs, size_t num_outputs, + size_t num_constants, const std::string &device_str, + std::optional cubin_dir, + bool include_weights = true) + : inputs_info_(num_inputs), + outputs_info_(num_outputs), + constants_info_(num_constants), + cubin_dir_(std::move(cubin_dir)), + include_weights(include_weights) { + parse_device_str(device_str, device_type_, device_idx_); + +#ifdef USE_NPU + if (device_idx_ == -1) { + AOTI_RUNTIME_DEVICE_CHECK(aclrtSetDevice(0)); + AOTI_RUNTIME_DEVICE_CHECK(aclrtGetDevice(&device_idx_)); + } else { + AOTI_RUNTIME_DEVICE_CHECK(aclrtSetDevice(device_idx_)); + } +#endif // USE_NPU + } + + // NOLINTNEXTLINE(modernize-use-equals-default) + ~AOTInductorModelBase() { +#ifdef USE_NPU + if (run_finished_) { + auto code = aclrtDestroyEvent(*run_finished_); + if (code != ACL_SUCCESS) { + std::cerr + << "Failed to destroy NPU event in AOTInductor model erorr code: " + << code << std::endl; + } + } +#endif // USE_NPU + } + + AOTInductorModelBase(AOTInductorModelBase &&) = delete; + AOTInductorModelBase &operator=(AOTInductorModelBase &&) = delete; + AOTInductorModelBase(const AOTInductorModelBase &) = delete; + AOTInductorModelBase &operator=(const AOTInductorModelBase &) = delete; + + void run(AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the + // array itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor) { +#if defined(USE_NPU) + if (!run_finished_) { + aclrtEvent run_finished; + AOTI_RUNTIME_DEVICE_CHECK(aclrtCreateEvent(&run_finished)); + run_finished_.emplace(run_finished); + } +#else + run_finished_ = false; +#endif + + auto *model = static_cast(this); + model->run_impl(input_handles, output_handles, stream, proxy_executor); + +#if defined(USE_NPU) + AOTI_RUNTIME_DEVICE_CHECK(aclrtRecordEvent(*run_finished_, stream)); +#else + run_finished_ = true; +#endif + } + + // Non-thread-aware variant of run(). Obviously unsafe to use in a threaded + // environment :) + void run_single_threaded( + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the array + // itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor) { + // don't bother with any of the run_finished stuff; this is unsafe to call + // in a threaded context + auto *model = static_cast(this); + model->run_impl(input_handles, output_handles, stream, proxy_executor); + } + + std::unordered_map run_const_fold( + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor, + bool initialization = false) { +#if defined(USE_NPU) + if (!run_finished_) { + aclrtEvent run_finished; + AOTI_RUNTIME_DEVICE_CHECK(aclrtCreateEvent(&run_finished)); + run_finished_.emplace(run_finished); + } +#else + run_finished_ = false; +#endif + + auto *model = static_cast(this); + auto folded_constants = + model->const_run_impl(stream, proxy_executor, initialization); + +#if defined(USE_NPU) + AOTI_RUNTIME_DEVICE_CHECK(aclrtRecordEvent(*run_finished_, stream)); +#else + run_finished_ = true; +#endif + return folded_constants; + } + + void load_constants() { + size_t num_constants = this->num_constants(); + constants_map_->reserve(num_constants); + + std::vector constants_internal_offset(num_constants); + size_t blob_size = 0; + compute_constant_blob(blob_size, constants_internal_offset); +#if defined(USE_NPU) + constant_blob_ = RAII_npuMalloc(blob_size); +#else + constant_blob_ = RAII_cpuMalloc(blob_size); +#endif + if (!include_weights) { + return; + } + + size_t bytes_read = 0; + for (size_t i = 0; i < num_constants; i++) { + bool from_folded = this->constant_from_folded(i); + if (from_folded) { + continue; + } + std::string name = this->constant_name(i); + size_t data_size = this->constant_data_size(i); + uint8_t *internal_ptr = + (data_size != 0) ? constant_ptr(constants_internal_offset[i], + bytes_read, data_size, from_folded) + : nullptr; + bytes_read += data_size; + + // Create at::Tensor from copied memory. + auto dtype = this->constant_dtype(i); + auto ndim = this->constant_ndim(i); + auto size = this->constant_shape(i); + auto stride = this->constant_stride(i); + auto offset = this->constant_offset(i); + auto layout = this->constant_layout(i); + auto opaque_metadata_ptr = this->opaque_metadata(i); + auto opaque_metadata_size = this->opaque_metadata_size(i); + + AtenTensorHandle tensor_handle = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu_v2( + internal_ptr, ndim, size, stride, offset, dtype, device_type_, + device_idx_, &tensor_handle, layout, opaque_metadata_ptr, + opaque_metadata_size)); + constants_map_->emplace(std::move(name), tensor_handle); + } + if (constants_map_) { + this->update_constants_array_from_map(); + } + } + + RAIIDataPtr &&release_constant_blob() { return std::move(constant_blob_); } + + std::shared_ptr> get_constants_array() { + return constants_; + } + + int32_t get_device_type() const { return device_type_; } + + int32_t get_device_idx() const { return device_idx_; } + + uint8_t *constant_ptr(size_t constant_offset, size_t bytes_read, + size_t data_size, bool skip_copy) { + auto *constants_ptr = static_cast(constant_blob_.get()); + uint8_t *internal_ptr = constants_ptr + constant_offset; + // TODO: Handle shared storage case. + if (!skip_copy) { +#if defined(USE_NPU) + AOTI_RUNTIME_DEVICE_CHECK(aclrtMemcpy( + internal_ptr, data_size, _get_constants_start() + bytes_read, + data_size, ACL_MEMCPY_HOST_TO_DEVICE)); +#else + memcpy(internal_ptr, _get_constants_start() + bytes_read, data_size); +#endif + } + return internal_ptr; + } + + void compute_constant_blob(size_t &blob_size, + std::vector &constants_internal_offset) { + size_t num_constants = this->num_constants(); + blob_size = 0; + for (size_t i = 0; i < num_constants; i++) { + size_t data_size = this->constant_data_size(i); + if (data_size % AOTI_CONST_ALIGNMENT) { + data_size = AOTI_CONST_ALIGNMENT + + (data_size / AOTI_CONST_ALIGNMENT) * AOTI_CONST_ALIGNMENT; + } + constants_internal_offset[i] = blob_size; + blob_size += data_size; + } + } + + size_t num_inputs() const { return inputs_info_.size(); } + + size_t num_outputs() const { return outputs_info_.size(); } + + size_t num_constants() const { return constants_info_.size(); } + + const char *input_name(int64_t idx) const { + return inputs_info_.at(idx).name; + } + + const char *output_name(int64_t idx) const { + return outputs_info_.at(idx).name; + } + + const char *constant_name(int64_t idx) const { + return constants_info_.at(idx).name; + } + + size_t constant_ndim(int64_t idx) { + return constants_info_.at(idx).shape.size(); + } + + const int64_t *constant_shape(int64_t idx) const { + return constants_info_.at(idx).shape.data(); + } + + const int64_t *constant_stride(int64_t idx) const { + return constants_info_.at(idx).stride.data(); + } + + int32_t constant_dtype(int64_t idx) const { + return constants_info_.at(idx).dtype; + } + + int32_t constant_layout(int64_t idx) const { + return constants_info_.at(idx).layout; + } + + size_t constant_offset(int64_t idx) const { + return constants_info_.at(idx).offset; + } + + size_t constant_data_size(int64_t idx) const { + return constants_info_.at(idx).data_size; + } + + const char *constant_original_fqn(int64_t idx) const { + return constants_info_.at(idx).original_fqn; + } + + const uint8_t *opaque_metadata(int64_t idx) const { + return constants_info_.at(idx).opaque_metadata.data(); + } + + size_t opaque_metadata_size(int64_t idx) { + return constants_info_.at(idx).opaque_metadata.size(); + } + + bool constant_from_folded(int64_t idx) const { + return constants_info_.at(idx).from_folded; + } + + int32_t constant_type(int64_t idx) const { + return constants_info_.at(idx).type; + } + + const char *get_in_spec() const { return in_spec_.c_str(); } + + const char *get_out_spec() const { return out_spec_.c_str(); } + + void update_constants_array_from_map() { + if (!constants_map_) { + throw std::runtime_error{ + "constants_map_ was not ready when constants_ is trying to be " + "constructed from it!"}; + } + if (!constants_) { + constants_ = + std::make_shared>(constants_info_.size()); + } else { + constants_->resize(constants_info_.size()); + } + int idx = 0; + for (const auto &info : constants_info_) { + const auto it = constants_map_->find(info.name); + if (it != constants_map_->end()) { + constants_->at(idx) = ConstantHandle(it->second); + } + idx++; + } + } + + void update_constants_map(std::shared_ptr constants_map, + bool remap_constants_array = true) { + constants_map_ = std::move(constants_map); + if (remap_constants_array) { + update_constants_array_from_map(); + } + } + + // This function allows us to update the constants_ that is used to look up + // the corresponding constant tensor during runtime. + void update_constants_array( + std::shared_ptr> constants_array) { + constants_ = std::move(constants_array); + } + + /// Returns true if the model is complete. + bool is_finished() { +#if defined(USE_NPU) + if (!run_finished_) { + throw std::runtime_error{"Model NPU event was not initialized"}; + } + aclrtEventRecordedStatus recordStatus = ACL_EVENT_RECORDED_STATUS_NOT_READY; + AOTI_RUNTIME_DEVICE_CHECK( + aclrtQueryEventStatus(*run_finished_, &recordStatus)); + + if (recordStatus == ACL_EVENT_RECORDED_STATUS_COMPLETE) { + return true; + } else { + return false; + } +#else + return run_finished_; +#endif + } + + /// Synchronizes completion event. + void wait_for_completion() {} + + protected: + uint8_t *_get_constants_start() { +#ifndef USE_MMAP_SELF + // NOLINTNEXTLINE(*const-cast*) + return const_cast(_binary_constants_bin_start); +#else + if (self_mmap) { + return self_mmap; + } + Dl_info dl_info; + // get pointer to constant which are appended to the binary + AOTI_RUNTIME_CHECK(dladdr(__func__, &dl_info), + "Can't find shared library name"); + int fd = open(dl_info.dli_fname, O_RDONLY); + AOTI_RUNTIME_CHECK(fd >= 0, "Shared library file cannot be opened"); + auto fsize = lseek(fd, 0, SEEK_END); + auto weights_size = + reinterpret_cast(_binary_constants_bin_start)[0]; + auto magic_number = + reinterpret_cast(_binary_constants_bin_start)[1]; + auto weights_offset = fsize - weights_size; + AOTI_RUNTIME_CHECK((weights_offset & 0x3fff) == 0, + "weights_offset must be aligned to 16K boundary"); + auto ptr = mmap(NULL, weights_size, PROT_READ | PROT_WRITE, MAP_PRIVATE, fd, + weights_offset); + close(fd); + AOTI_RUNTIME_CHECK(ptr != MAP_FAILED, "mmap() failed"); + self_mmap = static_cast(ptr); + AOTI_RUNTIME_CHECK( + reinterpret_cast(self_mmap + weights_size - + sizeof(uint64_t))[0] == magic_number, + "Weigths data seems corrupt"); + return self_mmap; +#endif + } + struct ParamInfo { + const char *name = nullptr; + }; + + struct ConstInfo { + const char *name = nullptr; + std::vector shape; + std::vector stride; + int32_t dtype{}; + int64_t offset{}; + size_t data_size{}; + int32_t layout{}; + std::vector opaque_metadata; + int64_t opaque_metadata_size{}; + const char *original_fqn = nullptr; + bool from_folded{}; + int32_t type{}; + }; + + std::vector inputs_info_; + std::vector outputs_info_; + std::vector constants_info_; + std::string in_spec_; + std::string out_spec_; + + std::shared_ptr constants_map_; + std::shared_ptr> constants_; + + // Holds the blob storage for constants' at::Tensor. + RAIIDataPtr constant_blob_; + +#ifdef USE_MMAP_SELF + uint8_t *self_mmap = NULL; +#endif + + // A directory with npu binary files, e.g. compiled kernels, etc. + const std::optional cubin_dir_; + + // This is the flag that implies whether the weight is included in the model. + // If True, we would prepare the weight when loading the model, otherwise the + // model will be loaded without weights, and need to be provided by the user. + bool include_weights; + + // Record if the model finishes an inference run so that its owning + // AOTModelContainer can re-use this instance. +#if defined(USE_NPU) + std::optional run_finished_; +#else + bool run_finished_{}; +#endif + + int32_t device_type_{}; + int32_t device_idx_{}; +}; + +// Codegen-ed classes can derive from this to keep pointers to loaded kernels. +class AOTInductorModelKernelsBase { + public: + virtual ~AOTInductorModelKernelsBase() = default; +}; + +class AOTInductorModel : public AOTInductorModelBase { + public: + AOTInductorModel(std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string &device_str, + std::optional cubin_dir, + bool include_weights = true); + + std::unordered_map const_run_impl( + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor, + bool initialization = false); + + void _const_run_impl(std::vector &output_handles, + DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor); + + void run_impl( + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the array + // itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor); + + template + Outputs run_impl_minimal_arrayref_interface( + const Inputs &inputs, DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor); + + static std::unique_ptr Create( + std::shared_ptr constants_map, + std::shared_ptr> constants_array, + const std::string &device_str, std::optional cubin_dir) { + return std::make_unique(std::move(constants_map), + std::move(constants_array), + device_str, std::move(cubin_dir)); + } + + private: + std::unique_ptr kernels_; +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/model_container.h b/torch_npu/csrc/inductor/aoti_runtime/model_container.h new file mode 100644 index 0000000000..51a5172aff --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/model_container.h @@ -0,0 +1,589 @@ +#pragma once + +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +namespace torch::aot_inductor { + +class AOTInductorModelContainer { + public: + AOTInductorModelContainer( + size_t num_models, const std::string &device_str, + const std::optional &cubin_dir = std::nullopt) { + constants_map_ = std::make_shared(); + constants_array_ = std::make_shared>(); + + models_.reserve(num_models); + available_models_.reserve(num_models); + for (size_t i = 0; i < num_models; ++i) { + models_.push_back(AOTInductorModel::Create( + constants_map_, constants_array_, device_str, cubin_dir)); + available_models_.push_back(models_.back().get()); + } + + // Note that the all following fields (input_names_, output_names, + // etc) can be filled in by the AOT + // codegen. However, we choose to query such information from + // the owned AOTInductorModel for a couple of reasons: + // * simplify the codegen templates + // * reduce information fragmentation and duplication + // * the initialization process below is done only once when the container + // is constructed, so it would have little performance impact + auto *model = available_models_[0]; + size_t num_inputs = model->num_inputs(); + input_names_.reserve(num_inputs); + for (size_t i = 0; i < num_inputs; i++) { + input_names_.emplace_back(model->input_name(static_cast(i))); + } + + size_t num_outputs = model->num_outputs(); + output_names_.reserve(num_outputs); + for (size_t i = 0; i < num_outputs; i++) { + output_names_.emplace_back(model->output_name(static_cast(i))); + } + model->load_constants(); + constant_blob_ = model->release_constant_blob(); + constants_internal_offset_.resize(model->num_constants()); + model->compute_constant_blob(blob_size_, constants_internal_offset_); + + for (auto &model : models_) { + model->update_constants_map(constants_map_); + } + + in_spec_ = model->get_in_spec(); + out_spec_ = model->get_out_spec(); + } + + void run(AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the + // array itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor) { + std::shared_lock model_lk(model_exec_mutex_); + auto *model = get_available_model(); + + if (!constant_folded_) { + // At this point, constant is not ready yet. We need to call constant + // folding before we execute the model. We obtain a unique lock at this + // point to make sure constant is ready for all. + model_lk.unlock(); + std::unique_lock constants_folding_lk(model_exec_mutex_); + // Double locking to make sure constant folding is only ran once. + if (!constant_folded_) { + auto folded_const_map = model->run_const_fold( + stream, proxy_executor, /* initialization = */ true); + update_constant_buffer(std::move(folded_const_map), + /* use_inactive = */ false, + /* validate_full_update = */ false); + constant_folded_ = true; + } + constants_folding_lk.unlock(); + model_lk.lock(); + } + + try { + model->run(input_handles, output_handles, stream, proxy_executor); + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + + { + std::lock_guard lk(models_mutex_); + pending_models_.push_back(model); + } + pending_models_available_.notify_one(); + } + + // Non-thread-aware variant of run(). Obviously unsafe to use in a threaded + // environment :) + void run_single_threaded( + AtenTensorHandle + *input_handles, // array of input AtenTensorHandle; handles + // are stolen; the array itself is borrowed + AtenTensorHandle + *output_handles, // array for writing output AtenTensorHandle; + // handles will be stolen by the caller; the array + // itself is borrowed + DeviceStreamType stream, AOTIProxyExecutorHandle proxy_executor) { + auto *model = available_models_[0]; + + if (!constant_folded_) { + auto folded_const_map = + model->run_const_fold(stream, proxy_executor, true); + update_constant_buffer(std::move(folded_const_map), false, false); + constant_folded_ = true; + } + + model->run_single_threaded(input_handles, output_handles, stream, + proxy_executor); + } + + size_t num_constants() const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->num_constants(); + } + + // retrieve the constant name of constants_info_[idx] + const char *constant_name(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_name(static_cast(idx)); + } + + // retrieve original FQN of constants_info_[idx] + const char *constant_original_fqn(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_original_fqn(static_cast(idx)); + } + + // retrieve whether constant is from folded of constants_info_[idx] + bool constant_from_folded(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_from_folded(static_cast(idx)); + } + + // retrieve type of constants_info_[idx] + int32_t constant_type(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_type(static_cast(idx)); + } + + // retrieve dtype of constants_info_[idx] + int32_t constant_dtype(size_t idx) const { + if (this->num_models() == 0) { + throw std::runtime_error("No available models in container!"); + } + return models_[0]->constant_dtype(static_cast(idx)); + } + + void run_const_fold(bool inactive_buffer, DeviceStreamType stream, + AOTIProxyExecutorHandle proxy_executor) { + std::shared_lock model_lk(model_exec_mutex_); + auto *model = get_available_model(); + + if (!inactive_buffer) { + // We would need to acquire a unique lock if we want to run constant + // folding on the active buffer. + model_lk.unlock(); + std::unique_lock constants_folding_lk(model_exec_mutex_); + try { + auto folded_const_map = model->run_const_fold(stream, proxy_executor); + update_constant_buffer(std::move(folded_const_map), + /* use_inactive = */ false, + /* validate_full_update = */ false); + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + constants_folding_lk.unlock(); + model_lk.lock(); + } else { + // We swap the constant mapping to the inactive buffer in the model to run + // const run. + auto constants_map = get_constants_map(/* get_inactive= */ true); + auto constants_array = get_constants_array(/* get_inactive= */ true); + + try { + model->update_constants_map(constants_map, + /* remap_constants_array= */ false); + model->update_constants_array(constants_array); + + auto folded_const_map = model->run_const_fold(stream, proxy_executor); + update_constant_buffer(std::move(folded_const_map), + /* use_inactive = */ true, + /* validate_full_update = */ false); + + // Swap back the model's constants mapping + constants_map = get_constants_map(/* get_inactive= */ false); + constants_array = get_constants_array(/* get_inactive= */ false); + model->update_constants_map(constants_map, + /* remap_constants_array= */ false); + model->update_constants_array(constants_array); + } catch (...) { + std::lock_guard lk(models_mutex_); + available_models_.push_back(model); + throw; + } + } + + { + std::lock_guard lk(models_mutex_); + pending_models_.push_back(model); + } + pending_models_available_.notify_one(); + } + + bool _should_skip_update(const size_t idx) const { + auto constant_type = models_[0]->constant_type(static_cast(idx)); + // We should skip constants + return constant_type == ConstantType::TensorConstant; + } + + bool _could_skip_update(const size_t idx) const { + auto constant_type = models_[0]->constant_type(static_cast(idx)); + // Buffer can be optionally skipped, so if it not provided by upstream + // services, it is OK to relax the check. + return constant_type == ConstantType::Buffer; + } + + void assert_all_constants( + const std::unordered_map &constants_map) { + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + if (models_[0]->constant_from_folded(static_cast(idx))) { + continue; + } + + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end()) { + if (_should_skip_update(idx) || _could_skip_update(idx)) { + // tracing sometimes creates tensors that are non-existent in + // original graph. We could skip those and do a direct copy. + std::cerr << "[WARNING] Found constant or module state buffer " + << constant_name + << " in model, but not provided by user!\n"; + continue; + } + throw std::runtime_error(std::string("Cannot find constants ") + + constant_name + + std::string(" in constants_map!")); + } + } + } + + // We directly take ownership from AtenTensorHandle if constants are moved. + void update_constant_buffer( + std::unordered_map &&constants_map, + bool use_inactive, bool validate_full_update) { + if (this->num_models() == 0) { + throw std::runtime_error("No model available in container!"); + } + if (validate_full_update) { + assert_all_constants(constants_map); + } + + auto original_constants_map = get_constants_map(!use_inactive); + auto constants_map_to_update = get_constants_map(use_inactive); + + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end() && !use_inactive) { + continue; + } + + AtenTensorHandle tensor; + if (it == constants_map.end() && use_inactive) { + aoti_torch_clone( + original_constants_map->find(constant_name)->second.get(), &tensor); + } else { + tensor = it->second; + } + + constants_map_to_update->insert_or_assign(constant_name, tensor); + } + // Update the inactive constant array. + update_array_from_map(get_constants_array(use_inactive), + constants_map_to_update); + } + + // This function updates the buffer for storing constants. + // It will update the buffer, the mapping and the array mapping. + void update_constant_buffer( + const std::unordered_map &constants_map, + bool use_inactive, bool validate_full_update) { + if (this->num_models() == 0) { + throw std::runtime_error("No model available in container!"); + } + if (validate_full_update) { + assert_all_constants(constants_map); + } + + auto original_constants_map = get_constants_map(!use_inactive); + auto constants_map_to_update = get_constants_map(use_inactive); + + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + auto constant_name = + std::string(models_[0]->constant_name(static_cast(idx))); + auto it = constants_map.find(constant_name); + if (it == constants_map.end() && !use_inactive) { + continue; + } + + AtenTensorHandle tensor; + if (it == constants_map.end() && use_inactive) { + tensor = original_constants_map->find(constant_name)->second.get(); + } else { + tensor = it->second; + } + auto *constants_blob_ptr = + static_cast(get_constant_blob_ptr(use_inactive)); + + // Move the data to container handled blob. + uint8_t *internal_constants_ptr = + constants_blob_ptr + constants_internal_offset_[idx]; + void *user_constant_ptr; + int64_t constant_size; + aoti_torch_get_data_ptr(tensor, &user_constant_ptr); + aoti_torch_get_storage_size(tensor, &constant_size); +#if defined(USE_NPU) + AOTI_RUNTIME_DEVICE_CHECK( + aclrtMemcpy(internal_constants_ptr, constant_size, user_constant_ptr, + constant_size, ACL_MEMCPY_HOST_TO_DEVICE)); +#else + memcpy(internal_constants_ptr, user_constant_ptr, constant_size); +#endif + // Generate Tensor from container handled blob. + // We extract stride and offset from provided Tensor since we do not + // guarantee that the tensor is contiguous. + AtenTensorHandle tensor_handle; + int64_t *stride; + int64_t offset; + int device_type = models_[0]->get_device_type(); + int device_idx = models_[0]->get_device_idx(); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(tensor, &stride)); + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(tensor, &offset)); + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu( + internal_constants_ptr, models_[0]->constant_ndim(idx), + models_[0]->constant_shape(idx), stride, offset, + models_[0]->constant_dtype(idx), device_type, device_idx, + &tensor_handle)); + + // Now place the tensor to constants_map. Note at this point the ownership + // of the tensor_handle will be taken over. + constants_map_to_update->insert_or_assign(constant_name, tensor_handle); + } + // Update the inactive constant array. + update_array_from_map(get_constants_array(use_inactive), + constants_map_to_update); + } + + void update_array_from_map( + const std::shared_ptr> &constants_array, + const std::shared_ptr &constants_map) { + auto num_constants = models_[0]->num_constants(); + for (size_t idx = 0; idx < num_constants; idx++) { + if (constants_map->find(models_[0]->constant_name( + static_cast(idx))) != constants_map->end()) { + constants_array->at(idx) = ConstantHandle( + constants_map + ->find(models_[0]->constant_name(static_cast(idx))) + ->second); + } + } + } + + void swap_constant_buffer() { + std::lock_guard unique_lk(model_exec_mutex_); + + auto constants_map = get_constants_map(/* get_inactive= */ true); + auto constants_array = get_constants_array(/* get_inactive= */ true); + + for (auto &model : models_) { + model->update_constants_map(constants_map, + /* remap_constants_array = */ false); + model->update_constants_array(constants_array); + } + + use_secondary_ = !use_secondary_; + } + + size_t num_inputs() const { return input_names_.size(); } + + size_t num_outputs() const { return output_names_.size(); } + + const char *input_name(size_t idx) const { + return input_names_.at(idx).c_str(); + } + + const char *output_name(size_t idx) const { + return output_names_.at(idx).c_str(); + } + + size_t num_models() const { return models_.size(); } + + const char *get_in_spec() const { return in_spec_; } + + const char *get_out_spec() const { return out_spec_; } + + private: + std::vector input_names_; + std::vector output_names_; + const char *in_spec_; + const char *out_spec_; + + // Holds the blob storage for constants' at::Tensor within the container. + // This blob of memory will be managed by the container. + RAIIDataPtr constant_blob_; + RAIIDataPtr constant_blob_secondary_; + + size_t blob_size_; + std::vector constants_internal_offset_; + + // Determine which constants is being used for the model. + // If true, + // constants_map_secondary/constant_blob_secondary/constants_array_secondary + // is being used. + bool use_secondary_{false}; + + // Determine whether we have ran constant folding + bool constant_folded_{false}; + + // Holds the mapping of constants to at::Tensor. + std::shared_ptr constants_map_; + std::shared_ptr constants_map_secondary_; + + // Holds the indexed array of constant for faster lookup during runtime. + std::shared_ptr> constants_array_; + std::shared_ptr> constants_array_secondary_; + + // Holds all the AOTInductorModel instances owned by this container. + std::vector> models_; + + // Holds the AOTInductorModel instances available for inference. + std::vector available_models_; + + // Holds the AOTInductorModel instances that have started running + // inference and can be placed onto available_models_ upon their + // completion. + std::deque pending_models_; + + // Protects available_models_ and pending_models_. + std::mutex models_mutex_; + + // Notified whenever a model is placed onto pending_models_. + std::condition_variable pending_models_available_; + + AOTInductorModel *get_available_model() { + std::unique_lock lk(models_mutex_); + if (available_models_.empty()) { + reclaim_finished_models(lk); + } + auto *result = available_models_.back(); + available_models_.pop_back(); + return result; + } + + // This mutex is used to protect execution of model. + // We acquire the mutex in shared mode if we allow concurrent execution. + // We acquire the mutex in unique mode when we want exclusive access of the + // model. One such case is when we want to do a weight swapping. We want to + // make sure no one is executing the model. + std::shared_mutex model_exec_mutex_; + + void *get_constant_blob_ptr(bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + return constant_blob_.get(); + } else { + if (!constant_blob_secondary_) { +#if defined(USE_NPU) + constant_blob_secondary_ = RAII_npuMalloc(blob_size_); +#else + constant_blob_secondary_ = RAII_cpuMalloc(blob_size_); +#endif // USE_NPU + } + return constant_blob_secondary_.get(); + } + } + + std::shared_ptr get_constants_map(bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + return constants_map_; + } else { + if (!constants_map_secondary_) { + constants_map_secondary_ = std::make_shared(); + } + return constants_map_secondary_; + } + } + + std::shared_ptr> get_constants_array( + bool get_inactive) { + if ((get_inactive && use_secondary_) || + (!get_inactive && !use_secondary_)) { + return constants_array_; + } else { + if (!constants_array_secondary_) { + constants_array_secondary_ = + std::make_shared>( + models_[0]->num_constants()); + } + return constants_array_secondary_; + } + } + + void reclaim_finished_models(std::unique_lock &lk) { +#ifdef __aarch64__ + // push finished model instances to the end of pending_models_ + auto it = + std::partition(pending_models_.begin(), pending_models_.end(), + [](AOTInductorModel *m) { return !m->is_finished(); }); +#else + // push finished model instances to the end of pending_models_ + auto it = std::stable_partition( + pending_models_.begin(), pending_models_.end(), + [](AOTInductorModel *m) { return !m->is_finished(); }); +#endif + + if (it != pending_models_.end()) { + // We have finished model instances that can be pushed into + // available_models_ so that we don't have to be blocked on waiting + // the pending_models_available_ condition. + available_models_.insert(available_models_.end(), it, + pending_models_.end()); + pending_models_.erase(it, pending_models_.end()); + return; + } + + pending_models_available_.wait( + lk, [this]() { return !pending_models_.empty(); }); + // Let's make the schedule simple first. We always wait on the first + // pending_models_ to be complete. + auto *model = pending_models_.front(); + pending_models_.pop_front(); + lk.unlock(); + try { + model->wait_for_completion(); + } catch (...) { + lk.lock(); + available_models_.push_back(model); + throw; + } + lk.lock(); + available_models_.push_back(model); + } +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/scalar_to_tensor.h b/torch_npu/csrc/inductor/aoti_runtime/scalar_to_tensor.h new file mode 100644 index 0000000000..3f59d40711 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/scalar_to_tensor.h @@ -0,0 +1,38 @@ +#pragma once + +#include +#include + +namespace torch::aot_inductor { + +template +inline RAIIAtenTensorHandle scalar_to_tensor_handle(T value) { + throw std::runtime_error("Unsupported scalar_to_tensor_handle"); +} + +// Specialize for supported C++ primitive types +#define AOTI_RUNTIME_SCALAR_TO_TENSOR(dtype, ctype) \ + template <> \ + inline RAIIAtenTensorHandle scalar_to_tensor_handle(ctype value) { \ + AtenTensorHandle tensor_handle; \ + AOTI_TORCH_ERROR_CODE_CHECK( \ + aoti_torch_scalar_to_tensor_##dtype(value, &tensor_handle)); \ + return RAIIAtenTensorHandle(tensor_handle); \ + } + +AOTI_RUNTIME_SCALAR_TO_TENSOR(float32, float) +AOTI_RUNTIME_SCALAR_TO_TENSOR(float64, double) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint8, uint8_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint16, uint16_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint32, uint32_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(uint64, uint64_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int8, int8_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int16, int16_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int32, int32_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(int64, int64_t) +AOTI_RUNTIME_SCALAR_TO_TENSOR(bool, bool) +AOTI_RUNTIME_SCALAR_TO_TENSOR(complex64, c10::complex) +AOTI_RUNTIME_SCALAR_TO_TENSOR(complex128, c10::complex) +#undef AOTI_RUNTIME_SCALAR_TO_TENSOR + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/thread_local.h b/torch_npu/csrc/inductor/aoti_runtime/thread_local.h new file mode 100644 index 0000000000..303f34c3c7 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/thread_local.h @@ -0,0 +1,144 @@ +#pragma once + +#include + +namespace torch::aot_inductor { + +template +struct ThreadLocalCachedOutputTensor; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const RAIIAtenTensorHandle &) {} + void copy_data_from(const RAIIAtenTensorHandle &handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } +}; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const AtenTensorHandle &) {} + void copy_data_from(const AtenTensorHandle &handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } +}; + +template <> +struct ThreadLocalCachedOutputTensor { + explicit ThreadLocalCachedOutputTensor(const ConstantHandle &) {} + void copy_data_from(const ConstantHandle &handle) { + throw std::runtime_error("can't happen"); + } + + AtenTensorHandle tensor() const { throw std::runtime_error("can't happen"); } +}; + +template +struct ThreadLocalCachedOutputTensor> { + explicit ThreadLocalCachedOutputTensor(const ArrayRefTensor &t) { + realloc(t); + } + + void copy_data_from(const ArrayRefTensor &t) { + if (t.numel() > capacity_) { + realloc(t); + } + std::copy(t.data(), t.data() + t.numel(), storage_.get()); + } + + AtenTensorHandle tensor() const { return tensor_.get(); } + + private: + void realloc(const ArrayRefTensor &t) { + capacity_ = t.numel(); + // NOLINTNEXTLINE(*arrays*) + storage_ = std::make_unique(t.numel()); + AtenTensorHandle handle = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_create_tensor_from_blob_npu( + storage_.get(), t.sizes().size(), t.sizes().data(), t.strides().data(), + 0, aoti_torch_dtype>(), t.device_type(), + t.device_idx(), &handle)); + tensor_ = handle; + } + + // NOLINTNEXTLINE(*arrays*) + std::unique_ptr storage_; + int64_t capacity_ = 0; + RAIIAtenTensorHandle tensor_; +}; + +template +struct ThreadLocalCachedOutputArray; + +// Just needs to compile, doesn't need to do anything. +template <> +struct ThreadLocalCachedOutputArray { + explicit ThreadLocalCachedOutputArray(const RAIIAtenTensorHandle &) { + throw std::runtime_error("can't happen"); + } + + // Not supported yet! We would need to put contiguous() or + // expect_contiguous() into the ABI. + void copy_data_from(const RAIIAtenTensorHandle &) { + throw std::runtime_error("can't happen"); + } + + template + ArrayRefTensor arrayref_tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +// Just needs to compile, doesn't need to do anything. +template <> +struct ThreadLocalCachedOutputArray { + explicit ThreadLocalCachedOutputArray(const ConstantHandle &) { + throw std::runtime_error("can't happen"); + } + + // Not supported yet! We would need to put contiguous() or + // expect_contiguous() into the ABI. + void copy_data_from(const ConstantHandle &) { + throw std::runtime_error("can't happen"); + } + + template + ArrayRefTensor arrayref_tensor() const { + throw std::runtime_error("can't happen"); + } +}; + +template +struct ThreadLocalCachedOutputArray> { + explicit ThreadLocalCachedOutputArray(const ArrayRefTensor &t) {} + + template , + std::remove_const_t>, + bool> = true> + ArrayRefTensor arrayref_tensor() const { + return tensor_; + } + + void copy_data_from(const ArrayRefTensor &t) { + if (t.numel() > capacity_) { + capacity_ = t.numel(); + // NOLINTNEXTLINE(*arrays*) + storage_ = std::make_unique(capacity_); + } + std::copy(t.data(), t.data() + t.numel(), storage_.get()); + tensor_ = t; + tensor_.set_arrayref(MiniArrayRef(storage_.get(), t.numel())); + } + + private: + // NOLINTNEXTLINE(*arrays*) + std::unique_ptr storage_; + uint32_t capacity_ = 0; + ArrayRefTensor tensor_; +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_runtime/utils.h b/torch_npu/csrc/inductor/aoti_runtime/utils.h new file mode 100644 index 0000000000..72c78e7ac1 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_runtime/utils.h @@ -0,0 +1,234 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +// WARNING: Be careful when adding new includes here. This header will be used +// in model.so, and should not refer to any aten/c10 headers except the stable +// C ABI defined in torch/csrc/inductor/aoti_torch/c/shim.h. The same rule +// applies to other files under torch/csrc/inductor/aoti_runtime/. +#include + +#if defined(__GNUC__) || defined(__clang__) +#define AOTI_NOINLINE __attribute__((noinline)) +#elif _MSC_VER +#define AOTI_NOINLINE __declspec(noinline) +#else +#define AOTI_NOINLINE +#endif + +AOTI_NOINLINE static void throw_exception(const char *call, const char *file, + int64_t line) { + std::stringstream ss; + ss << call << " API call failed at " << file << ", line " << line; + throw std::runtime_error(ss.str()); +} + +#define AOTI_TORCH_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_TORCH_SUCCESS) { \ + throw_exception(#call, __FILE__, __LINE__); \ + } + +using AOTIRuntimeError = int32_t; +#define AOTI_RUNTIME_SUCCESS 0 +#define AOTI_RUNTIME_FAILURE 1 + +#define AOTI_RUNTIME_ERROR_CODE_CHECK(call) \ + if ((call) != AOTI_RUNTIME_SUCCESS) { \ + throw_exception(#call, __FILE__, __LINE__); \ + } + +namespace torch::aot_inductor { + +using DeleterFnPtr = void (*)(void *); + +inline void noop_deleter(void *) {} + +inline void delete_tensor_object(void *ptr) { + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_delete_tensor_object(reinterpret_cast(ptr))); +} + +// RAIIAtenTensorHandle steals the tensor objects created by the libtorch C ABI +class RAIIAtenTensorHandle { + public: + RAIIAtenTensorHandle() : handle_(nullptr, noop_deleter) {} + RAIIAtenTensorHandle(const RAIIAtenTensorHandle &other) = delete; + RAIIAtenTensorHandle &operator=(const RAIIAtenTensorHandle &other) = delete; + + // Steal the ownership from another RAIIAtenTensorHandle using std::move + RAIIAtenTensorHandle(RAIIAtenTensorHandle &&other) = default; + RAIIAtenTensorHandle &operator=(RAIIAtenTensorHandle &&other) = default; + + // Steal the ownership from raw AtenTensorHandle + RAIIAtenTensorHandle(AtenTensorHandle handle) + : handle_(handle, delete_tensor_object) {} + + ~RAIIAtenTensorHandle() { handle_.reset(); } + + // Return a raw AtenTensorHandle to be used by aoti_torch functions + // Note: this function does NOT transfer the ownership of the handle + operator AtenTensorHandle() const { return handle_.get(); } + + AtenTensorHandle release() { return handle_.release(); } + + AtenTensorHandle get() const { return handle_.get(); } + + void reset() { handle_.reset(); } + + int64_t size(int64_t d) { + int64_t size = 0; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_size(handle_.get(), d, &size)); + return size; + } + + int64_t stride(int64_t d) { + int64_t stride = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_stride(handle_.get(), d, &stride)); + return stride; + } + + int64_t storage_offset() { + int64_t storage_offset = 0; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_storage_offset(handle_.get(), &storage_offset)); + return storage_offset; + } + + void *data_ptr() const { + void *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_data_ptr(handle_.get(), &result)); + return result; + } + + int64_t *sizes() const { + int64_t *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_.get(), &result)); + return result; + } + + int64_t *strides() const { + int64_t *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_.get(), &result)); + return result; + } + + private: + std::unique_ptr handle_; +}; + +// Steal the ownership from raw AtenTensorHandle to RAIIAtenTensorHandle +inline std::vector steal_from_raw_handles_to_raii_handles( + AtenTensorHandle *handles, size_t size) { + std::vector result; + result.reserve(size); + for (size_t i = 0; i < size; i++) { + result.emplace_back(handles[i]); + handles[i] = nullptr; + } + return result; +} + +inline AtenTensorHandle reinterpret_tensor_wrapper(AtenTensorHandle self, + int64_t ndim, + const int64_t *sizes_ptr, + const int64_t *strides_ptr, + int64_t storage_offset) { + AtenTensorHandle result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch__reinterpret_tensor( + self, ndim, sizes_ptr, strides_ptr, storage_offset, &result)); + return result; +} + +inline void *get_data_ptr_wrapper(AtenTensorHandle tensor) { + void *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(tensor, &result)); + return result; +} + +inline AtenTensorHandle unwrap_raii_handle_if_needed( + const RAIIAtenTensorHandle &handle) { + return handle.get(); +} + +inline RAIIAtenTensorHandle wrap_with_raii_handle_if_needed( + AtenTensorHandle handle) { + return RAIIAtenTensorHandle(handle); +} + +class ConstantHandle { + public: + ConstantHandle() = default; + + explicit ConstantHandle(AtenTensorHandle handle) : handle_(handle) { + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle_, &data_)); + } + + operator AtenTensorHandle() const { return handle_; } + + AtenTensorHandle tensor() const { return handle_; } + + AtenTensorHandle get() const { return handle_; } + + void *data_ptr() const { return data_; } + + int64_t *sizes() const { + int64_t *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle_, &result)); + return result; + } + + int64_t *strides() const { + int64_t *result = nullptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle_, &result)); + return result; + } + + private: + AtenTensorHandle handle_{}; + void *data_ = nullptr; +}; + +inline void *get_data_ptr_wrapper(const ConstantHandle &constant) { + return constant.data_ptr(); +} + +inline const ConstantHandle &unwrap_raii_handle_if_needed( + const ConstantHandle &handle) { + return handle; +} + +// Shouldn't be called. +inline AtenTensorHandle wrap_with_raii_handle_if_needed( + const ConstantHandle &handle) = delete; + +// DANGEROUS. Do not call unless you explicitly intend to get a reference to a +// temporary value, which will expire at the end of the current expression. +// This should only be called in cases where the C-shim API expects an optional +// input argument (passed by pointer), and a temporary needs to be passed to it. +template +T &temporary_reference(T &&t) { + return t; +} + +#define CACHE_TORCH_DTYPE(typename) \ + static auto cached_torch_dtype_##typename = aoti_torch_dtype_##typename() + +#define CACHE_TORCH_DEVICE(device) \ + static auto cached_torch_device_type_##device = \ + aoti_torch_device_type_##device() + +#define CACHE_TORCH_LAYOUT(layout) \ + static auto cached_torch_layout_##layout = aoti_torch_layout_##layout() + +#define CACHE_TORCH_MEMORY_FORMAT(format) \ + static auto cached_torch_memory_format_##format = \ + aoti_torch_memory_format_##format() + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/c/shim.h b/torch_npu/csrc/inductor/aoti_torch/c/shim.h new file mode 100644 index 0000000000..e86fc28330 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/c/shim.h @@ -0,0 +1,653 @@ +#ifndef AOTI_TORCH_SHIM +#define AOTI_TORCH_SHIM + +#include +#include + +// This header defines a stable C API for certain ATen functionality in +// libtorch. The AOTInductor compiled model.so will only refer to this header +// instead of other headers from aten/c10, which means it will NOT be able to +// directly use any data structures or call functions from libtorch. +// +// What problems are we trying to solve here? Direct use of aten/c10 APIs +// means use of C++ APIs on a library that doesn't have any ABI compatibility +// guarantees. However, we want model.so to remain usable across updates +// to the PyTorch C++ libraries, which requires a stable ABI. By introducing +// a C shim layer, we can minimize the surface that will cause breakage. The +// corresponding software stack can be illustrated as follows: +// +// |--------------------------------| +// | inference service code | +// |--------------------------------| +// | model.so | +// |--------------|-----------------| +// | | +// | libtorch.so | +// |--------------------------------| +// +// The general guidelines for the C API: +// +// - No exceptions, return an explicit error code to be checked at call site +// - Only pointers (AtenTensorHandle counts), integers and floats in headers +// +// If you want to make changes to this header, you MUST MAINTAIN ABI +// compatibility. Typically, this means you will have to add a _v2 version +// of a function that you, e.g., want to add a new function parameter to, and +// maintain the old and new versions of the APIs until all old model.so +// go out of use. + +#ifdef __GNUC__ +#define AOTI_TORCH_EXPORT __attribute__((__visibility__("default"))) +#else // !__GNUC__ +#ifdef _WIN32 +// PyTorch2 doesn't currently work on Windows. Exporting these APIs can lead +// to symbol clashes at link time if libtorch is included in a DLL and binary +// that depends on the DLL. As a short term fix, we don't export the symbols. +// In the long term, this will need to be addressed when Windows is supported. +#ifdef OVRSOURCE +// Do not export AOTI on Windows for internal builds +#define AOTI_TORCH_EXPORT +#else /* OVRSOURCE */ +#ifdef EXPORT_AOTI_FUNCTIONS +#define AOTI_TORCH_EXPORT __declspec(dllexport) +#else +#define AOTI_TORCH_EXPORT __declspec(dllimport) +#endif +#endif /* OVRSOURCE */ +#else // !_WIN32 +#define AOTI_TORCH_EXPORT +#endif // _WIN32 +#endif // __GNUC__ + +// The following files are implemented in a header-only way and are guarded by +// test/cpp/aoti_abi_check +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// AtenTensorHandle represents an abstract notion of Tensor that can be passed +// between model.so and libtorch.so. The contents of the structure itself +// are private; model.so is not allowed to access any fields directly, it must +// go through functions defined in this ABI. Under the hood, this is +// represented as at::Tensor*, but we reserve the right to change this (and in +// fact, we probably should change it to at::TensorImpl* at least). +// +// An AtenTensorHandle can be owning (please check the API reference for exact +// ownership/borrow semantics). If you have an owning AtenTensorHandle +// in model.so, you are obligated to aoti_torch_delete_tensor_object when you +// are done. You can use the helper C++ class RAIIAtenTensorHandle +// (see aot_runtime/model.h) to ensure the deallocator is called in RAII style +// (note that RAIIAtenTensorHandle is private to model.so, and never crosses +// the ABI boundary.) +struct AtenTensorOpaque; +using AtenTensorHandle = AtenTensorOpaque *; + +struct AtenGeneratorOpaque; +using AtenGeneratorHandle = AtenGeneratorOpaque *; + +struct AOTIProxyExecutorOpaque; +using AOTIProxyExecutorHandle = AOTIProxyExecutorOpaque *; + +using AOTITorchError = int32_t; +#define AOTI_TORCH_SUCCESS 0 +#define AOTI_TORCH_FAILURE 1 + +// Getter functions for retrieving various constants from the runtime, that +// can subsequently be passed to other aoti_* functions. By hiding these +// behind functions, the precise value of device/dtype is NOT part of the +// ABI contract. (In practice, aten/c10 is pretty good about not renumbering +// these, so we probably could later switch to having these in the ABI, if +// desired for perf reasons.) +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_cpu(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_meta(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_npu(); +AOTI_TORCH_EXPORT int32_t aoti_torch_device_type_privateuse1(); + +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fn(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e5m2fnuz(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float8_e4m3fnuz(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bfloat16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_float64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint8(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_uint64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int8(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int16(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_int64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_bool(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex32(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex64(); +AOTI_TORCH_EXPORT int32_t aoti_torch_dtype_complex128(); + +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_strided(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_coo(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_csr(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_csc(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_bsr(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_sparse_bsc(); +AOTI_TORCH_EXPORT int32_t aoti_torch_layout_jagged(); + +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_contiguous_format(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_channels_last(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_channels_last_3d(); +AOTI_TORCH_EXPORT int32_t aoti_torch_memory_format_preserve_format(); + +// Get TORCH_ABI_VERSION of the built libtorch.so +AOTI_TORCH_EXPORT uint64_t aoti_torch_abi_version(); + +// Functions for converting a single-element tensor to a scalar value +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float16(AtenTensorHandle tensor, c10::Half *ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float32(AtenTensorHandle tensor, float *ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_float64(AtenTensorHandle tensor, double *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_uint8(AtenTensorHandle tensor, + uint8_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_uint16(AtenTensorHandle tensor, + uint16_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_uint32(AtenTensorHandle tensor, + uint32_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_uint64(AtenTensorHandle tensor, + uint64_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_int8(AtenTensorHandle tensor, + int8_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_int16(AtenTensorHandle tensor, + int16_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_int32(AtenTensorHandle tensor, + int32_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_int64(AtenTensorHandle tensor, + int64_t *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_bool(AtenTensorHandle tensor, + bool *ret_value); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_item_bfloat16(AtenTensorHandle tensor, c10::BFloat16 *ret_value); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_item_complex64( + AtenTensorHandle tensor, c10::complex *ret_value); + +// Functions for wrapping a scalar value to a single-element tensor +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float32( + float value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_float64( + double value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint8( + uint8_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint16( + uint16_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint32( + uint32_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_uint64( + uint64_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int8( + int8_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int16( + int16_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int32( + int32_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_int64( + int64_t value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_scalar_to_tensor_bool(bool value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex64( + c10::complex value, AtenTensorHandle *ret_new_tensor); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scalar_to_tensor_complex128( + c10::complex value, AtenTensorHandle *ret_new_tensor); + +AOTI_TORCH_EXPORT bool aoti_torch_grad_mode_is_enabled(); +AOTI_TORCH_EXPORT void aoti_torch_grad_mode_set_enabled(bool enabled); + +// Free the tensor object +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_tensor_object(AtenTensorHandle tensor); + +// Get a pointer to the underlying storage data +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_data_ptr( + AtenTensorHandle tensor, + void **ret_data_ptr // returns borrowed reference +); + +// Get the nbytes of the underlying storage +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_storage_size(AtenTensorHandle tensor, int64_t *ret_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_dim(AtenTensorHandle tensor, + int64_t *ret_dim); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_numel(AtenTensorHandle tensor, + int64_t *ret_numel); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_storage_numel(AtenTensorHandle tensor, int64_t *ret_numel); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_sizes( + AtenTensorHandle tensor, + int64_t **ret_sizes // returns borrowed reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_size(AtenTensorHandle tensor, + int64_t d, + int64_t *ret_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_strides( + AtenTensorHandle tensor, + int64_t **ret_strides // returns borrowed reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_stride(AtenTensorHandle tensor, + int64_t d, + int64_t *ret_stride); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_dtype(AtenTensorHandle tensor, + int32_t *ret_dtype); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_device_type(AtenTensorHandle tensor, int32_t *ret_device_type); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_get_device_index(AtenTensorHandle tensor, int32_t *ret_device_index); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_get_storage_offset( + AtenTensorHandle tensor, int64_t *ret_storage_offset); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_new_tensor_handle( + AtenTensorHandle orig_handle, AtenTensorHandle *new_handle); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__alloc_from_pool( + AtenTensorHandle self, int64_t offset_bytes, int32_t dtype, int64_t ndim, + const int64_t *sizes_ptr, const int64_t *strides_ptr, + AtenTensorHandle *ret_new_tensor); + +// This function will create a new tensor object and its pointer is returned +// through *out. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__reinterpret_tensor( + AtenTensorHandle self, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, + AtenTensorHandle *ret_new_tensor // returns new reference +); + +// This function will create a new tensor object and its pointer is returned +// through *out. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_empty_strided( + int64_t ndim, const int64_t *sizes_ptr, const int64_t *strides_ptr, + int32_t dtype, int32_t device_type, int32_t device_index, + AtenTensorHandle *ret_new_tensor // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_as_strided(AtenTensorHandle self, const int64_t *sizes_ptr, + const int64_t *strides_ptr, AtenTensorHandle *ret); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_v2( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret, // returns new reference + int32_t layout, const uint8_t *opaque_metadata, + int64_t opaque_metadata_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_npu( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_create_tensor_from_blob_npu_v2( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret, // returns new reference + int32_t layout, const uint8_t *opaque_metadata, + int64_t opaque_metadata_size); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__embedding_bag( + AtenTensorHandle weight, AtenTensorHandle indices, AtenTensorHandle offsets, + int32_t scale_grad_by_freq, int32_t mode, int32_t sparse, + AtenTensorHandle per_sample_weights, // optional argument + int32_t include_last_offset, int32_t padding_idx, + AtenTensorHandle *ret0, // returns new reference + AtenTensorHandle *ret1, // returns new reference + AtenTensorHandle *ret2, // returns new reference + AtenTensorHandle *ret3 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__fft_c2c( + AtenTensorHandle self, const int64_t *dim_ptr, int64_t dim_size, + int64_t normalization, int32_t forward, + AtenTensorHandle *ret // returns new reference +); + +// This version is deprecated. We will remove it later +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention( + AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, + double dropout_p, bool is_causal, bool return_debug_mask, double scale, + AtenTensorHandle *ret0, // returns new reference + AtenTensorHandle *ret1, // returns new reference + AtenTensorHandle *ret2, // returns new reference + AtenTensorHandle *ret3, // returns new reference + int64_t *ret4, int64_t *ret5, + AtenTensorHandle *ret6, // returns new reference + AtenTensorHandle *ret7, // returns new reference + AtenTensorHandle *ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_flash_attention_v2( + AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, + double dropout_p, int is_causal, int return_debug_mask, + double *scale, // optional argument + AtenTensorHandle *ret0, // returns new reference + AtenTensorHandle *ret1, // returns new reference + AtenTensorHandle *ret2, // returns new reference + AtenTensorHandle *ret3, // returns new reference + int64_t *ret4, int64_t *ret5, + AtenTensorHandle *ret6, // returns new reference + AtenTensorHandle *ret7, // returns new reference + AtenTensorHandle *ret8 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch__scaled_dot_product_efficient_attention( + AtenTensorHandle query, AtenTensorHandle key, AtenTensorHandle value, + AtenTensorHandle attn_bias, // optional argument + int compute_log_sumexp, double dropout_p, int is_causal, + double *scale, // optional argument + AtenTensorHandle *ret0, // returns new reference + AtenTensorHandle *ret1, // returns new reference + AtenTensorHandle *ret2, // returns new reference + AtenTensorHandle *ret3 // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm( + AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle bias, + int32_t *out_dtype, AtenTensorHandle scale_a, AtenTensorHandle scale_b, + AtenTensorHandle scale_result, int8_t use_fast_accum, + AtenTensorHandle *ret0, AtenTensorHandle *ret1); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_mm_v2( + AtenTensorHandle self, AtenTensorHandle mat2, AtenTensorHandle scale_a, + AtenTensorHandle scale_b, AtenTensorHandle bias, + AtenTensorHandle scale_result, int32_t *out_dtype, int8_t use_fast_accum, + AtenTensorHandle *ret0); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_convolution( + AtenTensorHandle input, AtenTensorHandle weight, + AtenTensorHandle bias, // optional argument + const int64_t *stride_ptr, int64_t stride_size, const int64_t *padding_ptr, + int64_t padding_size, const int64_t *dilation_ptr, int64_t dilation_size, + int transposed, const int64_t *output_padding_ptr, + int64_t output_padding_size, int64_t groups, + AtenTensorHandle *ret // returns new reference +); + +// This function will create a new uninitialized tensor object +// and its pointer is returned through *ret. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_new_uninitialized_tensor(AtenTensorHandle *ret); + +// WARNING: This will be deprecated. Use aoti_torch_copy_ instead. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_tensor_copy_(AtenTensorHandle src, + AtenTensorHandle dst); + +// Make the tensor referred to by dst an alias for the tensor referred +// to by src. The two tensors must still be deleted with +// aoti_torch_delete_tensor separately (or not) as before the call. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_assign_tensors(AtenTensorHandle src, AtenTensorHandle dst); + +// Make a shallow copy of the tensor referred to by src and assign +// it to the handle in the ret_dst. This is similar to the above +// aoti_torch_assign_tensors function, but creates and sets the +// ret_dst from within. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_assign_tensors_out(AtenTensorHandle src, AtenTensorHandle *ret_dst); + +// This function will create a new tensor object and its pointer is returned +// through *ret. The caller is responsible for wrapping the tensor pointer +// with RAIIAtenTensorHandle which will call aoti_torch_delete_tensor_object +// when going out of scope. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_clone(AtenTensorHandle self, + AtenTensorHandle *ret); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_clone_preserve_strides(AtenTensorHandle self, AtenTensorHandle *ret); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_addmm_out(AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat1, + AtenTensorHandle mat2, + float beta, float alpha); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_bmm_out(AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_copy_(AtenTensorHandle self, + AtenTensorHandle src, + int32_t non_blocking); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_mm_out(AtenTensorHandle out, + AtenTensorHandle self, + AtenTensorHandle mat2); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch__mm_plus_mm_out( + AtenTensorHandle out, AtenTensorHandle a, AtenTensorHandle b, + AtenTensorHandle c, AtenTensorHandle d); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_wrapped_fbgemm_pack_gemm_matrix_fp16(AtenTensorHandle weight, + AtenTensorHandle *out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu__wrapped_linear_prepack( + AtenTensorHandle weight, AtenTensorHandle weight_scale, + AtenTensorHandle weight_zero_point, AtenTensorHandle bias, + AtenTensorHandle *out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu_wrapped_fbgemm_linear_fp16_weight(AtenTensorHandle input, + AtenTensorHandle weight, + AtenTensorHandle bias, + int64_t out_channel, + AtenTensorHandle *out); + +// This will soon be deprecated after ao_quantization is complete. +// Please refrain from using this or increasing callsites. +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_cpu__wrapped_quantized_linear_prepacked( + AtenTensorHandle input, AtenTensorHandle input_scale, + AtenTensorHandle input_zero_point, AtenTensorHandle weight, + AtenTensorHandle out_scale, AtenTensorHandle out_zeropoint, + int64_t out_channel, AtenTensorHandle *out); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_nonzero(AtenTensorHandle self, + AtenTensorHandle *out); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_zero_(AtenTensorHandle self); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_repeat_interleave_Tensor( + AtenTensorHandle repeats, int64_t *output_size, AtenTensorHandle *out); + +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_check_inf_and_nan(const char *tensor_name, AtenTensorHandle tensor); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_out(AtenTensorHandle out, + AtenTensorHandle self, + int64_t dim, + AtenTensorHandle index, + AtenTensorHandle src); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_scatter_reduce_out( + AtenTensorHandle out, AtenTensorHandle self, int64_t dim, + AtenTensorHandle index, AtenTensorHandle src, const char *reduce, + int32_t include_self); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_index_put_out( + AtenTensorHandle out, AtenTensorHandle self, + const AtenTensorHandle *indices, const uint32_t num_indices, + const AtenTensorHandle values, bool accumulate); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_as_real( + AtenTensorHandle self, + AtenTensorHandle *ret // returns new reference +); + +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_view_dtype( + AtenTensorHandle self, int32_t dtype, + AtenTensorHandle *ret // returns new reference +); + +AOTI_TORCH_EXPORT void aoti_torch_print_tensor_handle(AtenTensorHandle self, + const char *msg); + +// When AOTI debug printer option is enabled, this function will be invoked to +// torch pickle save the intermediate tensor for debugging purpose. +AOTI_TORCH_EXPORT void aoti_torch_save_tensor_handle(AtenTensorHandle self, + const char *tensor_name, + const char *launch_prefix, + const char *kernel_name); + +// helpers for converting between StableIValue and actual IValues +using StableIValue = uint64_t; + +class TorchLibraryOpaque; +using TorchLibraryHandle = TorchLibraryOpaque *; + +// stable corollary to torch::Library constructor with Kind::IMPL +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_impl( + const char *ns, const char *k, const char *file, uint32_t line, + TorchLibraryHandle *ret_new_torch_lib); + +// stable corollary to torch::Library constructor with Kind::DEF +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_library_init_def(const char *ns, const char *file, uint32_t line, + TorchLibraryHandle *ret_new_torch_lib); + +// stable corollary to torch::Library constructor with Kind::FRAGMENT +// will create a new torch::Library object on the heap +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_init_fragment( + const char *ns, const char *file, uint32_t line, + TorchLibraryHandle *ret_new_torch_lib); + +// stable corollary to torch::Library method m.impl(), should be +// called from StableLibrary +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_library_impl(TorchLibraryHandle self, const char *name, + void (*fn)(StableIValue *, uint64_t, uint64_t)); + +// stable corollary to torch::Library method m.def(), should be +// called from StableLibrary +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_library_def(TorchLibraryHandle self, + const char *schema); + +// the above stable constructors for torch::Library add Library objects +// to the heap. if you are calling those functions directly, please use +// this function to free the Library's memory. The more user friendly +// alternative is to use StableLibrary, which will free its handle upon +// destruction +AOTI_TORCH_EXPORT AOTITorchError +aoti_torch_delete_library_object(TorchLibraryHandle tlh); + +// calls the op overload defined by a given opName, overloadName, and a +// stack of StableIValues. This call will populate any return values of the +// op into the stack in their StableIValue form, with ret0 at index 0, ret1 +// at index 1, and so on. +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_call_dispatcher( + const char *opName, const char *overloadName, StableIValue *stack); + + +// See `ProxyExecutor Design Note` in ir.py for more details +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function( + AOTIProxyExecutorHandle proxy_executor, int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args); + +AOTI_TORCH_EXPORT void aoti_torch_check(bool cond, const char *func, + const char *file, uint32_t line, + const char *msg); + +#ifdef STRIP_ERROR_MESSAGES +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check(false, __func__, __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", __VA_ARGS__)); \ + } +#else +#define AOTI_TORCH_CHECK(cond, ...) \ + if (!(cond)) { \ + aoti_torch_check(false, __func__, __FILE__, \ + static_cast(__LINE__), \ + TORCH_CHECK_MSG(cond, "", ##__VA_ARGS__)); \ + } +#endif + +AOTI_TORCH_EXPORT void aoti_torch_warn(const char *func, const char *file, + uint32_t line, const char *msg); + +#ifdef DISABLE_WARN +#define AOTI_TORCH_WARN(...) ((void)0); +#else +#define AOTI_TORCH_WARN(...) \ + aoti_torch_warn(__func__, __FILE__, static_cast(__LINE__), \ + #__VA_ARGS__); +#endif + +#ifdef __cplusplus +} // extern "C" + +template +int32_t aoti_torch_dtype() = delete; + +#define DEFINE_DTYPE_SPECIALIZATION(ctype, typename) \ + template <> \ + inline int32_t aoti_torch_dtype() { \ + return aoti_torch_dtype_##typename(); \ + } + +namespace c10 { +struct BFloat16; +struct Half; +} // namespace c10 + +DEFINE_DTYPE_SPECIALIZATION(c10::BFloat16, bfloat16) +DEFINE_DTYPE_SPECIALIZATION(c10::Half, float16) +DEFINE_DTYPE_SPECIALIZATION(c10::complex, complex64) +DEFINE_DTYPE_SPECIALIZATION(float, float32) +DEFINE_DTYPE_SPECIALIZATION(double, float64) +DEFINE_DTYPE_SPECIALIZATION(uint8_t, uint8) +DEFINE_DTYPE_SPECIALIZATION(int8_t, int8) +DEFINE_DTYPE_SPECIALIZATION(int16_t, int16) +DEFINE_DTYPE_SPECIALIZATION(int32_t, int32) +DEFINE_DTYPE_SPECIALIZATION(int64_t, int64) +DEFINE_DTYPE_SPECIALIZATION(bool, bool) + +#endif +#endif // AOTI_TORCH_SHIM diff --git a/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.cpp b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.cpp new file mode 100644 index 0000000000..7495569f8e --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.cpp @@ -0,0 +1,545 @@ +#include +#include + +#include +#include +#include +#include + +namespace { +at::Tensor *tensor_handle_to_tensor_pointer(AtenTensorHandle handle) { + return reinterpret_cast(handle); +} +} // namespace + +namespace torch::aot_inductor { + +void OSSProxyExecutor::prefill_stack_with_static_arguments( + size_t index, const at::TypePtr &schema_arg_type, + const nlohmann::json &serialized_arg, OSSOpKernel &op_kernel) { + auto &stack = op_kernel.stack_; + auto &dynamic_args = op_kernel.dynamic_args_; + + TORCH_CHECK(serialized_arg.size() == 1); + std::string serialized_arg_type = serialized_arg.begin().key(); + auto &serialized_arg_val = serialized_arg.begin().value(); + + switch (schema_arg_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_arg_type == "as_tensor", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_tensor for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 1); + break; + } + case c10::TypeKind::IntType: { + TORCH_CHECK(serialized_arg_type == "as_int", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_int for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + break; + } + case c10::TypeKind::SymIntType: { + TORCH_CHECK(serialized_arg_type == "as_int" || + serialized_arg_type == "as_sym_int", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_int or as_sym_int for " + "argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + break; + } + case c10::TypeKind::FloatType: { + TORCH_CHECK(serialized_arg_type == "as_float", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_float for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::BoolType: { + TORCH_CHECK(serialized_arg_type == "as_bool", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_bool for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::NumberType: { + if (serialized_arg_type == "as_int") { + // Only int Scalar is treated as dynamic arg for now + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + } else if (serialized_arg_type == "as_float") { + stack.at(index) = serialized_arg_val.get(); + } else if (serialized_arg_type == "as_bool") { + stack.at(index) = serialized_arg_val.get(); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a scalar input for argument ", index, " but got ", + serialized_arg_type); + } + break; + } + case c10::TypeKind::StringType: { + TORCH_CHECK(serialized_arg_type == "as_string", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_string for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::ScalarTypeType: { + TORCH_CHECK( + serialized_arg_type == "as_scalar_type", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_scalar_type for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::MemoryFormatType: { + TORCH_CHECK( + serialized_arg_type == "as_memory_format", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_memory_format for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::LayoutType: { + TORCH_CHECK(serialized_arg_type == "as_layout", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_layout for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::DeviceObjType: { + TORCH_CHECK(serialized_arg_type == "as_device", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_device for argument ", + index, " but got ", serialized_arg_type); + + std::string device_string = serialized_arg_val["type"].get(); + if (serialized_arg_val.contains("index") && + serialized_arg_val["index"].is_number()) { + device_string += ":" + serialized_arg_val["index"].get(); + } + + c10::Device device(device_string); + + if (device != *device_) { + VLOG(1) << "ProxyExecutor is using " << *device_ << " for " + << op_kernel.target_ << " argument #" << index + << ", which is different from the one serialized in thrift: " + << device << ". Please ensure this is intentional."; + } + + stack.at(index) = *device_; + break; + } + case c10::TypeKind::ListType: { + if (schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK( + serialized_arg_type == "as_tensors", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_tensors for argument ", index, + " but got ", serialized_arg_type); + TORCH_CHECK(serialized_arg_type == "as_tensors"); + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofInts())) { + TORCH_CHECK(serialized_arg_type == "as_ints", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_ints for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { + TORCH_CHECK(serialized_arg_type == "as_ints" || + serialized_arg_type == "as_sym_ints", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_ints or as_sym_ints " + "for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofFloats())) { + TORCH_CHECK(serialized_arg_type == "as_floats", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_floats for argument ", + index, " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofBools())) { + TORCH_CHECK(serialized_arg_type == "as_bools", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_bools for argument ", + index, " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofNumbers())) { + if (serialized_arg_type == "as_ints") { + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (serialized_arg_type == "as_floats") { + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg); + } + stack.at(index) = std::move(ret); + } else if (serialized_arg_type == "as_bools") { + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg); + } + stack.at(index) = std::move(ret); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a List[Scalar] input for argument ", index, + " but got ", serialized_arg_type); + } + } else if (schema_arg_type->isSubtypeOf( + at::ListType::ofOptionalTensors())) { + if (serialized_arg_type == "as_optional_tensors") { + std::vector list_item_types; + for (const auto &arg : serialized_arg_val) { + list_item_types.push_back(arg.begin().key()); + } + dynamic_args.emplace_back(index, + DynamicArgType::ListOptionalTensorType, + serialized_arg_val.size(), list_item_types); + } else if (serialized_arg_type == "as_tensors") { + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, + serialized_arg_val.size()); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a Tensor?[] input for argument ", index, + " but got ", serialized_arg_type); + } + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofStrings())) { + TORCH_CHECK( + serialized_arg_type == "as_strings", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_strings for argument ", index, + " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else { + TORCH_CHECK(false, "NYI: Unsupported list type ", serialized_arg_type, + " for extern kernel ", op_kernel.target_, " argument ", + index); + } + break; + } + case c10::TypeKind::OptionalType: { + auto inner_type = + schema_arg_type->castRaw()->getElementType(); + + if (serialized_arg_type == "as_none") { + stack.at(index) = c10::IValue{}; + if (inner_type->kind() == c10::TypeKind::TensorType) { + // Tensor is None + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 0); + } else if (inner_type->kind() == c10::TypeKind::IntType || + inner_type->kind() == c10::TypeKind::SymIntType) { + // Int or SymInt is None + dynamic_args.emplace_back(index, DynamicArgType::IntType, 0); + } else if (inner_type->kind() == c10::TypeKind::ListType && + schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) { + // List[Tensor] is None + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, 0); + } else if (inner_type->kind() == c10::TypeKind::ListType && + schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { + // List[SymInt] is None + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, 0); + } + } else { + prefill_stack_with_static_arguments(index, inner_type, serialized_arg, + op_kernel); + } + break; + } + default: + TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type, + " for extern kernel ", op_kernel.target_, " argument ", + index); + } +} + +// Populates op_kernel.stack_, op_kernel.dynamic_args_ +void OSSProxyExecutor::get_input_info_from_serialized( + const std::vector &schema_args, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel) { + std::vector filled(schema_args.size(), false); + TORCH_CHECK(op_kernel.stack_.size() == 0); + op_kernel.stack_.resize(schema_args.size()); + for (const auto &named_argument : serialized_node["inputs"]) { + const auto &arg = named_argument["arg"]; + const auto &name = named_argument["name"].get(); + + // Doing a linear lookup in the schema to find the index + // of a static argument. Should be fine performance wise + // because we usually only have small amount of arguments. + for (size_t index = 0; index < schema_args.size(); index++) { + auto &schema_arg = schema_args[index]; + if (schema_arg.name() == name) { + prefill_stack_with_static_arguments(index, schema_arg.real_type(), arg, + op_kernel); + filled[index] = true; + break; + } + } + } + + // If an argument is not filled and has a default value, we should + // also prefill the default value. + for (size_t index = 0; index < schema_args.size(); index++) { + if (!filled[index] && schema_args[index].default_value()) { + auto default_value = *schema_args[index].default_value(); + op_kernel.stack_.at(index) = default_value; + } + } +} + +// Populates op_kernel.outputs_ +void OSSProxyExecutor::get_output_info_from_serialized( + const std::vector &schema_returns, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel) { + std::vector &outputs = op_kernel.outputs_; + + TORCH_CHECK(schema_returns.size() == serialized_node["outputs"].size(), + "Serialized node doesn't match operator ", + serialized_node["target"], "'s schema outputs."); + + size_t output_index = 0; + for (const auto &serialized_output : serialized_node["outputs"]) { + TORCH_CHECK(serialized_output.size() == 1); + std::string serialized_output_type = serialized_output.begin().key(); + auto &serialized_output_val = serialized_output.begin().value(); + + auto &schema_return = schema_returns[output_index]; + const at::TypePtr &schema_return_type = schema_return.real_type(); + + switch (schema_return_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_output_type == "as_tensor", + "Expected extern kernel ", serialized_node["target"], + " to have serialized output type as_tensor, ", " but got ", + serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::TensorType, 1); + break; + } + case c10::TypeKind::ListType: { + if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK(serialized_output_type == "as_tensors", + "Expected extern kernel ", serialized_node["target"], + " to have serialized output type as_tensors, ", + " but got ", serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::ListTensorType, + serialized_output_val.size()); + } else { + TORCH_CHECK(false, "Unsupported return list type ", + schema_return_type->repr_str()); + } + break; + } + default: { + TORCH_CHECK(false, "Unsupported return type ", + schema_return_type->repr_str(), " for extern kernel ", + op_kernel.target_); + } + } + + output_index++; + } +} + +OSSProxyExecutor::OSSProxyExecutor(const std::string &json_path, bool is_cpu) { + if (is_cpu) { + device_ = std::make_unique(c10::DeviceType::CPU); + } else { + int device_idx = -1; + // Mock a cuda device for now. + device_ = std::make_unique(c10::DeviceType::CUDA, device_idx); + } + + std::string extern_kernel_nodes_serialized; + + std::ifstream json_file(json_path); + TORCH_CHECK(json_file.is_open(), "Unable to open file ", json_path); + + // Parse file into a json object + nlohmann::json json_obj; + json_file >> json_obj; + + // Access data + for (auto const &serialized_extern_node : json_obj["nodes"]) { + auto const &serialized_node = serialized_extern_node["node"]; + + const std::string &target = serialized_node["target"]; + + std::string opName; + std::string overloadName; + size_t pos = target.find('.'); + if (pos == std::string::npos) { + opName = target; + overloadName = ""; + } else { + // There should be no more periods + size_t pos2 = target.find('.', pos + 1); + TORCH_CHECK(pos2 == std::string::npos); + + opName = target.substr(0, pos); + overloadName = target.substr(pos + 1, target.length() - pos); + } + + c10::OperatorHandle op_handle = + c10::Dispatcher::singleton().findSchemaOrThrow(opName.c_str(), + overloadName.c_str()); + const c10::FunctionSchema &schema = op_handle.schema(); + + const auto &schema_args = schema.arguments(); + const auto &schema_returns = schema.returns(); + + OSSOpKernel op_kernel(target, op_handle); + get_input_info_from_serialized(schema_args, serialized_node, op_kernel); + get_output_info_from_serialized(schema_returns, serialized_node, op_kernel); + + op_kernels_.emplace_back(std::move(op_kernel)); + } +} + +void OSSProxyExecutor::call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args) { + TORCH_CHECK(extern_node_index < static_cast(op_kernels_.size()), + "Invalid extern node index"); + OSSOpKernel &op_kernel = op_kernels_[extern_node_index]; + + std::vector stack = op_kernel.stack_; + auto &dynamic_args = op_kernel.dynamic_args_; + + int tensor_id = 0; + int int_id = 0; + for (auto &dynamic_arg : dynamic_args) { + int arg_index = dynamic_arg.arg_index; + DynamicArgType dynamic_arg_type = dynamic_arg.arg_type; + int length = dynamic_arg.length; + + if (length == 0) { + continue; + } + + switch (dynamic_arg_type) { + case DynamicArgType::TensorType: { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + stack[arg_index] = *tensor; + break; + } + case DynamicArgType::IntType: { + int64_t val = flatten_int_args[int_id++]; + stack[arg_index] = val; + break; + } + case DynamicArgType::ListTensorType: { + std::vector tensor_list; + for (int j = 0; j < length; j++) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + tensor_list.push_back(*tensor); + } + stack[arg_index] = tensor_list; + break; + } + case DynamicArgType::ListOptionalTensorType: { + std::vector> optional_tensor_list; + auto &list_item_types = dynamic_arg.list_item_types; + TORCH_CHECK( + list_item_types.has_value(), + "Could not find list of item types for optional tensor list input"); + + for (const std::string &item_type : list_item_types.value()) { + if (item_type == "as_tensor") { + at::Tensor *tensor = tensor_handle_to_tensor_pointer( + flatten_tensor_args[tensor_id++]); + optional_tensor_list.emplace_back(*tensor); + } else if (item_type == "as_none") { + optional_tensor_list.emplace_back(std::nullopt); + } + } + stack[arg_index] = optional_tensor_list; + break; + } + case DynamicArgType::ListIntType: { + std::vector vals; + vals.reserve(length); + for (int j = 0; j < length; j++) { + vals.push_back(flatten_int_args[int_id++]); + } + stack[arg_index] = vals; + break; + } + default: + TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type); + } + } + + int num_output_tensors = op_kernel.num_output_tensors(); + TORCH_CHECK(tensor_id == num_tensors - num_output_tensors, + "Mismatch between tensors consumed and num of input tensor, got " + "tensor_id = .", + tensor_id, ", expected num = ", num_tensors - num_output_tensors); + TORCH_CHECK(int_id == num_ints, + "Mismatch between ints consumed and num_ints, got int_id = ", + int_id, ", num_ints = ", num_ints); + + // Call the op with the prepared stack. + const c10::OperatorHandle &op = op_kernel.op_handle_; + op.callBoxed(stack); + + const c10::FunctionSchema &schema = op.schema(); + const auto &schema_returns = schema.returns(); + + TORCH_CHECK(op_kernel.outputs_.size() == stack.size()); + TORCH_CHECK(stack.size() == schema_returns.size()); + + int index = 0; + for (const auto &schema_return : schema_returns) { + if (schema_return.type()->kind() == c10::TypeKind::TensorType) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = stack[index++].toTensor(); + } else if (schema_return.type()->kind() == c10::TypeKind::ListType && + schema_return.type()->isSubtypeOf(at::ListType::ofTensors())) { + auto tensors = stack[index++].toTensorList(); + for (auto &&t : tensors) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = t; + } + } else { + TORCH_CHECK(false, "NYI: Unsupported return type for schema: ", + schema_return.type()->repr_str()); + } + } + + TORCH_CHECK( + tensor_id == num_tensors, + "Mismatch between tensors consumed and num_tensors, got tensor_id = ", + tensor_id, ", expected num = ", num_tensors); +} + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.h b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.h new file mode 100644 index 0000000000..7dbadc99b7 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::aot_inductor { + +enum class DynamicArgType : int { + TensorType = 0, + ListTensorType = 1, + ListOptionalTensorType = 2, + IntType = 3, + ListIntType = 4, +}; + +inline std::ostream &operator<<(std::ostream &os, DynamicArgType arg_type) { + os << static_cast(arg_type); + return os; +} + +inline bool isTensorType(DynamicArgType arg_type) { + return arg_type == DynamicArgType::TensorType || + arg_type == DynamicArgType::ListTensorType || + arg_type == DynamicArgType::ListOptionalTensorType; +} + +struct OSSDynamicArg { + OSSDynamicArg( + int arg_index, DynamicArgType arg_type, int length, + std::optional> list_item_types = std::nullopt) + : arg_index(arg_index), + arg_type(arg_type), + length(length), + list_item_types(std::move(list_item_types)) {} + int arg_index; + DynamicArgType arg_type; + int length; + std::optional> + list_item_types; // only used for parsing list of optional tensors +}; + +struct OSSOpKernel { + OSSOpKernel(std::string target, c10::OperatorHandle op_handle) + : target_(std::move(target)), op_handle_(std::move(op_handle)) {} + + std::string target_; + c10::OperatorHandle op_handle_; + std::vector dynamic_args_; + std::vector outputs_; + std::vector stack_; + + int num_output_tensors() const { + int num_output_tensors = 0; + for (const auto &output : outputs_) { + if (isTensorType(output.arg_type)) { + num_output_tensors += output.length; + } + } + return num_output_tensors; + } +}; + +class OSSProxyExecutor : public ProxyExecutor { + public: + explicit OSSProxyExecutor(const std::string &json_path, bool is_cpu); + + void call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args) override; + + private: + void prefill_stack_with_static_arguments(size_t index, + const at::TypePtr &schema_arg_type, + const nlohmann::json &serialized_arg, + OSSOpKernel &op_kernel); + + void get_input_info_from_serialized( + const std::vector &schema_args, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel); + + void get_output_info_from_serialized( + const std::vector &schema_returns, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel); + + std::vector op_kernels_; + std::unique_ptr device_; +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.cpp b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.cpp new file mode 100644 index 0000000000..f7404b389c --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.cpp @@ -0,0 +1,537 @@ +#include + +#include +#include +#include + +namespace { +at::Tensor *tensor_handle_to_tensor_pointer(AtenTensorHandle handle) { + return reinterpret_cast(handle); +} +} // namespace + +namespace torch::aot_inductor { +void OSSProxyExecutorNpu::prefill_stack_with_static_arguments( + size_t index, const at::TypePtr &schema_arg_type, + const nlohmann::json &serialized_arg, OSSOpKernel &op_kernel) { + auto &stack = op_kernel.stack_; + auto &dynamic_args = op_kernel.dynamic_args_; + + TORCH_CHECK(serialized_arg.size() == 1); + std::string serialized_arg_type = serialized_arg.begin().key(); + auto &serialized_arg_val = serialized_arg.begin().value(); + + switch (schema_arg_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_arg_type == "as_tensor", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_tensor for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 1); + break; + } + case c10::TypeKind::IntType: { + TORCH_CHECK(serialized_arg_type == "as_int", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_int for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + break; + } + case c10::TypeKind::SymIntType: { + TORCH_CHECK(serialized_arg_type == "as_int" || + serialized_arg_type == "as_sym_int", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_int or as_sym_int for " + "argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + break; + } + case c10::TypeKind::FloatType: { + TORCH_CHECK(serialized_arg_type == "as_float", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_float for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::BoolType: { + TORCH_CHECK(serialized_arg_type == "as_bool", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_bool for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::NumberType: { + if (serialized_arg_type == "as_int") { + // Only int Scalar is treated as dynamic arg for now + dynamic_args.emplace_back(index, DynamicArgType::IntType, 1); + } else if (serialized_arg_type == "as_float") { + stack.at(index) = serialized_arg_val.get(); + } else if (serialized_arg_type == "as_bool") { + stack.at(index) = serialized_arg_val.get(); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a scalar input for argument ", index, " but got ", + serialized_arg_type); + } + break; + } + case c10::TypeKind::StringType: { + TORCH_CHECK(serialized_arg_type == "as_string", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_string for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::ScalarTypeType: { + TORCH_CHECK( + serialized_arg_type == "as_scalar_type", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_scalar_type for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::MemoryFormatType: { + TORCH_CHECK( + serialized_arg_type == "as_memory_format", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_memory_format for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::LayoutType: { + TORCH_CHECK(serialized_arg_type == "as_layout", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_layout for argument ", + index, " but got ", serialized_arg_type); + stack.at(index) = serialized_arg_val.get(); + break; + } + case c10::TypeKind::DeviceObjType: { + TORCH_CHECK(serialized_arg_type == "as_device", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_device for argument ", + index, " but got ", serialized_arg_type); + + std::string device_string = serialized_arg_val["type"].get(); + if (serialized_arg_val.contains("index") && + serialized_arg_val["index"].is_number()) { + device_string += ":" + serialized_arg_val["index"].get(); + } + + c10::Device device(device_string); + + if (device != *device_) { + VLOG(1) << "ProxyExecutor is using " << *device_ << " for " + << op_kernel.target_ << " argument #" << index + << ", which is different from the one serialized in thrift: " + << device << ". Please ensure this is intentional."; + } + + stack.at(index) = *device_; + break; + } + case c10::TypeKind::ListType: { + if (schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK( + serialized_arg_type == "as_tensors", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_tensors for argument ", index, + " but got ", serialized_arg_type); + TORCH_CHECK(serialized_arg_type == "as_tensors"); + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofInts())) { + TORCH_CHECK(serialized_arg_type == "as_ints", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_ints for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { + TORCH_CHECK(serialized_arg_type == "as_ints" || + serialized_arg_type == "as_sym_ints", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_ints or as_sym_ints " + "for argument ", + index, " but got ", serialized_arg_type); + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofFloats())) { + TORCH_CHECK(serialized_arg_type == "as_floats", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_floats for argument ", + index, " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofBools())) { + TORCH_CHECK(serialized_arg_type == "as_bools", + "Expected extern kernel ", op_kernel.target_, + " to have serialized argument type as_bools for argument ", + index, " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofNumbers())) { + if (serialized_arg_type == "as_ints") { + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, + serialized_arg_val.size()); + } else if (serialized_arg_type == "as_floats") { + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg); + } + stack.at(index) = std::move(ret); + } else if (serialized_arg_type == "as_bools") { + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg); + } + stack.at(index) = std::move(ret); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a List[Scalar] input for argument ", index, + " but got ", serialized_arg_type); + } + } else if (schema_arg_type->isSubtypeOf( + at::ListType::ofOptionalTensors())) { + if (serialized_arg_type == "as_optional_tensors") { + std::vector list_item_types; + for (const auto &arg : serialized_arg_val) { + list_item_types.push_back(arg.begin().key()); + } + dynamic_args.emplace_back(index, + DynamicArgType::ListOptionalTensorType, + serialized_arg_val.size(), list_item_types); + } else if (serialized_arg_type == "as_tensors") { + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, + serialized_arg_val.size()); + } else { + TORCH_CHECK(false, "Expected extern kernel ", op_kernel.target_, + " to have a Tensor?[] input for argument ", index, + " but got ", serialized_arg_type); + } + } else if (schema_arg_type->isSubtypeOf(at::ListType::ofStrings())) { + TORCH_CHECK( + serialized_arg_type == "as_strings", "Expected extern kernel ", + op_kernel.target_, + " to have serialized argument type as_strings for argument ", index, + " but got ", serialized_arg_type); + std::vector ret; + for (const auto &arg : serialized_arg_val) { + ret.push_back(arg.get()); + } + stack.at(index) = std::move(ret); + } else { + TORCH_CHECK(false, "NYI: Unsupported list type ", serialized_arg_type, + " for extern kernel ", op_kernel.target_, " argument ", + index); + } + break; + } + case c10::TypeKind::OptionalType: { + auto inner_type = + schema_arg_type->castRaw()->getElementType(); + + if (serialized_arg_type == "as_none") { + stack.at(index) = c10::IValue{}; + if (inner_type->kind() == c10::TypeKind::TensorType) { + // Tensor is None + dynamic_args.emplace_back(index, DynamicArgType::TensorType, 0); + } else if (inner_type->kind() == c10::TypeKind::IntType || + inner_type->kind() == c10::TypeKind::SymIntType) { + // Int or SymInt is None + dynamic_args.emplace_back(index, DynamicArgType::IntType, 0); + } else if (inner_type->kind() == c10::TypeKind::ListType && + schema_arg_type->isSubtypeOf(at::ListType::ofTensors())) { + // List[Tensor] is None + dynamic_args.emplace_back(index, DynamicArgType::ListTensorType, 0); + } else if (inner_type->kind() == c10::TypeKind::ListType && + schema_arg_type->isSubtypeOf(at::ListType::ofSymInts())) { + // List[SymInt] is None + dynamic_args.emplace_back(index, DynamicArgType::ListIntType, 0); + } + } else { + prefill_stack_with_static_arguments(index, inner_type, serialized_arg, + op_kernel); + } + break; + } + default: + TORCH_CHECK(false, "Unsupported input type ", serialized_arg_type, + " for extern kernel ", op_kernel.target_, " argument ", + index); + } +} + +// Populates op_kernel.stack_, op_kernel.dynamic_args_ +void OSSProxyExecutorNpu::get_input_info_from_serialized( + const std::vector &schema_args, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel) { + std::vector filled(schema_args.size(), false); + TORCH_CHECK(op_kernel.stack_.size() == 0); + op_kernel.stack_.resize(schema_args.size()); + for (const auto &named_argument : serialized_node["inputs"]) { + const auto &arg = named_argument["arg"]; + const auto &name = named_argument["name"].get(); + + // Doing a linear lookup in the schema to find the index + // of a static argument. Should be fine performance wise + // because we usually only have small amount of arguments. + for (size_t index = 0; index < schema_args.size(); index++) { + auto &schema_arg = schema_args[index]; + if (schema_arg.name() == name) { + prefill_stack_with_static_arguments(index, schema_arg.real_type(), arg, + op_kernel); + filled[index] = true; + break; + } + } + } + + // If an argument is not filled and has a default value, we should + // also prefill the default value. + for (size_t index = 0; index < schema_args.size(); index++) { + if (!filled[index] && schema_args[index].default_value()) { + auto default_value = *schema_args[index].default_value(); + op_kernel.stack_.at(index) = default_value; + } + } +} + +// Populates op_kernel.outputs_ +void OSSProxyExecutorNpu::get_output_info_from_serialized( + const std::vector &schema_returns, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel) { + std::vector &outputs = op_kernel.outputs_; + + TORCH_CHECK(schema_returns.size() == serialized_node["outputs"].size(), + "Serialized node doesn't match operator ", + serialized_node["target"], "'s schema outputs."); + + size_t output_index = 0; + for (const auto &serialized_output : serialized_node["outputs"]) { + TORCH_CHECK(serialized_output.size() == 1); + std::string serialized_output_type = serialized_output.begin().key(); + auto &serialized_output_val = serialized_output.begin().value(); + + auto &schema_return = schema_returns[output_index]; + const at::TypePtr &schema_return_type = schema_return.real_type(); + + switch (schema_return_type->kind()) { + case c10::TypeKind::TensorType: { + TORCH_CHECK(serialized_output_type == "as_tensor", + "Expected extern kernel ", serialized_node["target"], + " to have serialized output type as_tensor, ", " but got ", + serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::TensorType, 1); + break; + } + case c10::TypeKind::ListType: { + if (schema_return_type->isSubtypeOf(at::ListType::ofTensors())) { + TORCH_CHECK(serialized_output_type == "as_tensors", + "Expected extern kernel ", serialized_node["target"], + " to have serialized output type as_tensors, ", + " but got ", serialized_output_type); + outputs.emplace_back(output_index, DynamicArgType::ListTensorType, + serialized_output_val.size()); + } else { + TORCH_CHECK(false, "Unsupported return list type ", + schema_return_type->repr_str()); + } + break; + } + default: { + TORCH_CHECK(false, "Unsupported return type ", + schema_return_type->repr_str(), " for extern kernel ", + op_kernel.target_); + } + } + + output_index++; + } +} + +OSSProxyExecutorNpu::OSSProxyExecutorNpu(const std::string &json_path, + bool is_cpu) { + if (is_cpu) { + device_ = std::make_unique(c10::DeviceType::CPU); + } else { + int device_idx = -1; + // Mock a cuda device for now. + device_ = std::make_unique(c10::DeviceType::CUDA, device_idx); + } + + std::string extern_kernel_nodes_serialized; + + std::ifstream json_file(json_path); + TORCH_CHECK(json_file.is_open(), "Unable to open file ", json_path); + + // Parse file into a json object + nlohmann::json json_obj; + json_file >> json_obj; + + // Access data + for (auto const &serialized_extern_node : json_obj["nodes"]) { + auto const &serialized_node = serialized_extern_node["node"]; + + const std::string &target = serialized_node["target"]; + + std::string opName; + std::string overloadName; + size_t pos = target.find('.'); + if (pos == std::string::npos) { + opName = target; + overloadName = ""; + } else { + // There should be no more periods + size_t pos2 = target.find('.', pos + 1); + TORCH_CHECK(pos2 == std::string::npos); + + opName = target.substr(0, pos); + overloadName = target.substr(pos + 1, target.length() - pos); + } + + c10::OperatorHandle op_handle = + c10::Dispatcher::singleton().findSchemaOrThrow(opName.c_str(), + overloadName.c_str()); + const c10::FunctionSchema &schema = op_handle.schema(); + + const auto &schema_args = schema.arguments(); + const auto &schema_returns = schema.returns(); + + OSSOpKernel op_kernel(target, op_handle); + get_input_info_from_serialized(schema_args, serialized_node, op_kernel); + get_output_info_from_serialized(schema_returns, serialized_node, op_kernel); + + op_kernels_.emplace_back(std::move(op_kernel)); + } +} + +void OSSProxyExecutorNpu::call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, + int num_tensors, + AtenTensorHandle *flatten_tensor_args) { + TORCH_CHECK(extern_node_index < static_cast(op_kernels_.size()), + "Invalid extern node index"); + OSSOpKernel &op_kernel = op_kernels_[extern_node_index]; + + std::vector stack = op_kernel.stack_; + auto &dynamic_args = op_kernel.dynamic_args_; + + int tensor_id = 0; + int int_id = 0; + for (auto &dynamic_arg : dynamic_args) { + int arg_index = dynamic_arg.arg_index; + DynamicArgType dynamic_arg_type = dynamic_arg.arg_type; + int length = dynamic_arg.length; + + if (length == 0) { + continue; + } + + switch (dynamic_arg_type) { + case DynamicArgType::TensorType: { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + stack[arg_index] = *tensor; + break; + } + case DynamicArgType::IntType: { + int64_t val = flatten_int_args[int_id++]; + stack[arg_index] = val; + break; + } + case DynamicArgType::ListTensorType: { + std::vector tensor_list; + for (int j = 0; j < length; j++) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + tensor_list.push_back(*tensor); + } + stack[arg_index] = tensor_list; + break; + } + case DynamicArgType::ListOptionalTensorType: { + std::vector> optional_tensor_list; + auto &list_item_types = dynamic_arg.list_item_types; + TORCH_CHECK( + list_item_types.has_value(), + "Could not find list of item types for optional tensor list input"); + + for (const std::string &item_type : list_item_types.value()) { + if (item_type == "as_tensor") { + at::Tensor *tensor = tensor_handle_to_tensor_pointer( + flatten_tensor_args[tensor_id++]); + optional_tensor_list.emplace_back(*tensor); + } else if (item_type == "as_none") { + optional_tensor_list.emplace_back(std::nullopt); + } + } + stack[arg_index] = optional_tensor_list; + break; + } + case DynamicArgType::ListIntType: { + std::vector vals; + vals.reserve(length); + for (int j = 0; j < length; j++) { + vals.push_back(flatten_int_args[int_id++]); + } + stack[arg_index] = vals; + break; + } + default: + TORCH_CHECK(false, "Unsupported dynamic arg type: ", dynamic_arg_type); + } + } + + int num_output_tensors = op_kernel.num_output_tensors(); + + // Call the op with the prepared stack. + + const c10::OperatorHandle &op = op_kernel.op_handle_; + op.callBoxed(stack); + + const c10::FunctionSchema &schema = op.schema(); + const auto &schema_returns = schema.returns(); + + TORCH_CHECK(op_kernel.outputs_.size() == stack.size()); + TORCH_CHECK(stack.size() == schema_returns.size()); + + int index = 0; + for (const auto &schema_return : schema_returns) { + if (schema_return.type()->kind() == c10::TypeKind::TensorType) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = stack[index++].toTensor(); + } else if (schema_return.type()->kind() == c10::TypeKind::ListType && + schema_return.type()->isSubtypeOf(at::ListType::ofTensors())) { + auto tensors = stack[index++].toTensorList(); + for (auto &&t : tensors) { + at::Tensor *tensor = + tensor_handle_to_tensor_pointer(flatten_tensor_args[tensor_id++]); + *tensor = t; + } + } else { + TORCH_CHECK(false, "NYI: Unsupported return type for schema: ", + schema_return.type()->repr_str()); + } + } + + TORCH_CHECK( + tensor_id == num_tensors, + "Mismatch between tensors consumed and num_tensors, got tensor_id = ", + tensor_id, ", expected num = ", num_tensors); +} +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.h b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.h new file mode 100644 index 0000000000..f2478bbeaa --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/oss_proxy_executor_npu.h @@ -0,0 +1,42 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace torch::aot_inductor { + +class OSSProxyExecutorNpu : public ProxyExecutor { + public: + explicit OSSProxyExecutorNpu(const std::string &json_path, bool is_cpu); + + void call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args) override; + + private: + void prefill_stack_with_static_arguments(size_t index, + const at::TypePtr &schema_arg_type, + const nlohmann::json &serialized_arg, + OSSOpKernel &op_kernel); + + void get_input_info_from_serialized( + const std::vector &schema_args, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel); + + void get_output_info_from_serialized( + const std::vector &schema_returns, + const nlohmann::json &serialized_node, OSSOpKernel &op_kernel); + + std::vector op_kernels_; + std::unique_ptr device_; +}; + +} // namespace torch::aot_inductor \ No newline at end of file diff --git a/torch_npu/csrc/inductor/aoti_torch/proxy_executor.h b/torch_npu/csrc/inductor/aoti_torch/proxy_executor.h new file mode 100644 index 0000000000..486fc80235 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/proxy_executor.h @@ -0,0 +1,19 @@ +#pragma once + +#include +#include +#include + +namespace torch::aot_inductor { + +class ProxyExecutor { + public: + ProxyExecutor() = default; + virtual ~ProxyExecutor() = default; + + virtual void call_function(int extern_node_index, int num_ints, + int64_t *flatten_int_args, int num_tensors, + AtenTensorHandle *flatten_tensor_args) = 0; +}; + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp b/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp new file mode 100644 index 0000000000..f920bae6e9 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/shim_npu.cpp @@ -0,0 +1,69 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#ifdef __cplusplus +extern "C" { +#endif +int32_t aoti_torch_device_type_npu() { + return (int32_t)c10::DeviceType::PrivateUse1; +} + +#ifdef __cplusplus +} // extern "C" +#endif + +namespace { +static c10::Device c10_device(int32_t device_type, int32_t device_index) { + if (device_type == aoti_torch_device_type_cpu()) { + return c10::Device(static_cast(device_type)); + } else { + return c10::Device(static_cast(device_type), + static_cast(device_index)); + } +} +} // namespace + +AOTITorchError aoti_torch_create_tensor_from_blob_npu( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, + AtenTensorHandle *ret_new_tensor) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + c10::IntArrayRef sizes(sizes_ptr, ndim); + c10::IntArrayRef strides(strides_ptr, ndim); + c10::Device device = c10_device(device_type, device_index); + c10::TensorOptions options = c10::TensorOptions().device(device).dtype( + static_cast(dtype)); + *ret_new_tensor = torch::aot_inductor::new_tensor_handle( + // data == nullptr can happen for a 0-size tensor + (data != nullptr) + ? at_npu::native::from_blob(data, sizes, strides, storage_offset, + options, device) + : at::empty_strided(sizes, strides, options)); + }); +} + +AOTITorchError aoti_torch_create_tensor_from_blob_npu_v2( + void *data, int64_t ndim, const int64_t *sizes_ptr, + const int64_t *strides_ptr, int64_t storage_offset, int32_t dtype, + int32_t device_type, int32_t device_index, AtenTensorHandle *ret_new_tensor, + int32_t layout, const uint8_t *opaque_metadata, + int64_t opaque_metadata_size) { + AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({ + if (layout == static_cast(at::kMkldnn)) { + throw std::runtime_error("do not support mkldnn on npu."); + } else { + aoti_torch_create_tensor_from_blob_npu(data, ndim, sizes_ptr, strides_ptr, + storage_offset, dtype, device_type, + device_index, ret_new_tensor); + } + }); +} \ No newline at end of file diff --git a/torch_npu/csrc/inductor/aoti_torch/tensor_converter.h b/torch_npu/csrc/inductor/aoti_torch/tensor_converter.h new file mode 100644 index 0000000000..f92d556eed --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/tensor_converter.h @@ -0,0 +1,25 @@ +#pragma once + +#include +#include + +namespace torch::aot_inductor { + +// Functions declared here are not meant to be called from the AOTInductor +// generated model.so + +// unsafe_alloc_new_handles_from_tensors is used for allocating new aten +// tensor objects and return them as a vector of AtenTensorHandle (raw +// pointers), and those pointers will be stolen by model.so. +TORCH_API std::vector unsafe_alloc_new_handles_from_tensors( + const std::vector &tensors); + +// alloc_tensors_by_stealing_from_handles is used for creating a vector of aten +// tensors by stealing from an array of handles. Only the handles are stolen, +// and the array itself is borrowed. +// +// WARNING: Can NOT be called in model.so +TORCH_API std::vector alloc_tensors_by_stealing_from_handles( + AtenTensorHandle *handles, size_t length); + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/aoti_torch/utils.h b/torch_npu/csrc/inductor/aoti_torch/utils.h new file mode 100644 index 0000000000..6764296582 --- /dev/null +++ b/torch_npu/csrc/inductor/aoti_torch/utils.h @@ -0,0 +1,217 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE(...) \ + try { \ + __VA_ARGS__ \ + } catch (const std::exception &e) { \ + LOG(ERROR) << "Exception in aoti_torch: " << e.what(); \ + return AOTI_TORCH_FAILURE; \ + } catch (...) { \ + LOG(ERROR) << "Exception in aoti_torch: UNKNOWN"; \ + return AOTI_TORCH_FAILURE; \ + } \ + return AOTI_TORCH_SUCCESS; + +namespace torch::aot_inductor { + +inline at::Tensor *tensor_handle_to_tensor_pointer(AtenTensorHandle handle) { + return reinterpret_cast(handle); +} + +inline AtenTensorHandle tensor_pointer_to_tensor_handle(at::Tensor *tensor) { + return reinterpret_cast(tensor); +} + +inline at::Tensor resolve_tensor_dispatch_flags(AtenTensorHandle handle) { + at::Tensor *tensor{tensor_handle_to_tensor_pointer(handle)}; + if (tensor->is_conj() || tensor->is_neg()) { + // If the conjugation or negation dispatch flags are set, runtime dispatch + // handles them by cloning the tensor before passing them to the native ATen + // function. Since the C-shim calls the native function directly, we have + // to handle the flags ourselves, or results will be silently incorrect. + return tensor->clone(); + } + return *tensor; +} + +inline std::optional resolve_tensor_dispatch_flags( + const AtenTensorHandle *handle) { + return handle ? std::make_optional(resolve_tensor_dispatch_flags(*handle)) + : std::nullopt; +} + +inline std::vector resolve_tensor_list_dispatch_flags( + const AtenTensorHandle *handle, int64_t len) { + std::vector ret{}; + ret.reserve(len); + for (int64_t i{0}; i < len; ++i) { + ret.emplace_back(resolve_tensor_dispatch_flags(handle[i])); + } + return ret; +} + +inline std::vector> +resolve_tensor_list_dispatch_flags(const AtenTensorHandle **handle, + int64_t len) { + std::vector> ret{}; + ret.reserve(len); + for (int64_t i{0}; i < len; ++i) { + ret.emplace_back(resolve_tensor_dispatch_flags(handle[i])); + } + return ret; +} + +inline at::Generator *generator_handle_to_generator_pointer( + AtenGeneratorHandle handle) { + return reinterpret_cast(handle); +} + +inline AtenGeneratorHandle generator_pointer_to_generator_handle( + at::Generator *generator) { + return reinterpret_cast(generator); +} + +inline AtenTensorHandle new_tensor_handle(at::Tensor &&tensor) { + at::Tensor *new_tensor = new at::Tensor(std::move(tensor)); + return tensor_pointer_to_tensor_handle(new_tensor); +} + +inline void assert_inf_and_nan(const std::string &tensor_name, + at::Tensor &check_tensor) { + auto isnan_tensor = check_tensor.isnan(); + if (isnan_tensor.any().item()) { + throw std::runtime_error("At least one NaN in " + tensor_name); + } + auto isinf_tensor = check_tensor.isinf(); + if (isinf_tensor.any().item()) { + throw std::runtime_error("At least one INF in " + tensor_name); + } +} + +// utility functions to convert a pointer to an optional value +template +inline std::optional pointer_to_optional(T *ptr) { + return ptr ? std::make_optional(*ptr) : std::nullopt; +} + +template >> +inline std::optional pointer_to_optional(U *ptr) { + return ptr ? std::make_optional(T(*ptr)) : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional(AtenTensorHandle *ptr) { + return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional( + const AtenTensorHandle *ptr) { + return ptr ? std::make_optional(*tensor_handle_to_tensor_pointer(*ptr)) + : std::nullopt; +} + +template <> +inline std::optional pointer_to_optional( + AtenGeneratorHandle *ptr) { + return ptr ? std::make_optional(*generator_handle_to_generator_pointer(*ptr)) + : std::nullopt; +} + +inline std::optional pointer_to_optional_device( + int32_t *device_type, int32_t device_index) { + return device_type ? std::make_optional(c10::Device( + static_cast(*device_type), + static_cast(device_index))) + : std::nullopt; +} + +// utility functions to convert a pointer to a list +template +struct is_optional : std::false_type {}; +template +struct is_optional> : std::true_type {}; + +template +inline c10::ArrayRef pointer_to_list(T *ptr, int64_t len) { + return c10::ArrayRef(ptr, len); +} + +template >, + typename = std::enable_if_t::value>> +inline std::vector pointer_to_list(U *ptr, int64_t len) { + // std::vector will be implicitly converted to c10::ArrayRef at the call + // site + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(T(ptr[i])); + } + return result; +} + +template ::value>> +inline std::vector pointer_to_list(U **ptr, int64_t len) { + // Here U** denotes a list of optional arguments + // std::vector will be implicitly converted to c10::ArrayRef at the call + // site + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(pointer_to_optional(ptr[i])); + } + return result; +} + +template <> +inline std::vector pointer_to_list(const AtenTensorHandle *ptr, + int64_t len) { + std::vector result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(*tensor_handle_to_tensor_pointer(ptr[i])); + } + return result; +} + +template <> +inline std::vector> pointer_to_list( + const AtenTensorHandle **ptr, int64_t len) { + std::vector> result; + result.reserve(len); + for (int64_t i = 0; i < len; i++) { + result.emplace_back(pointer_to_optional(ptr[i])); + } + return result; +} + +template +inline std::array pointer_to_list(const int32_t *ptr) { + std::array result; + std::copy(ptr, ptr + N, result.begin()); + return result; +} + +// Utility function to convert a pointer to an optional list of values +template +inline std::optional> pointer_to_optional_list(U **ptr, + int64_t len) { + return ptr ? std::make_optional>( + pointer_to_list(*ptr, len)) + : std::nullopt; +} + +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/array_ref_impl.h b/torch_npu/csrc/inductor/array_ref_impl.h new file mode 100644 index 0000000000..1c0aeb549e --- /dev/null +++ b/torch_npu/csrc/inductor/array_ref_impl.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include +#include +#include + +namespace torch::aot_inductor { +template +void convert_output_to_handle(const ArrayRefTensor &output, + AtenTensorHandle &handle) { + handle = output.expensiveCopyToTensor(); +} + +template +void convert_outputs_to_handles_helper( + const std::tuple...> &outputs, + AtenTensorHandle *output_handles, std::index_sequence) { + (convert_output_to_handle(std::get(outputs), output_handles[Is]), ...); +} +template +void convert_outputs_to_handles( + const std::tuple...> &outputs, + AtenTensorHandle *output_handles) { + convert_outputs_to_handles_helper(outputs, output_handles, + std::make_index_sequence()); +} + +template +void convert_handle_to_arrayref_tensor(AtenTensorHandle handle, + ArrayRefTensor &input) { + void *data_ptr; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_data_ptr(handle, &data_ptr)); + int64_t dim; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dim(handle, &dim)); + int64_t numel; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_numel(handle, &numel)); + int64_t *sizes; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_sizes(handle, &sizes)); + int64_t *strides; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_strides(handle, &strides)); + int32_t dtype; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(handle, &dtype)); + int32_t device_type; + AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(handle, &device_type)); + int32_t device_index; + AOTI_TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(handle, &device_index)); + + input = ArrayRefTensor( + MiniArrayRef(reinterpret_cast(data_ptr), numel), + MiniArrayRef(sizes, dim), + MiniArrayRef(strides, dim), device_type, device_index); +} + +template +void convert_handles_to_inputs_helper(AtenTensorHandle *input_handles, + std::tuple...> &inputs, + std::index_sequence) { + (convert_handle_to_arrayref_tensor(input_handles[Is], std::get(inputs)), + ...); +} + +template +void convert_handles_to_inputs(AtenTensorHandle *input_handles, + std::tuple...> &inputs) { + convert_handles_to_inputs_helper(input_handles, inputs, + std::make_index_sequence()); +} + +template +void assert_numel(const ArrayRefTensor &tensor, uint64_t numel) { + if (tensor.numel() != numel) { + std::stringstream err; + err << "incorrect numel for input tensor. expected " << numel << ", got " + << tensor.numel(); + throw std::runtime_error(err.str()); + } +} +} // namespace torch::aot_inductor diff --git a/torch_npu/csrc/inductor/inductor_ops.h b/torch_npu/csrc/inductor/inductor_ops.h new file mode 100644 index 0000000000..924d54050c --- /dev/null +++ b/torch_npu/csrc/inductor/inductor_ops.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +namespace torch::inductor { + +TORCH_API at::Tensor _mm_plus_mm_out(at::Tensor &out, const at::Tensor &a, + const at::Tensor &b, const at::Tensor &c, + const at::Tensor &d); + +// After adding _mm_plus_mm_out, this should not be exposed and called by model +// code. Keeping it around for backward compatibility. Will be deprecated later. +TORCH_API at::Tensor _mm_plus_mm(const at::Tensor &a, const at::Tensor &b, + const at::Tensor &c, const at::Tensor &d, + at::Tensor &out); + +TORCH_API at::Tensor _alloc_from_pool(const at::Tensor &self, + int64_t offset_bytes, + at::ScalarType dtype, + at::IntArrayRef size, + at::IntArrayRef stride); + +// Similar to as_strided with the following differences +// - offset is added to the existing offset (rather than replacing it) +// - view tracking is disabled similar to unsafe_view +TORCH_API at::Tensor _reinterpret_tensor(const at::Tensor &self, + at::IntArrayRef size, + at::IntArrayRef stride, + int64_t offset_increment = 0); + +} // namespace torch::inductor -- Gitee