본문 바로가기

Deep Learning

Implement KL Divergence using Pytorch

반응형
def GaussianPDF(mean, logvar, z):
    r""" Return the PDF value of z in N(mean,exp(logvar))
    mean, logvar : [*, dim_z]
    z : [*, N, dim_z]
    return: [*, N, dim_z]
    """
    if type(mean) is torch.Tensor:
        mean, logvar = mean.unsqueeze(-2), logvar.unsqueeze(-2)
        return 1/(np.sqrt(2*np.pi)*torch.exp(logvar*0.5)) * torch.exp(-((z-mean)**2) / (2*torch.exp(logvar)))
    elif type(mean) is np.ndarray:
        mean, logvar = np.expand_dims(mean, axis=-2), np.expand_dims(logvar, axis=-2)
        return 1/(np.sqrt(2*np.pi)*np.exp(logvar*0.5)) * np.exp(-((z-mean)**2) / (2*np.exp(logvar)))
    return None

def LogGaussianPDF(mean, logvar, z):
    r""" Return the log PDF value of z in N(mean,exp(logvar))
    mean, logvar : [*, dim_z]
    z : [*, N, dim_z]
    return: [*, N, dim_z]
    """
    if type(mean) is torch.Tensor:
        mean, logvar = mean.unsqueeze(-2), logvar.unsqueeze(-2)
        return -0.5*np.log(2*np.pi) -0.5*logvar - ((z-mean)**2+1e-6) / (2*torch.exp(logvar)+1e-6)
    elif type(mean) is np.ndarray:
        mean, logvar = np.expand_dims(mean, axis=-2), np.expand_dims(logvar, axis=-2)
        return -0.5*np.log(2*np.pi) -0.5*logvar - ((z-mean)**2+1e-6) / (2*np.exp(logvar)+1e-6)
    return None

def kld_loss(mean1, log_var1, mean2, log_var2):

    '''
    KLD = -0.5 * (log(var1/var2) - (var1 + (mu1 - mu2)^2)/var2 + 1 )
        = -0.5 * (A - B + 1)

    A = log(var1) - log(var2)
    B = (var1 + (mu1 - mu2)^2) / var2

    prior ~ N(mean2, var2)
    posterior ~ N(mean1, var1)
    '''

    A = log_var1 - log_var2
    B = log_var1.exp() + (mean1 - mean2).pow(2)
    kld = -0.5 * (A - B.div(log_var2.exp() + 1e-10) + 1)

    return torch.sum(kld, dim=1)

class NormalDist(nn.Module):

    def __init__(self, mu, logvar):
        super(NormalDist, self).__init__()

        self.mu = mu
        self.logvar = logvar

    def rsample(self, best_k):

        '''
        output : (best_k x batch) x dim
        '''

        Z = []
        for _ in range(best_k):
            std = torch.exp(0.5 * self.logvar)
            eps = torch.randn_like(std)
            Z.append(eps.mul(std).add_(self.mu))
        return torch.cat(Z, dim=0)

    def kld(self, p_z_dist=None, type='anal', n_samples=1000):
        '''
        type : anal (analytic), mc-b (mc-built-in), mc-w (mc-writen)

        It is found that
        (1) 'anal' and 'mc-b' show the same results
        (2) 'mc-w' can approximate 'mc-b'
        '''

        if (p_z_dist == None):
            p_z_dist = NormalDist(mu=torch.zeros_like(self.mu),
                                  logvar=torch.ones_like(self.logvar))

        if (type == 'anal'):
            return kld_loss(self.mu, self.logvar, p_z_dist.mu, p_z_dist.logvar).mean()
        elif (type == 'mc-b'):
            q = torch.distributions.Normal(self.mu, torch.exp(0.5 * self.logvar))
            p = torch.distributions.Normal(p_z_dist.mu, torch.exp(0.5 * p_z_dist.logvar))
            return torch.distributions.kl_divergence(q, p).sum(dim=-1).mean()
        elif (type == 'mc-w'):
            batch, dim = self.mu.size()
            z_ = self.rsample(best_k=n_samples)
            z = [z_[_*batch:(_+1)*batch].unsqueeze(1) for _ in range(n_samples)]
            z = torch.cat(z, dim=1) # batch x n_samples x dim

            logq = LogGaussianPDF(self.mu, self.logvar, z).mean(1) # batch x dim
            logp = LogGaussianPDF(p_z_dist.mu, p_z_dist.logvar, z).mean(1)

            return (logq-logp).sum(dim=-1).mean()

        else:
            sys.exit("current type %s is not supported for KLD calculation" % type)