-
Notifications
You must be signed in to change notification settings - Fork 35
Open
Description
Hi there!
Thanks for your excellent tutorial on NFs and the code. While playing one the one thing I wanted to do was just generate transformed pdfs and check if they integrate to 1. Using your planar flows class and forward method I wrote a code snippet to test this with np.trapz().
import torch
from torch import nn
import matplotlib.pyplot as plt
import numpy as np
import pyro.distributions as dist
class Planar(nn.Module):
def __init__(self, size=1, init_sigma=0.01):
super().__init__()
self.u = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma))
self.w = nn.Parameter(torch.randn(1, size).normal_(0, init_sigma))
self.b = nn.Parameter(torch.zeros(1))
@property
def normalized_u(self):
"""
Needed for invertibility condition.
See Appendix A.1
Rezende et al. Variational Inference with Normalizing Flows
https://arxiv.org/pdf/1505.05770.pdf
"""
# softplus
def m(x):
return -1 + torch.log(1 + torch.exp(x))
wtu = torch.matmul(self.w, self.u.t())
w_div_w2 = self.w / torch.norm(self.w)
return self.u + (m(wtu) - wtu) * w_div_w2
def psi(self, z):
"""
ψ(z) =h′(w^tz+b)w
See eq(11)
Rezende et al. Variational Inference with Normalizing Flows
https://arxiv.org/pdf/1505.05770.pdf
"""
return self.h_prime(z @ self.w.t() + self.b) @ self.w
def h(self, x):
return torch.tanh(x)
def h_prime(self, z):
return 1 - torch.tanh(z) ** 2
def forward(self, z):
if isinstance(z, tuple):
z, accumulating_ldj = z
else:
z, accumulating_ldj = z, 0
psi = self.psi(z)
u = self.normalized_u
# determinant of jacobian
det = (1 + psi @ u.t())
# log |det Jac|
ldj = torch.log(torch.abs(det) + 1e-6)
wzb = z @ self.w.t() + self.b
fz = z + (u * self.h(wzb))
return fz, ldj + accumulating_ldj
Perhaps I am missing something in the way I generate the pdf? Is there anything apart from the jacobian adjustment I need to worry about.?
if __name__ == '__main__':
z0 = torch.rand((1000, 2))
# define a meshgrid
x1 = torch.tensor(data=np.linspace(-5,5,100))
x2 = torch.tensor(data=np.linspace(-5,5,100))
x1_s, x2_s = torch.meshgrid(x1, x2)
x_field = torch.tensor(np.concatenate([x1_s[..., None], x2_s[..., None]], axis=-1)).float()
# unit Gaussian base dist.
base_dist = dist.MultivariateNormal(loc=torch.zeros(2), covariance_matrix=torch.eye(2))
# Planar flow
pf = Planar(size=2)
xk, ldj = pf.forward(x_field)
# Generating pdf and checking if integrates to 1
planar_pdf = torch.exp(base_dist.log_prob(x_field) - ldj.reshape(100,100))
print(np.trapz(np.trapz(planar_pdf.detach(), torch.linspace(-7,7,100), axis=0), torch.linspace(-7,7,100)))`
Metadata
Metadata
Assignees
Labels
No labels