How to balance the generator and the discriminator performances in a GAN?

自闭症网瘾萝莉.ら 提交于 2021-02-19 08:06:26

问题


It's the first time I'm working with GANs and I am facing an issue regarding the Discriminator repeatedly outperforming the Generator. I am trying to reproduce the PA model from this article and I'm looking at this slightly different implementation to help me out.

I have read quite a lot of papers on how GANs work and also followed some tutorials to understand them better. Moreover, I've read articles on how to overcome the major instabilities, but I can't find a way to overcome this behavior.

In my environment, I'm using PyTorch and BCELoss(). Following the DCGAN PyTorch tutorial, I'm using the following training loop:

criterion = nn.BCELoss()
train_d = False
# Discriminator true
optim_d.zero_grad()
disc_train_real = target.to(device)
batch_size = disc_train_real.size(0)
label = torch.full((batch_size,), 1, device=device).cuda()
output_d = discriminator(disc_train_real).view(-1)
loss_d_real = criterion(output_d, label).cuda()
if lossT:
    loss_d_real *= 2
if loss_d_real.item() > 0.3:
    loss_d_real.backward()
    train_d = True
D_x = output_d.mean().item()
# Discriminator false
output_g = generator(image)
output_d = discriminator(output_g.detach()).view(-1)
label.fill_(0)
loss_d_fake = criterion(output_d, label).cuda()
D_G_z1 = output_d.mean().item()
if lossT:
    loss_d_fake *= 2
loss_d = loss_d_real + loss_d_fake
if loss_d_fake.item() > 0.3:
    loss_d_fake.backward()
    train_d = True
if train_d:
    optim_d.step()

# Generator
label.fill_(1)
output_d = discriminator(output_g).view(-1)
loss_g = criterion(output_d, label).cuda()
D_G_z2 = output_d.mean().item()
if lossT:
    loss_g *= 2

loss_g.backward()
optim_g.step()

and, after a period of settlement, everything seems to work fine:

Epoch 1/5 - Step: 1900/9338  Loss G: 3.057388  Loss D: 0.214545  D(x): 0.940985  D(G(z)): 0.114064 / 0.114064
Time for the last step: 51.55 s    Epoch ETA: 01:04:13
Epoch 1/5 - Step: 2000/9338  Loss G: 2.984724  Loss D: 0.222931  D(x): 0.879338  D(G(z)): 0.159163 / 0.159163
Time for the last step: 52.68 s    Epoch ETA: 01:03:24
Epoch 1/5 - Step: 2100/9338  Loss G: 2.824713  Loss D: 0.241953  D(x): 0.905837  D(G(z)): 0.110231 / 0.110231
Time for the last step: 50.91 s    Epoch ETA: 01:02:29
Epoch 1/5 - Step: 2200/9338  Loss G: 2.807455  Loss D: 0.252808  D(x): 0.908131  D(G(z)): 0.218515 / 0.218515
Time for the last step: 51.72 s    Epoch ETA: 01:01:37
Epoch 1/5 - Step: 2300/9338  Loss G: 2.470529  Loss D: 0.569696  D(x): 0.620966  D(G(z)): 0.512615 / 0.350175
Time for the last step: 51.96 s    Epoch ETA: 01:00:46
Epoch 1/5 - Step: 2400/9338  Loss G: 2.148863  Loss D: 1.071563  D(x): 0.809529  D(G(z)): 0.114487 / 0.114487
Time for the last step: 51.59 s    Epoch ETA: 00:59:53
Epoch 1/5 - Step: 2500/9338  Loss G: 2.016863  Loss D: 0.904711  D(x): 0.621433  D(G(z)): 0.440721 / 0.435932
Time for the last step: 52.03 s    Epoch ETA: 00:59:02
Epoch 1/5 - Step: 2600/9338  Loss G: 2.495639  Loss D: 0.949308  D(x): 0.671085  D(G(z)): 0.557924 / 0.420826
Time for the last step: 52.66 s    Epoch ETA: 00:58:12
Epoch 1/5 - Step: 2700/9338  Loss G: 2.519842  Loss D: 0.798667  D(x): 0.775738  D(G(z)): 0.246357 / 0.265839
Time for the last step: 51.20 s    Epoch ETA: 00:57:19
Epoch 1/5 - Step: 2800/9338  Loss G: 2.545630  Loss D: 0.756449  D(x): 0.895455  D(G(z)): 0.403628 / 0.301851
Time for the last step: 51.88 s    Epoch ETA: 00:56:27
Epoch 1/5 - Step: 2900/9338  Loss G: 2.458109  Loss D: 0.653513  D(x): 0.820105  D(G(z)): 0.379199 / 0.103250
Time for the last step: 53.50 s    Epoch ETA: 00:55:39
Epoch 1/5 - Step: 3000/9338  Loss G: 2.030103  Loss D: 0.948208  D(x): 0.445385  D(G(z)): 0.303225 / 0.263652
Time for the last step: 51.57 s    Epoch ETA: 00:54:47
Epoch 1/5 - Step: 3100/9338  Loss G: 1.721604  Loss D: 0.949721  D(x): 0.365646  D(G(z)): 0.090072 / 0.232912
Time for the last step: 52.19 s    Epoch ETA: 00:53:55
Epoch 1/5 - Step: 3200/9338  Loss G: 1.438854  Loss D: 1.142182  D(x): 0.768163  D(G(z)): 0.321164 / 0.237878
Time for the last step: 50.79 s    Epoch ETA: 00:53:01
Epoch 1/5 - Step: 3300/9338  Loss G: 1.924418  Loss D: 0.923860  D(x): 0.729981  D(G(z)): 0.354812 / 0.318090
Time for the last step: 52.59 s    Epoch ETA: 00:52:11

