由 Aaron Shi 和 Zachary DeVito 撰写

这是《理解 GPU 内存》博客系列的第二部分。我们的第一篇《理解 GPU 内存 1:可视化所有分配随时间的变化》展示了如何使用内存快照工具。在本部分中,我们将使用内存快照来可视化由引用循环引起的 GPU 内存泄漏,然后使用引用循环检测器在我们的代码中定位并移除它们。

有时候当我们使用内存快照时,我们会看到类似这样的 GPU 内存图表。

GPU memory

在这个快照中,每个峰值都表示 GPU 张量随时间积累,然后一次性释放多个张量。此外,右侧发生 CUDA OOM,导致所有张量都被释放。看到张量以这种方式积累是问题的明显迹象,但它并不立即表明原因。

参考周期中的张量

在早期调试期间,我们进一步挖掘发现,当你的 Python 代码中有具有引用周期的对象时,这种**模式经常发生。** Python 会立即使用引用计数清理非循环对象。然而,循环中的对象只能由循环收集器稍后清理。如果这些循环引用了 GPU 张量,GPU 张量将保持活跃状态,直到循环收集器运行并移除引用循环。让我们看看一个简化的例子。

Simple reference cycle

快照背后的代码片段(完整代码见附录 A)

    def leak(tensor_size, num_iter=100000, device="cuda:0"):
      class Node:
        def __init__(self, T):
          self.tensor = T
          self.link = None

      for _ in range(num_iter):
        A = torch.zeros(tensor_size, device=device)
        B = torch.zeros(tensor_size, device=device)
        a, b = Node(A), Node(B)

        # A reference cycle will force refcounts to be non-zero.
        a.link, b.link = b, a
        # Python will eventually garbage collect a & b, but will
        # OOM on the GPU before that happens (since python
        # runtime doesn't know about CUDA memory usage).

在这个代码示例中,创建了张量 A 和 B,其中 A 指向 B,反之亦然。这迫使 A 和 B 超出作用域时具有非零引用计数。当我们运行 100,000 次迭代时,我们预计自动垃圾回收会在超出作用域时释放引用循环。然而,这实际上会导致 CUDA OOM。

为什么自动垃圾回收不起作用?

当有大量额外内存时,如 CPU 上常见的情况,自动垃圾回收工作良好,因为它通过使用代际垃圾回收来分摊昂贵的垃圾回收成本。但是为了分摊收集工作,它会推迟一些内存清理,使得最大内存使用量更高,这不太适合内存受限的环境。Python 运行时也没有对 CUDA 内存使用的洞察,因此它不能在高内存压力下触发。由于 GPU 训练几乎总是内存受限,因为我们会经常提高批大小以使用任何额外的空闲内存,所以这更加具有挑战性。

CPython 的垃圾回收通过标记-清除算法释放引用循环中不可达的对象。当对象数量超过一定阈值时,垃圾回收会自动运行。有 3 个代阈值,以帮助分摊在所有对象上运行垃圾回收的昂贵成本。较晚的代运行频率较低。这可以解释为什么自动收集每次只会清除每个峰值上的几个张量,然而仍然有一些张量泄漏,导致 CUDA OOM。这些张量在较晚的代中被引用循环持有。

显式调用 gc.collect()

解决这个问题的一种方法是通过频繁地显式调用垃圾回收器。在这里,我们可以看到,当我们每 100 次迭代显式调用垃圾回收器时,超出作用域的张量 GPU 内存就会被清理。这也控制了泄漏的张量所持有的最大 GPU 峰值内存。

memory leak

虽然这样做可以解决 CUDA OOM 问题,但频繁调用 gc.collect()可能会导致其他问题,包括 QPS 回归。因此,我们无法简单地增加每个训练作业的垃圾回收频率。最好的办法是首先避免创建引用循环。更多内容请参阅章节,引用循环检测器。

隐藏的回调内存泄漏

真实例子更复杂,所以让我们看看一个具有类似行为的更实际的例子。在这个快照中,我们可以观察到张量在自动垃圾回收过程中被累积和释放,直到我们遇到 CUDA OOM。

memory leak

此快照背后的代码片段(完整代码示例见附录 A):

    class AwaitableTensor:
      def __init__(self, tensor_size):
        self._tensor_size = tensor_size
        self._tensor = None

      def wait(self):
        self._tensor = torch.zeros(self._tensor_size, device="cuda:0")
        return self._tensor

    class AwaitableTensorWithViewCallback:
      def __init__(self, tensor_awaitable, view_dim):
        self._tensor_awaitable = tensor_awaitable
        self._view_dim = view_dim
        # Add a view filter callback to the tensor.
        self._callback = lambda ret: ret.view(-1, self._view_dim)

      def wait(self):
        return self._callback(self._tensor_awaitable.wait())

    async def awaitable_leak(
      tensor_size=2**27, num_iter=100000,
    ):
      for _ in range(num_iter):
        A = AwaitableTensor(tensor_size)
        AwaitableTensorWithViewCallBack(A, 4).wait()

