PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.
- General factor graphs: PGMax supports easy specification of general factor graphs with potentially complicated topology, factor definitions, and discrete variables with a varying number of states.
- LBP in JAX: PGMax generates pure JAX functions implementing LBP for a
given factor graph. The generated pure JAX functions run on modern accelerators
(GPU/TPU), work with JAX transformations
(e.g.
vmap
for processing batches of models/samples,grad
for differentiating through the LBP iterative process), and can be easily used as part of a larger end-to-end differentiable system.
See our companion paper for more details.
Installation | Getting started
pip install pgmax
pip install git+https://github.com/deepmind/PGMax.git
While you can install PGMax in your standard python environment, we strongly recommend using a Python virtual environment to manage your dependencies. This should help to avoid version conflicts and just generally make the installation process easier.
git clone https://github.com/deepmind/PGMax.git
cd PGMax
python3 -m venv pgmax_env
source pgmax_env/bin/activate
pip install --upgrade pip setuptools
pip install -e .
By default the above commands install JAX for CPU. If you have access to a GPU, follow the official instructions here to install JAX for GPU.
Here are a few self-contained Colab notebooks to help you get started on using PGMax. We recommend running them on GPU instances:
- First tutorial for basic PGMax inference on an Ising model
- Advanced tutorial running inference on a Restricted Boltzmann Machine
- Implementing max-product LBP for Recursive Cortical Networks
- End-to-end differentiable LBP for gradient-based PGM training
- 2D binary deconvolution
- Alternative inference using a Smooth Dual LP-MAP solver
PGMax is part of the DeepMind JAX ecosystem. If you use PGMax in your work, please consider citing our companion paper
@article{zhou2022pgmax,
author = {Zhou, Guangyao and Dedieu, Antoine and Kumar, Nishanth and L{\'a}zaro-Gredilla, Miguel and Kushagra, Shrinu and George, Dileep},
title = {{PGMax: Factor Graphs for Discrete Probabilistic Graphical Models and Loopy Belief Propagation in JAX}},
journal = {arXiv preprint arXiv:2202.04110},
year={2022}
}
and using the DeepMind JAX Ecosystem citation
@software{deepmind2020jax,
title = {The {D}eep{M}ind {JAX} {E}cosystem},
author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\'{c}, Milo\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},
url = {http://github.com/google-deepmind},
year = {2020},
}
This is not an officially supported Google product.