The Official PyTorch Implementation of "Brain-like Variatonal Inference" (NeurIPS 2025 Paper)
Welcome to the "Brain-like Variational Inference" codebase!
Variational free energy (F) is the same thing as negative ELBO from machine learning. Why do we care? Because F minimization unifies popular generative models like VAEs with major cornerstones of theoretical neuroscience like Sparse Coding and Predictive Coding:
Building on this unification potential, we introduced FOND (Free energy Online Natural-gradient Dynamics): a framework for deriving brain-like adaptive iterative inference algorithms from first principles.
We then applied the FOND framework to derive a family of iterative VAE models, including the spiking iterative Poisson VAE (iP-VAE). This repository provides the implementation code for these iterative VAEs.
Before diving into the code, take a quick detour to watch an iP-VAE neuron in action, where we reproduce the classic Hubel & Wiesel bar of light experiment: 🎥 Watch the video with sound
To learn more, check out:
- Research paper: https://openreview.net/forum?id=573IcLusXq
- X summary thread: https://x.com/hadivafaii/status/1924344415063294287
- Talk: https://www.youtube.com/watch?v=fTg-S81Ymto
./main/: Full architecture and training code for the iterative VAE models, including iP-VAE and iG-VAE../base/: Core functionality including distributions, optimization, and dataset handling../analysis/: Data analysis and result generation code../scripts/: Model fitting scripts (examples below).
We also provide a minimal PyTorch Lightning implementation of the iP-VAE, stripped down to its essential components. This serves as an excellent starting point for understanding the model. Check it out:
To train a model, run:
cd scripts/
./fit_model.sh <device> <dataset> <model> [additional args]<device>:int, CUDA device index.<dataset>:str, choices ={'vH16-wht', 'MNIST', ...}.<model>:str, choices ={'poisson', 'gaussian'}.
Additional arguments can be passed to customize the training process. For example:
Key parameters include:
t_train: Number of inference iterations (default: 16)n_iters_outer: Number of repeats of the outer loop (default: 1)- Controls gradient accumulation cycles. When > 1, implements truncated backpropagation through time by performing multiple cycles of "run
t_traininference iterations, accumulate gradients, then update weights". Higher values allow longer effective sequence training while managing memory constraints. In the paper we only use the default value of 1.
- Controls gradient accumulation cycles. When > 1, implements truncated backpropagation through time by performing multiple cycles of "run
n_iters_inner: Number of gradient updates during inference (default: 1)- When > 1, the KL regularizaiton term kicks in (i.e., the "leak" term in iP-VAE).
beta_outer: Beta used during learning (default: 16.0)- This value is used when computing the loss for weight update.
beta_inner: Beta used during inference (default: 1.0)- This value is used only for the inner loop updates, therefore it does not have any effects when
n_iters_inner = 1.
- This value is used only for the inner loop updates, therefore it does not have any effects when
n_latents: Dimensionality of the latent space (default: 512)
See ./main/config.py for all available configuration options.
To reproduce Figure 3 from the paper, train models using the following configurations, corresponding to iP-VAE, iG-VAE, and iGrelu-VAE, respectively:
./fit_model.sh 0 'vH16-wht' 'poisson' --t_train 16 --n_latents 512 --beta_outer 24.0
./fit_model.sh 0 'vH16-wht' 'gaussian' --t_train 16 --n_latents 512 --beta_outer 8.0
./fit_model.sh 0 'vH16-wht' 'gaussian' --t_train 16 --n_latents 512 --beta_outer 8.0 --latent_act 'relu'results.ipynb: Generates figures and analyses from the paper.load_models.ipynb: Visualizes trained models and their features.hubel_wiesel.ipynb: Reproduce the classic Hubel & Wiesel bar of light experiment on a model "neuron".
We provide model checkpoints trained on whitened 16 x 16 patches extracted from the van Hateren dataset (vH16-wht). These are the same models you would get from running the scripts above, and they are located in ./checkpoints/ and can be loaded/visualized using load_models.ipynb. If additional model checkpoints would be helpful, feel free to reach out.
Download the processed datasets from the following links:
- Complete folder: Drive Link.
- Or individual datasets:
Place the downloaded data under ~/Datasets/ with the following structure:
~/Datasets/DOVES/vH16~/Datasets/MNIST/processed
For details, see the make_dataset() function in ./base/dataset.py.
If you use our code in your research, please cite our paper:
@inproceedings{
vafaii2025brainlike,
title={Brain-like Variational Inference},
author={Hadi Vafaii and Dekel Galor and Jacob L. Yates},
booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
year={2025},
url={https://openreview.net/forum?id=573IcLusXq}
}- For code-related questions, please open an issue in this repository.
- For paper-related questions, contact me at vafaii@berkeley.edu.