that is, the gradients on the Generator are higher and start to decrease after a while, and in the meanwhile the gradients on the Discriminator rise up. As for the losses, the Generator goes down while the Discriminator goes up. If compared to the tutorial, I guess this can be acceptable.

Here's my first question: I've noticed that on the tutorial (usually) as D_G_z1 rises, D_G_z2 decreases (and viceversa), while in my example this happens a lot less. Is it just a coincidence or am I doing something wrong?

Given that, I've let the training procedure go on, but now I'm noticing this:

Epoch 3/5 - Step: 1100/9338  Loss G: 4.071329  Loss D: 0.031608  D(x): 0.999969  D(G(z)): 0.024329 / 0.024329
Time for the last step: 51.41 s    Epoch ETA: 01:11:24
Epoch 3/5 - Step: 1200/9338  Loss G: 3.883331  Loss D: 0.036354  D(x): 0.999993  D(G(z)): 0.043874 / 0.043874
Time for the last step: 51.63 s    Epoch ETA: 01:10:29
Epoch 3/5 - Step: 1300/9338  Loss G: 3.468963  Loss D: 0.054542  D(x): 0.999972  D(G(z)): 0.050145 / 0.050145
Time for the last step: 52.47 s    Epoch ETA: 01:09:40
Epoch 3/5 - Step: 1400/9338  Loss G: 3.504971  Loss D: 0.053683  D(x): 0.999972  D(G(z)): 0.052180 / 0.052180
Time for the last step: 50.75 s    Epoch ETA: 01:08:41
Epoch 3/5 - Step: 1500/9338  Loss G: 3.437765  Loss D: 0.056286  D(x): 0.999941  D(G(z)): 0.058839 / 0.058839
Time for the last step: 52.20 s    Epoch ETA: 01:07:50
Epoch 3/5 - Step: 1600/9338  Loss G: 3.369209  Loss D: 0.062133  D(x): 0.955688  D(G(z)): 0.058773 / 0.058773
Time for the last step: 51.05 s    Epoch ETA: 01:06:54
Epoch 3/5 - Step: 1700/9338  Loss G: 3.290109  Loss D: 0.065704  D(x): 0.999975  D(G(z)): 0.056583 / 0.056583
Time for the last step: 51.27 s    Epoch ETA: 01:06:00
Epoch 3/5 - Step: 1800/9338  Loss G: 3.286248  Loss D: 0.067969  D(x): 0.993238  D(G(z)): 0.063815 / 0.063815
Time for the last step: 52.28 s    Epoch ETA: 01:05:09
Epoch 3/5 - Step: 1900/9338  Loss G: 3.263996  Loss D: 0.065335  D(x): 0.980270  D(G(z)): 0.037717 / 0.037717
Time for the last step: 51.59 s    Epoch ETA: 01:04:16
Epoch 3/5 - Step: 2000/9338  Loss G: 3.293503  Loss D: 0.065291  D(x): 0.999873  D(G(z)): 0.070188 / 0.070188
Time for the last step: 51.85 s    Epoch ETA: 01:03:25
Epoch 3/5 - Step: 2100/9338  Loss G: 3.184164  Loss D: 0.070931  D(x): 0.999971  D(G(z)): 0.059657 / 0.059657
Time for the last step: 52.14 s    Epoch ETA: 01:02:34
Epoch 3/5 - Step: 2200/9338  Loss G: 3.116310  Loss D: 0.080597  D(x): 0.999850  D(G(z)): 0.074931 / 0.074931
Time for the last step: 51.85 s    Epoch ETA: 01:01:42
Epoch 3/5 - Step: 2300/9338  Loss G: 3.142180  Loss D: 0.073999  D(x): 0.995546  D(G(z)): 0.054752 / 0.054752
Time for the last step: 51.76 s    Epoch ETA: 01:00:50
Epoch 3/5 - Step: 2400/9338  Loss G: 3.185711  Loss D: 0.072601  D(x): 0.999992  D(G(z)): 0.076053 / 0.076053
Time for the last step: 50.53 s    Epoch ETA: 00:59:54
Epoch 3/5 - Step: 2500/9338  Loss G: 3.027437  Loss D: 0.083906  D(x): 0.997390  D(G(z)): 0.082501 / 0.082501
Time for the last step: 52.06 s    Epoch ETA: 00:59:03
Epoch 3/5 - Step: 2600/9338  Loss G: 3.052374  Loss D: 0.085030  D(x): 0.999924  D(G(z)): 0.073295 / 0.073295
Time for the last step: 52.37 s    Epoch ETA: 00:58:12

