반응형
def GaussianPDF(mean, logvar, z):
r""" Return the PDF value of z in N(mean,exp(logvar))
mean, logvar : [*, dim_z]
z : [*, N, dim_z]
return: [*, N, dim_z]
"""
if type(mean) is torch.Tensor:
mean, logvar = mean.unsqueeze(-2), logvar.unsqueeze(-2)
return 1/(np.sqrt(2*np.pi)*torch.exp(logvar*0.5)) * torch.exp(-((z-mean)**2) / (2*torch.exp(logvar)))
elif type(mean) is np.ndarray:
mean, logvar = np.expand_dims(mean, axis=-2), np.expand_dims(logvar, axis=-2)
return 1/(np.sqrt(2*np.pi)*np.exp(logvar*0.5)) * np.exp(-((z-mean)**2) / (2*np.exp(logvar)))
return None
def LogGaussianPDF(mean, logvar, z):
r""" Return the log PDF value of z in N(mean,exp(logvar))
mean, logvar : [*, dim_z]
z : [*, N, dim_z]
return: [*, N, dim_z]
"""
if type(mean) is torch.Tensor:
mean, logvar = mean.unsqueeze(-2), logvar.unsqueeze(-2)
return -0.5*np.log(2*np.pi) -0.5*logvar - ((z-mean)**2+1e-6) / (2*torch.exp(logvar)+1e-6)
elif type(mean) is np.ndarray:
mean, logvar = np.expand_dims(mean, axis=-2), np.expand_dims(logvar, axis=-2)
return -0.5*np.log(2*np.pi) -0.5*logvar - ((z-mean)**2+1e-6) / (2*np.exp(logvar)+1e-6)
return None
def kld_loss(mean1, log_var1, mean2, log_var2):
'''
KLD = -0.5 * (log(var1/var2) - (var1 + (mu1 - mu2)^2)/var2 + 1 )
= -0.5 * (A - B + 1)
A = log(var1) - log(var2)
B = (var1 + (mu1 - mu2)^2) / var2
prior ~ N(mean2, var2)
posterior ~ N(mean1, var1)
'''
A = log_var1 - log_var2
B = log_var1.exp() + (mean1 - mean2).pow(2)
kld = -0.5 * (A - B.div(log_var2.exp() + 1e-10) + 1)
return torch.sum(kld, dim=1)
class NormalDist(nn.Module):
def __init__(self, mu, logvar):
super(NormalDist, self).__init__()
self.mu = mu
self.logvar = logvar
def rsample(self, best_k):
'''
output : (best_k x batch) x dim
'''
Z = []
for _ in range(best_k):
std = torch.exp(0.5 * self.logvar)
eps = torch.randn_like(std)
Z.append(eps.mul(std).add_(self.mu))
return torch.cat(Z, dim=0)
def kld(self, p_z_dist=None, type='anal', n_samples=1000):
'''
type : anal (analytic), mc-b (mc-built-in), mc-w (mc-writen)
It is found that
(1) 'anal' and 'mc-b' show the same results
(2) 'mc-w' can approximate 'mc-b'
'''
if (p_z_dist == None):
p_z_dist = NormalDist(mu=torch.zeros_like(self.mu),
logvar=torch.ones_like(self.logvar))
if (type == 'anal'):
return kld_loss(self.mu, self.logvar, p_z_dist.mu, p_z_dist.logvar).mean()
elif (type == 'mc-b'):
q = torch.distributions.Normal(self.mu, torch.exp(0.5 * self.logvar))
p = torch.distributions.Normal(p_z_dist.mu, torch.exp(0.5 * p_z_dist.logvar))
return torch.distributions.kl_divergence(q, p).sum(dim=-1).mean()
elif (type == 'mc-w'):
batch, dim = self.mu.size()
z_ = self.rsample(best_k=n_samples)
z = [z_[_*batch:(_+1)*batch].unsqueeze(1) for _ in range(n_samples)]
z = torch.cat(z, dim=1) # batch x n_samples x dim
logq = LogGaussianPDF(self.mu, self.logvar, z).mean(1) # batch x dim
logp = LogGaussianPDF(p_z_dist.mu, p_z_dist.logvar, z).mean(1)
return (logq-logp).sum(dim=-1).mean()
else:
sys.exit("current type %s is not supported for KLD calculation" % type)
'Deep Learning' 카테고리의 다른 글
Anaconda를 이용한 env 생성 및 제거 (0) | 2023.03.08 |
---|---|
Gaussian, LogGaussian PDF in Pytorch (0) | 2022.12.06 |
Entropy, Mutual Information, KL Divergence (1) | 2022.12.02 |
Kullback–Leibler divergence (KL divergence) (0) | 2022.11.18 |
Kullback–Leibler divergence (KL divergence) (0) | 2022.10.25 |