diff --git a/imaginairy/modules/distributions.py b/imaginairy/modules/distributions.py index 1cb8694..62a671b 100644 --- a/imaginairy/modules/distributions.py +++ b/imaginairy/modules/distributions.py @@ -24,23 +24,24 @@ class DiagonalGaussianDistribution: def kl(self, other=None): if self.deterministic: return torch.Tensor([0.0]) - else: - if other is None: - return 0.5 * torch.sum( - torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, - dim=[1, 2, 3], - ) - else: - return 0.5 * torch.sum( - torch.pow(self.mean - other.mean, 2) / other.var - + self.var / other.var - - 1.0 - - self.logvar - + other.logvar, - dim=[1, 2, 3], - ) - def nll(self, sample, dims=[1, 2, 3]): + if other is None: + return 0.5 * torch.sum( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=None): + dims = dims if dims is None else [1, 2, 3] if self.deterministic: return torch.Tensor([0.0]) logtwopi = np.log(2.0 * np.pi)