diff --git a/test/profiler/test_export_memory_timeline.py b/test/profiler/test_export_memory_timeline.py index f4200f13db37802a67753f5eacbbba87f2449887..01413a2be2ca164ea96a17e07b1dcab07b0d2c6f 100644 --- a/test/profiler/test_export_memory_timeline.py +++ b/test/profiler/test_export_memory_timeline.py @@ -1,13 +1,45 @@ import os -import stat +import io +import functools +import gc +import itertools as it +import textwrap +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Set +from unittest.mock import patch import torch +from torch.profiler import _utils +from torch.utils._pytree import tree_flatten +from torch.testing._internal.common_utils import skipIfTorchDynamo + import torch_npu from torch_npu.utils.path_manager import PathManager from torch_npu.profiler._profiler_path_creator import ProfPathCreator +from torch_npu.profiler.analysis.prof_parse._event_tree_parser import EventTree +from torch_npu.profiler.analysis.prof_view import _memory_timeline_parser +from torch_npu.profiler.analysis.prof_view._memory_timeline_parser import ( + _EventType, + _DeviceType, + _TensorMetadata, + Category, + TensorKeyAndVersion, + TensorKey, + Action, + MemoryProfile, +) from torch_npu.testing.testcase import TestCase, run_tests +_DEVICE_DICT = { + "cpu": _DeviceType.CPU.value, + "npu": _DeviceType.NPU.value, +} + + +device = "npu:0" +profile = functools.partial(torch_npu.profiler.profile, record_shapes=True, profile_memory=True, with_stack=True) + + class SimpleCNN(torch.nn.Module): def __init__(self): super(SimpleCNN, self).__init__() @@ -32,12 +64,11 @@ class SimpleCNN(torch.nn.Module): class TrainModel: def __init__(self): - self.device = "npu:0" - self.model = SimpleCNN().to(self.device) + self.model = SimpleCNN().to(device) self.criterion = torch.nn.CrossEntropyLoss() self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4) - self.inputs = torch.randn(1, 3, 8, 8, device=self.device) + self.inputs = torch.randn(1, 3, 8, 8, device=device) self.labels = torch.rand_like(self.model(self.inputs)) def train_one_step(self): @@ -135,5 +166,721 @@ class TestExportMemoryTimeline(TestCase): return False +@skipIfTorchDynamo("TorchDynamo removes profiler altogether.") +class TestMemoryProfiler(TestCase): + @patch("sys.stdout", new_callable=io.StringIO) + def test_config_check(self, mock_stdout) -> None: + with torch_npu.profiler.profile() as prof: + x = torch.ones((1,), device=device) + prof.export_memory_timeline(output_path="test.json") + prof_dir = ProfPathCreator().get_prof_dir() + PathManager.remove_path_safety(prof_dir) + self.assertIn("record_shapes=True, profile_memory=True, with_stack=True or with_modules=True", + mock_stdout.getvalue()) + + with torch_npu.profiler.profile(record_shapes=True, with_stack=True) as prof: + x = torch.ones((1,), device=device) + prof.export_memory_timeline(output_path="test.json") + prof_dir = ProfPathCreator().get_prof_dir() + PathManager.remove_path_safety(prof_dir) + self.assertIn("profile_memory=True required for memory profiling", mock_stdout.getvalue()) + + +class ScaleLayer(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.scale = torch.nn.Parameter(torch.rand(()), requires_grad=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x * self.scale + + +class LazyLinear(torch.nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + def forward(self, x) -> torch.Tensor: + if getattr(self, "weight", None) is None: + self.weight = torch.nn.Parameter( + torch.empty((self.out_features, self.in_features), device=device) + ) + self.bias = torch.nn.Parameter(torch.empty(self.out_features, device=device)) + return torch.nn.functional.linear(x, self.weight, self.bias) + + +@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.") +class TestIdentifyGradients(TestCase): + def gradient_detected( + self, + prof_dir: str, + ctx: _EventType, + grad_tensor: torch.Tensor, + parameter: Optional[torch.Tensor] = None, + ) -> None: + # This is not an exhaustive check, but for the purpose of unit testing + # it is sufficient. + def key_matches_tensor(key, tensor) -> bool: + # Vacuous case. + if tensor is None: + return True + + if key is None: + return False + + return tensor.storage().data_ptr() == key.storage.ptr + + tree = EventTree(prof_dir) + for node in _utils.traverse_dfs(tree.get_root_nodes()): + if node.tag == ctx: + gradients_list = list(_memory_timeline_parser.extract_gradients(node)) + if not gradients_list: + continue + + for p_grad_key in gradients_list: + if key_matches_tensor(p_grad_key, grad_tensor): + return True + + return False + + def assertGradientDetected(self, name: str, *args, **kwargs) -> None: + self.assertTrue( + self.gradient_detected(*args, **kwargs), + f"Failed to identify gradient `{name}` from profile.", + ) + + def assertOnlyGradients( + self, prof_dir: str, tensors: Iterator[torch.Tensor] + ) -> None: + allowed_set = {t.storage().data_ptr() for t in tensors} + tree = EventTree(prof_dir) + for node in _utils.traverse_dfs(tree.get_root_nodes()): + for p_grad_key in _memory_timeline_parser.extract_gradients(node): + self.assertTrue( + p_grad_key.storage.ptr in allowed_set, + f"Tensor wrongly marked as gradient: {node.name}: {p_grad_key}", + ) + + def test_extract_gradients_low_level(self) -> None: + x = torch.ones((1,), device=device) + w0 = torch.ones((1,), requires_grad=True, device=device) + w1 = torch.ones((1,), requires_grad=True, device=device) + + def check(cold_start: bool): + self.assertEqual(w0.grad is None, cold_start) + self.assertEqual(w1.grad is None, cold_start) + with profile() as _: + z = x.expand(4) * w0 + (z * w1).sum().backward() + prof_dir = ProfPathCreator().get_prof_dir() + + # Gradient detection through op inspection does not provide a + # reference to the parameter corresponding to the gradient. + self.assertGradientDetected("w0", prof_dir, _EventType.TorchOp, w0.grad) + self.assertGradientDetected("w1", prof_dir, _EventType.TorchOp, w1.grad) + self.assertOnlyGradients(prof_dir, (w0.grad, w1.grad)) + + PathManager.remove_path_safety(prof_dir) + + check(cold_start=True) + check(cold_start=False) + + def test_extract_gradients_from_module(self) -> None: + model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer()).to(device) + named_parameters = dict(model.named_parameters()) + self.assertEqual(len(named_parameters), 3) + + def assert_only_gradients(prof_dir: str): + gradients = tuple(i.grad for i in named_parameters.values()) + self.assertFalse(any(i is None for i in gradients)) + self.assertOnlyGradients(prof_dir, gradients) + + def check(cold_start: bool): + x = torch.ones((2, 2), device=device) + with profile() as _: + model(x).sum().backward() + prof_dir = ProfPathCreator().get_prof_dir() + + for name, p in named_parameters.items(): + self.assertNotEqual( + self.gradient_detected(prof_dir, _EventType.PyCall, p.grad, p), + cold_start, + name, + ) + + # Op based detection should still identify the gradients. + self.assertGradientDetected(name, prof_dir, _EventType.TorchOp, p.grad) + assert_only_gradients(prof_dir) + + PathManager.remove_path_safety(prof_dir) + + # We can detect gradients even when `.backward()` is not called. + with profile() as _: + model(torch.ones((2, 2), device=device)) + prof_dir = ProfPathCreator().get_prof_dir() + + for name, p in named_parameters.items(): + self.assertGradientDetected(name, prof_dir, _EventType.PyCall, p.grad, p) + self.assertFalse( + self.gradient_detected(prof_dir, _EventType.TorchOp, p.grad), name + ) + assert_only_gradients(prof_dir) + + PathManager.remove_path_safety(prof_dir) + + check(cold_start=True) + check(cold_start=False) + + def test_extract_gradients_from_optimizer(self) -> None: + x = torch.ones((1,), device=device) + w0 = torch.ones((1,), requires_grad=True, device=device) + w1 = torch.ones((1,), requires_grad=True, device=device) + optimizer = torch.optim.SGD((w0, w1), lr=0.1, momentum=0.9) + + def check(cold_start: bool): + self.assertEqual(w0.grad is None, cold_start) + self.assertEqual(w1.grad is None, cold_start) + with profile() as _: + optimizer.zero_grad() + z = x.expand(4) * w0 + (z * w1).sum().backward() + optimizer.step() + prof_dir = ProfPathCreator().get_prof_dir() + + # Optimizer instrumentation runs late in the step, so we can detect + # gradients for both cold and warm start. + self.assertGradientDetected("w0", prof_dir, _EventType.PyCall, w0.grad, w0) + self.assertGradientDetected("w1", prof_dir, _EventType.PyCall, w1.grad, w1) + + self.assertGradientDetected("w0", prof_dir, _EventType.TorchOp, w0.grad) + self.assertGradientDetected("w1", prof_dir, _EventType.TorchOp, w1.grad) + self.assertOnlyGradients(prof_dir, (w0.grad, w1.grad)) + + PathManager.remove_path_safety(prof_dir) + + with profile() as _: + for _ in range(2): + optimizer.zero_grad() + z = x.expand(4) * w0 + (z * w1).sum().backward() + optimizer.step() + prof_dir = ProfPathCreator().get_prof_dir() + + self.assertTrue(self.gradient_detected(prof_dir, _EventType.PyCall, w0.grad, w0)) + self.assertTrue(self.gradient_detected(prof_dir, _EventType.PyCall, w1.grad, w1)) + + PathManager.remove_path_safety(prof_dir) + + check(cold_start=True) + check(cold_start=False) + + def test_extract_gradients_from_module_and_optimizer(self) -> None: + # Module and optimizer are thoroughly tested individually and should be + # additive. Thus we can manage with a lightweight check that they don't + # interact adversely. + model = torch.nn.Sequential(torch.nn.Linear(2, 1), ScaleLayer()).to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) + with profile() as _: + model(torch.ones((2, 2), device=device)).sum().backward() + optimizer.step() + prof_dir = ProfPathCreator().get_prof_dir() + + self.assertGradientDetected( + "weight", prof_dir, _EventType.PyCall, model[0].weight.grad, model[0].weight + ) + + PathManager.remove_path_safety(prof_dir) + + +@skipIfTorchDynamo("TorchDynamo removes profiler altogether.") +class TestDataFlow(TestCase): + def setUp(self) -> None: + super().setUp() + self.maxDiff = None + + @staticmethod + def formatSchemas( + prof_dir: str, indent: int = 12 + ) -> Tuple[Tuple[str, Tuple[bool, ...]], ...]: + tree = EventTree(prof_dir) + out: List[Tuple[str, Tuple[bool, ...]]] = [] + for node in _utils.traverse_dfs(tree.get_root_nodes()): + if node.tag == _EventType.TorchOp: + e = node.extra_fields + schemas = _memory_timeline_parser.SchemaMatcher.match_schemas(e) + name = node.name + if len(schemas) == 1: + name = f"{name}.{schemas[0].overload_name}" + elif len(schemas) > 1: + name = f"{name}.{{{', '.join(s.overload_name for s in schemas)}}}" + out.append((name, _memory_timeline_parser.SchemaMatcher.inputs_are_mutable(e))) + return tuple(out) + + @staticmethod + def _run_and_format_data_flow( + inputs: Dict[str, torch.Tensor], + f: Callable[..., Optional[Dict[str, torch.Tensor]]], + indent: int = 12, + ) -> str: + with profile() as _: + outputs = f(**inputs) or {} + gc.collect() + prof_dir = ProfPathCreator().get_prof_dir() + memory_profile = MemoryProfile(prof_dir) + PathManager.remove_path_safety(prof_dir) + + graph = memory_profile._data_flow_graph + storage_to_id = {key.storage.ptr: key.id for key in graph._active_version} + + lines: List[str] = [] + for name, t in it.chain(inputs.items(), outputs.items()): + lines.append(f"{name + ':':<8} T{storage_to_id.get(t.storage().data_ptr(), '?')}") + if t.grad is not None: + grad_id = storage_to_id[t.grad.storage().data_ptr()] + lines.append(f"{name + '.grad:':<9} T{grad_id}") + + if lines: + lines.append("") + + for node in graph.flow_nodes: + destroyed = {k for k, v in node._edges.items() if v.is_deletion} + + inputs: List[str] = [] + for key, (_, v) in node.inputs.items(): + inputs.append(f"T{key.id}(v{v}{'*' if key in destroyed else ''})") + + outputs = [f"T{key.id}(v{v})" for key, v in node.outputs.items()] + if inputs or outputs: + event_name = node._event.name.replace("torch::autograd::", "") + lines.append( + f"{event_name:<25} {', '.join(inputs):<15} -> {', '.join(outputs)}" + ) + + return textwrap.indent("\n".join([line.rstrip() for line in lines]), " " * indent) + + def test_data_flow_graph_with_annotations(self) -> None: + def f(x, y): + # torch._C._jit_get_schemas_for_operator will reject any name that + # is missing a namespace. (denoted by the presence of "::") We want + # to check that we skip both annotations which have no schema + # (return empty tuple from SchemaMatcher.lookup_schemas) and + # annotations which cannot have schema (return None from + # SchemaMatcher.lookup_schemas). + with torch.profiler.record_function("Namespaced Annotation"): + with torch.profiler.record_function("My Annotation"): + x.zero_() + y.zero_() + x0 = torch.ones_like(x, device=device) + y0 = torch.zeros_like(y, device=device) + return {"x0": x0, "y0": y0} + + inputs = {"x": torch.ones((1,), device=device), "y": torch.ones((1,), device=device)} + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f), + """\ + x: T0 + y: T1 + x0: T? + y0: T? + + aten::zero_ T0(v0) -> T0(v1) + aten::zero_ T1(v0) -> T1(v1) + aten::ones_like T0(v1) -> + aten::zeros_like T1(v1) ->""") + + def test_data_flow_graph_non_op_allocations(self) -> None: + def f(x): + x.mul(2) + + self.assertExpectedInline( + self._run_and_format_data_flow({"x": torch.ones((1,), device=device)}, f), + """\ + x: T1 + + [Memory] -> T0(v0) + aten::mul T0(v0), T1(v0) -> + [Memory] T0(v0*) ->""", + ) + + def test_data_flow_graph_simple(self) -> None: + inputs = { + "x": torch.ones((25,), device=device), + "y": torch.ones((25,), requires_grad=True, device=device) + } + + def f0(x, y): + z = x.mul(y) + return {"z": z.view_as(z)} + + def f1(x, y): + with torch.no_grad(): + return f0(x, y) + + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f0), + """\ + x: T0 + y: T1 + z: T2 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::view_as T2(v0) ->""", + ) + + # Out of place is identical regardless of Autograd. + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f1), + """\ + x: T0 + y: T1 + z: T2 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::view_as T2(v0) ->""", + ) + + def test_data_flow_graph_simple_inplace(self) -> None: + inputs = { + "x": torch.ones((25,), device=device), + "y": torch.ones((25,), requires_grad=True, device=device) + } + + def f0(x, y): + x.mul_(y) + + def f1(x, y): + with torch.no_grad(): + return f0(x, y) + + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f0), + """\ + x: T0 + y: T1 + + aten::mul_ T0(v0), T1(v0) -> T0(v1), T2(v0)""", + ) + + self.assertExpectedInline( + self._run_and_format_data_flow(inputs, f1), + """\ + x: T0 + y: T1 + + aten::mul_ T0(v0), T1(v0) -> T0(v1)""", + ) + + def test_data_flow_graph_simple_backward(self) -> None: + inputs = { + "x": torch.ones((1,), device=device), + "w": torch.ones((1,), requires_grad=True, device=device), + } + self.assertExpectedInline( + self._run_and_format_data_flow( + inputs, lambda x, w: (x * w).sin().backward() + ), + """\ + x: T0 + w: T1 + w.grad: T7 + + aten::mul T0(v0), T1(v0) -> T2(v0) + aten::sin T2(v0) -> T3(v0) + aten::ones_like T3(v0) -> T4(v0) + SinBackward0 T2(v0), T4(v0) -> T6(v0) + [Memory] T2(v0*) -> + MulBackward0 T0(v0), T6(v0) -> T7(v0) + [Memory] T6(v0*) -> + AccumulateGrad T7(v0) -> + [Memory] T4(v0*) -> + [Memory] T3(v0*) ->""", + ) + + +@skipIfTorchDynamo("TorchDynamo changes Python calls that memory profiling relies on.") +class TestMemoryProfilerE2E(TestCase): + @staticmethod + def _lookup_tensor_categories( + t: torch.Tensor, memory_profile: MemoryProfile + ) -> Dict[TensorKeyAndVersion, Optional[Category]]: + storage = t.storage() + if storage is None: + raise ValueError("Cannot look up uninitialized Tensor.") + + snapshot = memory_profile._category_snapshot() + ids = set() + for key, _ in snapshot: + if ( + key.storage.ptr == storage.data_ptr() + and key.device_type == _DEVICE_DICT.get(storage.device.type) + and key.device_index == storage.device.index + ): + ids.add(key.storage.allocation_id) + + max_id = max(ids) if ids else -1 + return { + (key, version): category + for (key, version), category in memory_profile._category_snapshot().items() + # + # If a Tensor is live we want the most recent ID + if key.storage.allocation_id == max_id + } + + def _run_and_check_parameters_and_gradients( + self, inner_fn, model, grads_none: bool = False + ): + with profile() as _: + inner_fn() + prof_dir = ProfPathCreator().get_prof_dir() + memory_profile = MemoryProfile(prof_dir) + PathManager.remove_path_safety(prof_dir) + + def assert_category( + t: torch.Tensor, + category: Category, + should_be_none: bool = False, + ): + if should_be_none: + assert t is None, "tensor should be None but is not." + return + self.assertIsNotNone(t) + categories = self._lookup_tensor_categories(t, memory_profile) + self.assertGreater(len(categories), 0) + self.assertTrue(all(c == category for c in categories.values()), categories) + + for p in model.parameters(): + assert_category(p, Category.PARAMETER) + assert_category(p.grad, Category.GRADIENT, grads_none) + + # Rely on internal asserts + _ = memory_profile.timeline + + def test_parameters_and_gradients(self): + model = torch.nn.Sequential( + torch.nn.Linear(2, 2), ScaleLayer(), torch.nn.Linear(2, 1), ScaleLayer() + ).to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + def fwd_only(): + _ = model(torch.ones((2, 2), device=device)) + + def fwd_bwd_step(): + optimizer.zero_grad() + y = model(torch.ones((2, 2), device=device)) + torch.nn.functional.mse_loss(y, torch.rand((2, 1), device=device)).backward() + optimizer.step() + + # If we profile the first step then gradients will not have been + # created when we call `model.forward`, so if we don't call `.backward` + # then gradients are never created. + self._run_and_check_parameters_and_gradients( + inner_fn=fwd_only, model=model, grads_none=True + ) + + # On the first step we must rely on `AccumulateGrad`, since gradients + # did not exist when `model.forward` was called. + self.assertTrue(all(p.grad is None for p in model.parameters())) + self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model) + + # After one step the python tracer will also flag gradients. + self.assertTrue(not any(p.grad is None for p in model.parameters())) + self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model) + + # The parameter gradients are not used but we still detect them with + # the python tracer. + self._run_and_check_parameters_and_gradients(inner_fn=fwd_only, model=model) + + def test_parameters_and_gradients_set_to_none(self): + model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1)).to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + + def fwd_bwd_step(): + for _ in range(3): + # zero grads at the start so gradients are still live to be + # checked. + optimizer.zero_grad(set_to_none=True) + + y = model(torch.ones((2, 2), device=device)) + torch.nn.functional.mse_loss(y, torch.rand((2, 1), device=device)).backward() + optimizer.step() + + fwd_bwd_step() + self.assertTrue(not any(p.grad is None for p in model.parameters())) + self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model) + + optimizer.zero_grad(set_to_none=True) + self.assertTrue(all(p.grad is None for p in model.parameters())) + self._run_and_check_parameters_and_gradients(inner_fn=fwd_bwd_step, model=model) + + def test_inputs_fwd(self): + model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1)).to(device) + inputs = [torch.ones((2, 2), device=device) for _ in range(2)] + + with profile() as _: + # Inputs which were allocated before profiling began + for x in inputs: + _ = model(x) + + # Inputs which were allocated after profiling began + for _ in range(2): + x = torch.ones((2, 2), device=device) + inputs.append(x) + _ = model(x) + prof_dir = ProfPathCreator().get_prof_dir() + memory_profile = MemoryProfile(prof_dir) + PathManager.remove_path_safety(prof_dir) + + for x in inputs: + categories = self._lookup_tensor_categories(x, memory_profile) + self.assertGreater(len(categories), 0) + self.assertTrue( + all(i == Category.INPUT for i in categories.values()), + categories, + ) + + snapshot = memory_profile._category_snapshot() + self.assertTrue(Category.INPUT in snapshot.values()) + + def test_inputs_fwd_lazy(self): + model = torch.nn.Sequential(LazyLinear(2, 2), LazyLinear(2, 1)).to(device) + inputs = [torch.ones((2, 2), device=device) for _ in range(2)] + + with profile() as _: + # Inputs which were allocated before profiling began + for x in inputs: + _ = model(x) + + # Inputs which were allocated after profiling began + for _ in range(2): + x = torch.ones((2, 2), device=device) + inputs.append(x) + _ = model(x) + prof_dir = ProfPathCreator().get_prof_dir() + memory_profile = MemoryProfile(prof_dir) + PathManager.remove_path_safety(prof_dir) + + for x in inputs: + categories = self._lookup_tensor_categories(x, memory_profile) + self.assertGreater(len(categories), 0) + self.assertTrue(all(i is None for i in categories.values()), categories) + + snapshot = memory_profile._category_snapshot() + self.assertFalse(Category.INPUT in snapshot.values()) + + def test_inputs_fwd_bwd(self): + model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1)).to(device) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + inputs_targets = [(torch.ones((2, 2), device=device), torch.rand((2, 1), device=device)) for _ in range(2)] + + def fwd_bwd_step(x, targets): + y = model(x) + torch.nn.functional.mse_loss(y, targets).backward() + optimizer.step() + optimizer.zero_grad() + + with profile() as _: + # Inputs which were allocated before profiling began + for x, targets in inputs_targets: + fwd_bwd_step(x, targets) + + # Inputs which were allocated after profiling began + for _ in range(2): + x = torch.ones((2, 2), device=device) + targets = torch.rand((2, 1), device=device) + inputs_targets.append((x, targets)) + fwd_bwd_step(x, targets) + prof_dir = ProfPathCreator().get_prof_dir() + memory_profile = MemoryProfile(prof_dir) + PathManager.remove_path_safety(prof_dir) + + def check(t): + categories = self._lookup_tensor_categories(t, memory_profile) + self.assertGreater(len(categories), 0) + self.assertTrue( + all(i == Category.INPUT for i in categories.values()) + ) + + for x, targets in inputs_targets: + check(x) + check(targets) + + def test_lazily_initialized(self) -> None: + model = torch.nn.Sequential( + torch.nn.Linear(2, 2), + torch.nn.ReLU(), + LazyLinear(2, 2), + torch.nn.ReLU(), + torch.nn.Linear(2, 1), + ).to(device) + + self.assertEqual(len(list(model.parameters())), 4) + + def inner_fn(): + y = model(torch.ones((2, 2), device=device)) + optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + optimizer.zero_grad() + torch.nn.functional.mse_loss(y, torch.rand((2, 1), device=device)).backward() + optimizer.step() + + self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model) + self.assertEqual(len(list(model.parameters())), 6) + + def test_manual_optimizer_step(self) -> None: + model = torch.nn.Sequential(torch.nn.Linear(2, 2), torch.nn.Linear(2, 1)).to(device) + + def inner_fn(): + y = model(torch.ones((2, 2), device=device)) + torch.nn.functional.mse_loss(y, torch.rand((2, 1), device=device)).backward() + + with torch.no_grad(): + for p in model.parameters(): + grad = p.grad + self.assertIsNotNone(grad) + p.add_(grad, alpha=-0.1) + + self._run_and_check_parameters_and_gradients(inner_fn=inner_fn, model=model) + + def test_memory_timeline_no_id(self) -> None: + x = torch.ones((1024,), device=device) + + with profile() as _: + # We never see `x` used so we don't know the storage is for a + # Tensor, but we do still see the free event. + del x + + # For empty we see the allocation and free, but not any use. + # So this also cannot be identified as a Tensor. + y = torch.empty((64,), device=device) + del y + + z = torch.empty((256,), device=device) + z.view_as(z) # Show `z` to the profiler + del z + prof_dir = ProfPathCreator().get_prof_dir() + memory_profile = MemoryProfile(prof_dir) + PathManager.remove_path_safety(prof_dir) + + expected = [ + # x + (Action.PREEXISTING, 4096 + 512), + (Action.DESTROY, 4096 + 512), + # y + (Action.CREATE, 512), + (Action.DESTROY, 512), + # z + (Action.CREATE, 1024 + 512), + (Action.DESTROY, 1024 + 512), + ] + + actual = [(action, size) for _, action, _, size in memory_profile.timeline] + + self.assertEqual( + actual, + expected, + f"expected does not match actual: {actual}", + ) + + if __name__ == "__main__": run_tests() \ No newline at end of file diff --git a/torch_npu/profiler/analysis/prof_view/_memory_timeline_parser.py b/torch_npu/profiler/analysis/prof_view/_memory_timeline_parser.py index 5f5a64b6a105a538405fe6bfe214a0e15c195aea..998eec9d7ee9504834a4b73555e64df6f9379b76 100644 --- a/torch_npu/profiler/analysis/prof_view/_memory_timeline_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_memory_timeline_parser.py @@ -507,24 +507,24 @@ class DataFlowGraph: the event tree partially. Consider the following code: ``` - with record_function("## Init ##"): + with record_function("## My Annotation ##"): x.zero_() y.zero_() ``` The event tree will look like: - TorchOp: "## Init ##" + TorchOp: "## My Annotation ##" TorchOp: zero_ - TorchOp: fill_ + TorchOp: aclnnInplaceZero TorchOp: zero_ - TorchOp: fill_ + TorchOp: aclnnInplaceZero It's important to select the right operator as a node in the - dataflow graph. In this case, choosing "## Init ##" loses - detail from subsequent calls, while `fill_` makes the graph - too detailed. The best nodes are top-level torch ops matching - the torch operator schema. Memory allocations and frees should - also be included to capture all memory usage. + dataflow graph. In this case, choosing "## My Annotation ##" + loses detail from subsequent calls, while `aclnnInplaceZero` + makes the graph too detailed. The best nodes are top-level + torch ops matching the torch operator schema. Memory allocations + and frees should also be included to capture all memory usage. """ leaf_events: List[_ProfilerEvent] = [] for event in traverse_dfs(root_nodes, children_fn=lambda e: self._get_children(e)):