Generative modelling with neural probabilistic circuits

The current state of the art in generative modelling is dominated by neural networks. Despite their impressive performance on many benchmark tasks, these algorithms do not provide tractable inference for common and important probabilistic queries. Moreover, the leading methods in this area – GANs, VAEs, normalizing flows, and diffusion models – are notoriously data-hungry and often require extensive tuning. Building on classic work in tree-based density estimation, this project will develop and study the properties of a hybrid class of models we call neural probabilistic circuits (NPCs). NPCs combine the expressive power of neural networks with the speed and flexibility of tree-based ensembles such as random forests and gradient boosting machines. The resulting neural symbolic algorithms can be used for a wide variety of downstream tasks, from compression to imputation and inference.

This project has theoretical and practical components. Primary objectives include: (1) developing new algorithms for tractable probabilistic modelling using NPCs; (2) encoding structural assumptions into model training to enable efficient estimation of unobservable quantities such as individual treatment effects; (3) studying the convergence rate and computational complexity of these algorithms; and (4) benchmarking against state-of-the-art alternatives on simulated and real-world data.

Watson, D., Blesch, K., Kapar, J. & Wright, M. (2023). Adversarial random forests for density estimation and generative modelling. In Proceedings of the 26th International Conference on Artificial Intelligence and Statistics (pp. 5357-5375). Valencia, Spain.
Choi, Y., Vergari, A., & Van den Broeck, G. (2020). Probabilistic circuits: A unifying framework for tractable probabilistic models. Technical Report, University of California, Los Angeles.
Tian, J. & Pearl, J. (2000). Probabilities of causation: Bounds and identification. Annals of Mathematics and Artificial Intelligence, 28: 287-313.

Project ID