在此代码中,我们定义了两个类。AwaitableTensor 类会在等待时创建张量。另一个类 AwaitableTensorWithViewCallback 会通过回调 lambda 对 AwaitableTensor 应用视图过滤器。

当运行 awaitable_leak 时,它创建张量 A(512 MB)并应用视图过滤器进行 100,000 次迭代,我们预期每次 A 离开作用域时都应该被回收,因为引用计数应该达到 0。然而,这实际上会导致 OOM!

虽然我们知道这里存在引用循环,但代码中并不清楚循环是在哪里创建的。为了帮助解决这些情况,我们创建了一个工具来定位和报告这些循环。

引用循环检测器

介绍参考周期检测器,它帮助我们找到保持 GPU 张量活跃的参考周期。该 API 相当简单:

  • 在模型初始化期间:
    • 导入: from torch.utils.viz._cycles import warn_tensor_cycles
    • 开始: warn_tensor_cycles()

引用周期检测器将在周期收集器运行并找到被释放的 CUDA 张量时发出警告。警告提供了对象图,展示了引用周期如何引用 GPU 张量。

object graph

例如,在这个对象图中,我们可以轻松观察到图的外圈存在循环依赖,并且用红色突出显示的是保持活跃的 GPU 张量。

一旦发现,大多数循环都很容易修复。例如,在这里我们可以移除由 self._view_dim 在回调中创建的自我引用。

code snippet

我们已经花费了一些时间使用这些工具修复现有模型中的循环。例如,在 TorchRec 中,我们在 PR#1226 中找到了并移除了一个引用循环。

code snippet

删除引用循环后,代码将不再引发 CUDA OOM,也不会在它们的快照中显示任何内存泄漏。

使用引用循环检测器的其他好处是什么?

移除这些循环还将直接降低最大 GPU 内存使用量,并使内存碎片化的可能性降低,因为分配器在每次迭代后都会回到相同的状态。

我在哪里可以找到这些工具?

我们希望参考循环检测器能大大提高您查找和删除由参考循环引起的内存泄漏的能力。参考循环检测器作为实验性功能,已包含在 PyTorch v2.1 版本中,更多关于参考循环检测器的信息可以在 PyTorch 内存文档中找到。

反馈

我们期待收到您关于任何增强功能、错误或由我们的工具帮助解决的内存故事的反馈!一如既往,请随时在 PyTorch 的 GitHub 页面上提交新问题。

我们也欢迎开源社区的贡献,请自由地在任何 GitHub PR 中提及 Aaron Shi 和 Zachary DeVito 以供审查。

致谢

非常感谢内容审阅者 Mark Saroufim、Gregory Chanan 和 Adnan Aziz 审阅这篇帖子并提高其可读性。

附录

附录 A - 代码示例

此代码片段用于生成所展示的图表和示例。以下是重现各部分的参数:

  • 简介: python sample.py
  • 明确调用 gc.collect(): python sample.py --gc_collect_interval=100
  • 回调中的偷偷内存泄漏: python sample.py --workload=awaitable
  • 循环引用检测器: python sample.py --workload=awaitable --warn_tensor_cycles

sample.py:

# (c) Meta Platforms, Inc. and affiliates. 
import argparse
import asyncio
import gc
import logging
import socket
from datetime import datetime, timedelta

import torch

logging.basicConfig(
   format="%(levelname)s:%(asctime)s %(message)s",
   level=logging.INFO,
   datefmt="%Y-%m-%d %H:%M:%S",
)
logger: logging.Logger = logging.getLogger(__name__)
logger.setLevel(level=logging.INFO)

TIME_FORMAT_STR: str = "%b_%d_%H_%M_%S"

# Keep a max of 100,000 alloc/free events in the recorded history
# leading up to the snapshot.
MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT: int = 100000

def start_record_memory_history() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not recording memory history")
       return

   logger.info("Starting snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(
       max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
   )

def stop_record_memory_history() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not recording memory history")
       return

   logger.info("Stopping snapshot record_memory_history")
   torch.cuda.memory._record_memory_history(enabled=None)

