From 28e630a71ec2768fa74461a325e76687e721c1ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=94=B0=E9=87=8E?= Date: Sun, 29 Sep 2024 15:56:24 +0800 Subject: [PATCH] test empty --- torch_npu/meta/_meta_registrations.py | 1 + torch_npu/onnx/wrapper_onnx_ops.py | 1 + 2 files changed, 2 insertions(+) diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index 2e839b86ac..4b551f6ca8 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -39,6 +39,7 @@ def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, return torch.empty_like(query, dtype=query.dtype) + @impl(m, "npu_fusion_attention") def npu_fusion_attention_forward(query, key, value, head_num, input_layout, pse=None, padding_mask=None, atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, diff --git a/torch_npu/onnx/wrapper_onnx_ops.py b/torch_npu/onnx/wrapper_onnx_ops.py index 8217a22461..29b012fdf8 100644 --- a/torch_npu/onnx/wrapper_onnx_ops.py +++ b/torch_npu/onnx/wrapper_onnx_ops.py @@ -855,6 +855,7 @@ class _NPUQuantizeOP(torch.autograd.Function): acl_dtype = 3 else: raise ValueError("The argument 'dtype' must be torch.quint8, torch.qint8 or torch.qint32") + return g.op("npu::NPUQuantize", inputs, scales, zero_points, dtype_i=acl_dtype, axis_i=axis, div_mode_i=div_mode) -- Gitee