diff --git a/torch_npu/meta/_meta_registrations.py b/torch_npu/meta/_meta_registrations.py index 93db00131f7c4be7fddc64050e96734ac6110c9b..0a2e6be69b7887154ae704b8e0b9a31a1801ee29 100644 --- a/torch_npu/meta/_meta_registrations.py +++ b/torch_npu/meta/_meta_registrations.py @@ -1036,6 +1036,50 @@ def npu_prefetch_meta(self, dependency, max_size, offset=0): ) +@impl(m, "npu_dequant_swiglu_quant") +def npu_dequant_swiglu_quant_meta(x, weight_scale, activation_scale, bias, quant_scale, + quant_offset, group_index, activate_left=False, quant_mode=0): + y_size = [] + scale_size = [] + for i in range(x.dim() - 1): + y_size.append(x.size(i)) + scale_size.append(x.size(i)) + y_size.append(math.floor(x.size(x.dim() - 1) / 2)) + return (torch.empty(y_size, dtype=torch.int8, device=x.device), + torch.empty(scale_size, dtype=torch.float32, device=x.device)) + + +@impl(m, "npu_single_rope") +def npu_single_rope_meta(x, cos, sin): + return torch.empty_like(x) + + +@impl(m, "npu_kv_rmsnorm_rope_cache") +def npu_kv_rmsnorm_rope_cache_meta(x, gamma, cos, sin, index, k_cache, v_cache, epsilon=1e-5): + return (k_cache, v_cache) + + +@impl(m, "npu_moe_gating_top_k") +def npu_moe_gating_top_k_meta(x, bias=None, k=1, k_group=1, group_count=1, group_select_mode=0, renorm=0, norm_type=0, y2_flag=False, routed_scaling_factor=1.0, eps=1e-20): + x_dim = x.dim() + torch._check( + x_dim == 2, + lambda: "the x shape support only 2d)" + ops_error(ErrCode.VALUE), + ) + if bias is not None: + bias_dim = bias.dim() + torch._check( + bias_dim == 1, + lambda: "the bias shape support only 1d)" + ops_error(ErrCode.VALUE), + ) + y_dim_list = [x.size(0), k] + expert_idx_dim_list = [x.size(0), k] + y2_dim_list = [x.size(0), x.size(1)] + return (x.new_empty(tuple(y_dim_list), dtype=x.dtype), + x.new_empty(tuple(expert_idx_dim_list), dtype=torch.int32), + x.new_empty(tuple(y_dim_list), dtype=torch.float32)) + + @impl(m, "npu_swiglu") def npu_swiglu_meta(x, dim=-1): output_size = []