def export_memory_snapshot() -> None:
   if not torch.cuda.is_available():
       logger.info("CUDA unavailable. Not exporting memory snapshot")
       return

   # Prefix for file names.
   host_name = socket.gethostname()
   timestamp = datetime.now().strftime(TIME_FORMAT_STR)
   file_prefix = f"{host_name}_{timestamp}"

   try:
       logger.info(f"Saving snapshot to local file: {file_prefix}.pickle")
       torch.cuda.memory._dump_snapshot(f"{file_prefix}.pickle")
   except Exception as e:
       logger.error(f"Failed to capture memory snapshot {e}")
       return

# This function will leak tensors due to the reference cycles.
def simple_leak(tensor_size, gc_interval=None, num_iter=30000, device="cuda:0"):
    class Node:
        def __init__(self, T):
            self.tensor = T
            self.link = None

    for i in range(num_iter):
        A = torch.zeros(tensor_size, device=device)
        B = torch.zeros(tensor_size, device=device)
        a, b = Node(A), Node(B)
        # A reference cycle will force refcounts to be non-zero, when
        # a and b go out of scope.
        a.link, b.link = b, a
        # Python will eventually gc a and b, but may OOM on the CUDA
        # device before that happens (since python runtime doesn't
        # know about CUDA memory usage).

        # Since implicit gc is not called frequently enough due to
        # generational gc, adding an explicit gc is necessary as Python
        # runtime does not know about CUDA memory pressure.
        # https://en.wikipedia.org/wiki/Tracing_garbage_collection#Generational_GC_(ephemeral_GC)
        if gc_interval and i % int(gc_interval) == 0:
            gc.collect()

async def awaitable_leak(
    tensor_size, gc_interval=None, num_iter=100000, device="cuda:0"
):
    class AwaitableTensor:
        def __init__(self, tensor_size, device) -> None:
            self._tensor_size = tensor_size
            self._device = device
            self._tensor = None

        def wait(self) -> torch.Tensor:
            self._tensor = torch.zeros(self._tensor_size, device=self._device)
            return self._tensor

    class AwaitableTensorWithViewCallBack:
        def __init__(
            self,
            tensor_awaitable: AwaitableTensor,
            view_dim: int,
        ) -> None:
            self._tensor_awaitable = tensor_awaitable
            self._view_dim = view_dim
            # Add a view filter callback to the tensor.
            self._callback = lambda ret: ret.view(-1, self._view_dim)

        def wait(self) -> torch.Tensor:
            return self._callback(self._tensor_awaitable.wait())

    for i in range(num_iter):
        # Create an awaitable tensor
        a_tensor = AwaitableTensor(tensor_size, device)

        # Apply a view filter callback on the awaitable tensor.
        AwaitableTensorWithViewCallBack(a_tensor, 4).wait()

        # a_tensor will go out of scope.

        if gc_interval and i % int(gc_interval) == 0:
            gc.collect()

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="A memory_leak binary instance")
    parser.add_argument(
        "--gc_collect_interval",
        default=None,
        help="Explicitly call GC every given interval. Default is off.",
    )
    parser.add_argument(
        "--workload",
        default="simple",
        help="Toggle which memory leak workload to run. Options are simple, awaitable.",
    )
    parser.add_argument(
        "--warn_tensor_cycles",
        action="store_true",
        default=False,
        help="Toggle whether to enable reference cycle detector.",
    )
    args = parser.parse_args()

    if args.warn_tensor_cycles:
        from tempfile import NamedTemporaryFile

        from torch.utils.viz._cycles import observe_tensor_cycles

        logger.info("Enabling warning for Python reference cycles for CUDA Tensors.")

        def write_and_log(html):
            with NamedTemporaryFile("w", suffix=".html", delete=False) as f:
                f.write(html)
                logger.warning(
                    "Reference cycle includes a CUDA Tensor see visualization of cycle %s",
                    f.name,
                )

        observe_tensor_cycles(write_and_log)
    else:
        # Start recording memory snapshot history
        start_record_memory_history()

    # Run the workload with a larger tensor size.
    # For smaller sizes, we will not CUDA OOM as gc will kick in often enough
    # to reclaim reference cycles before an OOM occurs.
    size = 2**26  # 256 MB
    try:
        if args.workload == "awaitable":
            size *= 2
            logger.info(f"Running tensor_size: {size*4/1024/1024} MB")
            asyncio.run(
                awaitable_leak(tensor_size=size, gc_interval=args.gc_collect_interval)
            )
        elif args.workload == "simple":
            logger.info(f"Running tensor_size: {size*4/1024/1024} MB")
            simple_leak(tensor_size=size, gc_interval=args.gc_collect_interval)
        else:
            raise Exception("Unknown workload.")
    except Exception:
        logger.exception(f"Failed to allocate {size*4/1024/1024} MB")

    # Create the memory snapshot file
    export_memory_snapshot()

    # Stop recording memory snapshot history
    stop_record_memory_history()