From ac3df001c72c99b4cdc7a50de3348ee1bfd3459b Mon Sep 17 00:00:00 2001 From: rmch Date: Thu, 9 May 2024 17:43:39 +0800 Subject: [PATCH 1/5] add test comm coverter for all_reduce and all_gather --- test/dynamo/test_comm_converter.py | 114 +++++++++++++++++++++++++++++ 1 file changed, 114 insertions(+) create mode 100644 test/dynamo/test_comm_converter.py diff --git a/test/dynamo/test_comm_converter.py b/test/dynamo/test_comm_converter.py new file mode 100644 index 0000000000..88292dcbeb --- /dev/null +++ b/test/dynamo/test_comm_converter.py @@ -0,0 +1,114 @@ +import os +from copy import deepcopy + +import torch +import torch_npu +import torchair + +import torch.distributed as dist +from torch import nn +import torch.distributed +import torch.multiprocessing as mp +import torch.distributed._functional_collectives as fcol + +from torch._dynamo.test_case import TestCase +from torch._dynamo.testing import normalize_gm + + +DIM = 200 + +class Net(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(DIM, DIM) + + def forward(self, x, group): + _fc1 = self.fc1(x) + _fc1 = fcol.all_reduce(_fc1, "sum", group=group) + _fc1 = fcol.reduce_scatter_tensor(_fc1, "sum", scatter_dim=0, group=group) + _fc1 = _fc1.reshape(2, -1)[0] + return _fc1 + + +def _test_compile( + rank, + world_size, +): + backend = "hccl" + dist.init_process_group( + backend=backend, + rank=rank, + world_size=world_size + ) + + graph = None + + def compiler_fn(gm): + def inner_compiler(gm_, example_inputs_): + nonlocal graph + if not graph is None: + raise AssertionError('TestCommConverter Failed, before run, graph should be None') + graph = gm_ + graph = normalize_gm(graph.print_readable(False)) + return torchair.get_npu_backend()(gm_, example_inputs_) + + return torch.compile( + gm, backend=inner_compiler, dynamic=False, fullgraph=True + ) + + torch_npu.npu.set_device(f"npu:{rank}") + device = torch.device("npu") + torch.manual_seed(123) + model = Net().to(device) + + compiled_model = compiler_fn(deepcopy(model)) + group = torch.distributed.distributed_c10d._get_default_group() + ret = [] + for i in range(3): + torch.manual_seed(123 + rank + i) + input = torch.randn([DIM, DIM], device=device) + compiled_output = compiled_model(input, group) + loss_output = model(input, group) + expect = """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_ : torch.Tensor): + l_x_ = L_x_ + + _fc1 = self.L__self___fc1(l_x_); l_x_ = None + + tensor = torch.ops.c10d_functional.all_reduce(_fc1, 'sum', 'ptd:0', [0, 1], 2); _fc1 = None + + _fc1_1 = torch.ops.c10d_functional.wait_tensor(tensor); tensor = None + + tensor_1 = torch.ops.c10d_functional.reduce_scatter_tensor(_fc1_1, 'sum', 'ptd:0', [0, 1], 2); _fc1_1 = None + + _fc1_2 = torch.ops.c10d_functional.wait_tensor(tensor_1); tensor_1 = None + + reshape = _fc1_2.reshape(2, -1); _fc1_2 = None + _fc1_3 = reshape[0]; reshape = None + return (_fc1_3,) +""" + if expect != graph: + raise RuntimeError('TestCommConverter Failed, fx graph is not expected') + if not (compiled_output == loss_output).all(): + raise RuntimeError('TestCommConverter Failed, dynamo outputs are not equal to eager outputs') + + +def mp_main(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + os.environ['TORCH_DISABLE_NATIVE_FUNCOL'] = '1' + _test_compile(rank=rank, world_size=world_size) + + +class TestCommConverter(TestCase): + def test_comm_converter(self): + world_size = 2 + mp.spawn(mp_main, args=(world_size,), nprocs=world_size, join=True) + + +if __name__ == "__main__": + + from torch._dynamo.test_case import run_tests + + run_tests() -- Gitee From 991e6a25467a2aa3c9cbb80676a18d716597968d Mon Sep 17 00:00:00 2001 From: rmch Date: Thu, 9 May 2024 19:01:41 +0800 Subject: [PATCH 2/5] add test comm coverter for all_reduce and all_reduce_scatter --- test/dynamo/test_comm_converter.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_comm_converter.py b/test/dynamo/test_comm_converter.py index 88292dcbeb..5172f463a8 100644 --- a/test/dynamo/test_comm_converter.py +++ b/test/dynamo/test_comm_converter.py @@ -2,8 +2,6 @@ import os from copy import deepcopy import torch -import torch_npu -import torchair import torch.distributed as dist from torch import nn @@ -14,9 +12,12 @@ import torch.distributed._functional_collectives as fcol from torch._dynamo.test_case import TestCase from torch._dynamo.testing import normalize_gm +import torch_npu + DIM = 200 + class Net(nn.Module): def __init__(self): super().__init__() @@ -50,6 +51,7 @@ def _test_compile( raise AssertionError('TestCommConverter Failed, before run, graph should be None') graph = gm_ graph = normalize_gm(graph.print_readable(False)) + import torchair return torchair.get_npu_backend()(gm_, example_inputs_) return torch.compile( -- Gitee From 728944ed68ebc5adca5c09de678aeed844f9998b Mon Sep 17 00:00:00 2001 From: rmch Date: Thu, 9 May 2024 19:02:48 +0800 Subject: [PATCH 3/5] add test comm coverter for all_reduce and all_reduce_scatter --- test/dynamo/test_comm_converter.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_comm_converter.py b/test/dynamo/test_comm_converter.py index 5172f463a8..e0b627b164 100644 --- a/test/dynamo/test_comm_converter.py +++ b/test/dynamo/test_comm_converter.py @@ -47,7 +47,7 @@ def _test_compile( def compiler_fn(gm): def inner_compiler(gm_, example_inputs_): nonlocal graph - if not graph is None: + if graph is not None: raise AssertionError('TestCommConverter Failed, before run, graph should be None') graph = gm_ graph = normalize_gm(graph.print_readable(False)) @@ -68,9 +68,9 @@ def _test_compile( ret = [] for i in range(3): torch.manual_seed(123 + rank + i) - input = torch.randn([DIM, DIM], device=device) - compiled_output = compiled_model(input, group) - loss_output = model(input, group) + input_tensor = torch.randn([DIM, DIM], device=device) + compiled_output = compiled_model(input_tensor, group) + loss_output = model(input_tensor, group) expect = """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): -- Gitee From f93d2d91b382ed11ab0d32d4e1158e803309301b Mon Sep 17 00:00:00 2001 From: rmch Date: Thu, 9 May 2024 19:32:29 +0800 Subject: [PATCH 4/5] add test comm coverter for all_reduce and all_reduce_scatter --- test/dynamo/test_comm_converter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/dynamo/test_comm_converter.py b/test/dynamo/test_comm_converter.py index e0b627b164..314271c758 100644 --- a/test/dynamo/test_comm_converter.py +++ b/test/dynamo/test_comm_converter.py @@ -105,7 +105,7 @@ def mp_main(rank, world_size): class TestCommConverter(TestCase): def test_comm_converter(self): - world_size = 2 + world_size = 1 mp.spawn(mp_main, args=(world_size,), nprocs=world_size, join=True) -- Gitee From 876664db95d251af712892c0ac4cecfc42891106 Mon Sep 17 00:00:00 2001 From: rmch Date: Thu, 9 May 2024 20:07:09 +0800 Subject: [PATCH 5/5] add test comm coverter for all_reduce and all_reduce_scatter --- test/dynamo/test_comm_converter.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/test/dynamo/test_comm_converter.py b/test/dynamo/test_comm_converter.py index 314271c758..65272ee483 100644 --- a/test/dynamo/test_comm_converter.py +++ b/test/dynamo/test_comm_converter.py @@ -27,6 +27,8 @@ class Net(nn.Module): _fc1 = self.fc1(x) _fc1 = fcol.all_reduce(_fc1, "sum", group=group) _fc1 = fcol.reduce_scatter_tensor(_fc1, "sum", scatter_dim=0, group=group) + _fc1 = fcol.all_gather_tensor(_fc1, 0, group=group) + _fc1 = fcol.all_to_all_single(_fc1, None, None, group=[0, 1]) _fc1 = _fc1.reshape(2, -1)[0] return _fc1 @@ -86,9 +88,17 @@ class GraphModule(torch.nn.Module): _fc1_2 = torch.ops.c10d_functional.wait_tensor(tensor_1); tensor_1 = None - reshape = _fc1_2.reshape(2, -1); _fc1_2 = None - _fc1_3 = reshape[0]; reshape = None - return (_fc1_3,) + tensor_2 = torch.ops.c10d_functional.all_gather_into_tensor(_fc1_2, 'ptd:0', [0, 1], 2); _fc1_2 = None + + _fc1_3 = torch.ops.c10d_functional.wait_tensor(tensor_2); tensor_2 = None + + tensor_3 = torch.ops.c10d_functional.all_to_all_single(_fc1_3, None, None, '', [0, 1], 2); _fc1_3 = None + + _fc1_4 = torch.ops.c10d_functional.wait_tensor(tensor_3); tensor_3 = None + + reshape = _fc1_4.reshape(2, -1); _fc1_4 = None + _fc1_5 = reshape[0]; reshape = None + return (_fc1_5,) """ if expect != graph: raise RuntimeError('TestCommConverter Failed, fx graph is not expected') -- Gitee