如果我只对某些样本进行转发,计算图什么时候

When will the computation graph be freed if I only do forward for some samples?(如果我只对某些样本进行转发,计算图什么时候会被释放?)

本文介绍了如果我只对某些样本进行转发,计算图什么时候会被释放?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我有一个用例,我对批次中的每个样本进行转发,并且仅根据样本模型输出的某些条件累积某些样本的损失.这是一个说明代码,

I have a use case where I do forward for each sample in a batch and only accumulate loss for some of the samples based on some condition on the model output of the sample. Here is an illustrating code,

for batch_idx, (data, target) in enumerate(train_loader):
    optimizer.zero_grad()
    total_loss = 0

    loss_count_local = 0
    for i in range(len(target)):
        im = Variable(data[i].unsqueeze(0).cuda())
        y = Variable(torch.FloatTensor([target[i]]).cuda())

        out = model(im)

        # if out satisfy some condtion, we will calculate loss
        # for this sample, else proceed to next sample
        if some_condition(out):
            loss = criterion(out, y)
        else:
            continue

        total_loss += loss
        loss_count_local += 1

        if loss_count_local == 32 or i == (len(target)-1):
            total_loss /= loss_count_local
            total_loss.backward()
            total_loss = 0
            loss_count_local = 0

    optimizer.step()

我的问题是,因为我对所有样本都进行了转发,但只对某些样本进行了向后操作.那些对损失没有贡献的样本的图表什么时候会被释放?只有在 for 循环结束后或在我为下一个样本执行转发后立即释放这些图吗?我在这里有点困惑.

My question is, as I do forward for all samples but only do backward for some of the samples. When will the graph for those samples which do not contribute to the loss be freed? Will these graphs be freed only after the for loop has ended or immediately after I do forward for the next sample? I am a little confused here.

同样对于那些对 total_loss 有贡献的样本,它们的图将在我们执行 total_loss.backward() 后立即被释放.是吗?

Also for those samples that do contribute to the total_loss, their graph will be freed immediately after we do total_loss.backward(). Is that right?

推荐答案

让我们从一般性讨论 PyTorch 如何释放内存开始:

Let's start with a general discussion of how PyTorch frees memory:

首先,我们应该强调 PyTorch 使用存储在 Python 对象属性中的隐式声明图.(请记住,它是 Python,所以一切都是对象).更具体地说,torch.autograd.Variables 有一个 .grad_fn 属性.这个属性的类型定义了我们有什么样的计算节点(例如一个加法),以及该节点的输入.

First, we should emphasize that PyTorch uses an implicitly declared graph that is stored in Python object attributes. (Remember, it's Python, so everything is an object). More specifically, torch.autograd.Variables have a .grad_fn attribute. This attribute's type defines what kind of computation node we have (e.g. an addition), and the input to that node.

这很重要,因为 Pytorch 只需使用标准的 Python 垃圾收集器(如果相当积极的话)就可以释放内存.在这种情况下,这意味着(隐式声明的)计算图将保持活动状态,只要在当前作用域中存在对持有它们的对象的引用!

This is important because Pytorch frees memory simply by using the standard python garbage collector (if fairly aggressively). In this context, this means that the (implicitly declared) computation graphs will be kept alive as long as there are references to the objects holding them in the current scope!

这意味着如果您例如对样本 s_1 ... s_k 进行某种批处理,计算每个损失并在最后添加损失,该累积损失将保存对每个单独损失的引用,后者又保存对每个计算节点的引用

This means that if you e.g. do some kind of batching on samples s_1 ... s_k, compute the loss for each and add the loss at the end, that cumulative loss will hold references to each individual loss, which in turn holds references to each of the computation nodes that computed it.

因此,您应用于代码的问题更多是关于 Python(或者更具体地说是它的垃圾收集器)如何处理引用,而不是关于 Pytorch.由于您将损失累积在一个对象中 (total_loss),您可以保持指针处于活动状态,因此在外循环中重新初始化该对象之前不会释放内存.

So your question applied to your code is more about how Python (or, more specifically its garbage collector) handles references than about Pytorch does. Since you accumulate the loss in one object (total_loss), you keep pointers alive, and thereby do not free the memory until you re-initialize that object in the outer loop.

应用于您的示例,这意味着您在前向传递中创建的计算图(在 out = model(im))仅由 out 对象引用以及任何未来的计算.因此,如果您计算损失并将其求和,您将保持对 out 的引用,从而保持对计算图的引用.但是,如果您不使用它,垃圾收集器应该递归地收集 out 及其计算图.

Applied to your example, this means that the computation graph you create in the forward pass (at out = model(im)) is only referenced by the out object and any future computations thereof. So if you compute the loss and sum it, you will keep references to out alive, and thereby to the computation graph. If you do not use it, however, the garbage collector should recursively collect out, and its computation graph.

这篇关于如果我只对某些样本进行转发,计算图什么时候会被释放?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!

本文标题为:如果我只对某些样本进行转发,计算图什么时候

基础教程推荐