问题
I have a use case where I do forward for each sample in a batch and only accumulate loss for some of the samples based on some condition on the model output of the sample. Here is an illustrating code,
for batch_idx, (data, target) in enumerate(train_loader):
optimizer.zero_grad()
total_loss = 0
loss_count_local = 0
for i in range(len(target)):
im = Variable(data[i].unsqueeze(0).cuda())
y = Variable(torch.FloatTensor([target[i]]).cuda())
out = model(im)
# if out satisfy some condtion, we will calculate loss
# for this sample, else proceed to next sample
if some_condition(out):
loss = criterion(out, y)
else:
continue
total_loss += loss
loss_count_local += 1
if loss_count_local == 32 or i == (len(target)-1):
total_loss /= loss_count_local
total_loss.backward()
total_loss = 0
loss_count_local = 0
optimizer.step()
My question is, as I do forward for all samples but only do backward for some of the samples. When will the graph for those samples which do not contribute to the loss be freed? Will these graphs be freed only after the for loop has ended or immediately after I do forward for the next sample? I am a little confused here.
Also for those samples that do contribute to the total_loss
, their graph will be freed immediately after we do total_loss.backward()
. Is that right?
回答1:
Let's start with a general discussion of how PyTorch frees memory:
First, we should emphasize that PyTorch uses an implicitly declared graph that is stored in Python object attributes. (Remember, it's Python, so everything is an object). More specifically, torch.autograd.Variable
s have a .grad_fn
attribute. This attribute's type defines what kind of computation node we have (e.g. an addition), and the input to that node.
This is important because Pytorch frees memory simply by using the standard python garbage collector (if fairly aggressively). In this context, this means that the (implicitly declared) computation graphs will be kept alive as long as there are references to the objects holding them in the current scope!
This means that if you e.g. do some kind of batching on samples s_1 ... s_k, compute the loss for each and add the loss at the end, that cumulative loss will hold references to each individual loss, which in turn holds references to each of the computation nodes that computed it.
So your question applied to your code is more about how Python (or, more specifically its garbage collector) handles references than about Pytorch does. Since you accumulate the loss in one object (total_loss
), you keep pointers alive, and thereby do not free the memory until you re-initialize that object in the outer loop.
Applied to your example, this means that the computation graph you create in the forward pass (at out = model(im)
) is only referenced by the out
object and any future computations thereof. So if you compute the loss and sum it, you will keep references to out
alive, and thereby to the computation graph. If you do not use it, however, the garbage collector should recursively collect out
, and its computation graph.
来源:https://stackoverflow.com/questions/47587122/when-will-the-computation-graph-be-freed-if-i-only-do-forward-for-some-samples