You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
20 lines
532 B
Python
20 lines
532 B
Python
"""Functions for discriminator loss calculation"""
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def hinge_d_loss(logits_real, logits_fake):
|
|
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
|
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
|
d_loss = 0.5 * (loss_real + loss_fake)
|
|
return d_loss
|
|
|
|
|
|
def vanilla_d_loss(logits_real, logits_fake):
|
|
d_loss = 0.5 * (
|
|
torch.mean(torch.nn.functional.softplus(-logits_real))
|
|
+ torch.mean(torch.nn.functional.softplus(logits_fake))
|
|
)
|
|
return d_loss
|