diff --git a/torch_npu/testing/common_distributed.py b/torch_npu/testing/common_distributed.py index 3a6367622598fce17b142fad659a0c1316680bc6..cf4e4dcb40f5b2681bded0b9b92595f975f7ad09 100644 --- a/torch_npu/testing/common_distributed.py +++ b/torch_npu/testing/common_distributed.py @@ -38,6 +38,17 @@ def skipIfUnsupportMultiNPU(npu_number_needed): return skip_dec +def skipOpOnlySupport910B(op_name): + def skip_dec(func): + def wrapper(self): + device_name = torch_npu.npu.get_device_name(0)[:10] + if device_name != 'Ascend910B': + raise unittest.SkipTest(f"The op {op_name} is only suppport on 910B, skip this ut for this device type!") + return func(self) + return wrapper + return skip_dec + + def with_comms(func): if func is None: raise RuntimeError("Test function is None.")