not only D(x) has increased again and it's stuck to almost one, but also both D_G_z1 and D_G_z2 always show the same value. Moreover, looking at the losses it seems pretty clear that the Discriminator has outperformed the Generator. This behavior has gone on and on for the rest of the epoch and for all the next one, until the end of the training.

Hence my second question: is this normal? If not, what am I doing wrong within the procedure? How can I achieve a more stable training?

EDIT: I've tried to train the network using the MSELoss() as suggested and this is the output:

Epoch 1/1 - Step: 100/9338  Loss G: 0.800785  Loss D: 0.404525  D(x): 0.844653  D(G(z)): 0.030439 / 0.016316
Time for the last step: 55.22 s    Epoch ETA: 01:25:01
Epoch 1/1 - Step: 200/9338  Loss G: 1.196659  Loss D: 0.014051  D(x): 0.999970  D(G(z)): 0.006543 / 0.006500
Time for the last step: 51.41 s    Epoch ETA: 01:21:11
Epoch 1/1 - Step: 300/9338  Loss G: 1.197319  Loss D: 0.000806  D(x): 0.999431  D(G(z)): 0.004821 / 0.004724
Time for the last step: 51.79 s    Epoch ETA: 01:19:32
Epoch 1/1 - Step: 400/9338  Loss G: 1.198960  Loss D: 0.000720  D(x): 0.999612  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.47 s    Epoch ETA: 01:18:09
Epoch 1/1 - Step: 500/9338  Loss G: 1.212810  Loss D: 0.000021  D(x): 0.999938  D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.18 s    Epoch ETA: 01:17:11
Epoch 1/1 - Step: 600/9338  Loss G: 1.216168  Loss D: 0.000000  D(x): 0.999945  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.24 s    Epoch ETA: 01:16:02
Epoch 1/1 - Step: 700/9338  Loss G: 1.212301  Loss D: 0.000000  D(x): 0.999970  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.61 s    Epoch ETA: 01:15:02
Epoch 1/1 - Step: 800/9338  Loss G: 1.214397  Loss D: 0.000005  D(x): 0.999973  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.58 s    Epoch ETA: 01:14:04
Epoch 1/1 - Step: 900/9338  Loss G: 1.212016  Loss D: 0.000003  D(x): 0.999932  D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.20 s    Epoch ETA: 01:13:13
Epoch 1/1 - Step: 1000/9338  Loss G: 1.215162  Loss D: 0.000000  D(x): 0.999988  D(G(z)): 0.000000 / 0.000000
Time for the last step: 52.28 s    Epoch ETA: 01:12:23
Epoch 1/1 - Step: 1100/9338  Loss G: 1.216291  Loss D: 0.000000  D(x): 0.999983  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.78 s    Epoch ETA: 01:11:28
Epoch 1/1 - Step: 1200/9338  Loss G: 1.215526  Loss D: 0.000000  D(x): 0.999978  D(G(z)): 0.000000 / 0.000000
Time for the last step: 51.88 s    Epoch ETA: 01:10:35

As can be seen, the situation gets even worse. Moreover, reading the EnhanceNet paper all over again, Section 4.2.4 (Adversarial Training) states that the adversarial loss function used is a BCELoss(), as I would expect to solve the vanishing gradients problem that I get with MSELoss().


回答1:


Interpreting GAN Losses are a bit of a black art because the actual loss values

Question 1: The frequency of swinging between a discriminator/generator dominance will vary based on a few factors primarily (in my experience): learning rates and batch sizes which will impact the propagated loss. The particular loss metrics used will impact variance in how the D & G networks train. The EnhanceNet paper (for baseline) and the tutorial use a Mean Squared Error loss too - you're using a Binary Cross Entropy loss which will change the rate at which the networks converge. I'm no expert so here's a pretty good link to Rohan Varma's article that explains the difference between loss functions. Would be curious to see if your network behaves differently when you change the loss function - try it and update the question?

Question 2: Over time both the D and G losses should settle to a value, however it's somewhat difficult to tell whether they've converged on strong performance or whether they've converged due to something like mode collapse/diminishing gradients (Jonathan Hui's explanation on problems in training GANs). The best way I've found is to actually inspect a cross section of the generated images and either visually inspect the output or use some kind of perceptual metrics (SSIM, PSNR, PIQ, etc.) across the generated image set.

Some other useful leads that you might find useful in finding an ans:

This post has a couple of reasonably good pointers on interpreting GAN Losses.

Ian Goodfellow's NIPS2016 tutorial also has some solid ideas on how to balance D & G training.



来源:https://stackoverflow.com/questions/62174141/how-to-balance-the-generator-and-the-discriminator-performances-in-a-gan

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!