{ "cells": [ { "cell_type": "markdown", "id": "2db98323fb940c7f", "metadata": {}, "source": [ "# Bayes Factor Tutorial\n", "\n", "Bayes factors are a key concept in Bayesian model comparison, allowing us to compare the relative likelihood of different models given the data. They are computed using the marginal likelihoods (or evidence) of the models. This tutorial will cover various methods for computing marginal likelihoods.\n", "\n", "You find an introduction and extensive review here: [Llorente et al. (2023)](https://doi.org/10.1137/20M1310849)." ] }, { "cell_type": "markdown", "id": "5c56f766bcf7ab48", "metadata": {}, "source": [ "\n", "## Marginal Likelihood\n", "\n", "The marginal likelihood (or evidence) of a model $\\mathcal{M}$ given data $\\mathcal{D}$ is defined as:\n", "\n", "$$\n", "P(\\mathcal{D} \\mid \\mathcal{M}) = \\int P(\\mathcal{D} \\mid \\theta, \\mathcal{M}) P(\\theta \\mid \\mathcal{M}) \\, d\\theta\n", "$$\n", "\n", "where $\\theta$ are the parameters of the model. This integral averages the likelihood over the prior distribution of the parameters, providing a measure of how well the model explains the data, considering all possible parameter values." ] }, { "cell_type": "markdown", "id": "6337b6a3", "metadata": {}, "source": [ "## Bayes Factor\n", "\n", "The Bayes factor comparing two models $\\mathcal{M}_1$ and $\\mathcal{M}_2$ given data $\\mathcal{D}$ is the ratio of their marginal likelihoods:\n", "\n", "$$\n", "\\operatorname{BF}_{12} = \\frac{P(\\mathcal{D} \\mid \\mathcal{M}_1)}{P(\\mathcal{D} \\mid \\mathcal{M}_2)}\n", "$$\n", "\n", "A $\\operatorname{BF}_{12} > 1$ indicates that the data favors model $\\mathcal{M}_1$ over model $\\mathcal{M}_2$, while $\\operatorname{BF}_{12} < 1$ indicates the opposite.\n", "\n", "Jeffreys (1961) suggested interpreting Bayes factors in half-units on the log10 scale (this was further simplified in Kass and Raftery (1995)):\n", "\n", "- Not worth more than a bare mention: $0 < \\log_{10} \\operatorname{BF}_{12} \\leq 0.5$\n", "- Substantial: $0.5 < \\log_{10}\\operatorname{BF}_{12} \\leq 1$\n", "- Strong: $1 < \\log_{10}\\operatorname{BF}_{12} \\leq 2$\n", "- Decisive: $2 < \\log_{10}\\operatorname{BF}_{12}$" ] }, { "cell_type": "markdown", "id": "a6b7640cff0280de", "metadata": {}, "source": [ "## Example\n", "\n", "To illustrate different methods to compute marginal likelihoods, we introduce two toy models, for which we can compute the marginal likelihoods analytically:\n", "\n", "1. **Mixture of Two Gaussians (True Data Generator)**: Composed of two Gaussian distributions, $\\mathcal{N}(\\mu_1, \\sigma_1^2)$ and $\\mathcal{N}(\\mu_2, \\sigma_2^2)$, with mixing coefficient $\\pi=0.7$.\n", "\n", "2. **Single Gaussian (Alternative Model)**: A single Gaussian distribution, $\\mathcal{N}(\\mu, \\sigma^2)$.\n", "\n", "We sample synthetic data from the first model and create pypesto problems for both models with the same data.\n", "The free parameters are the means of both models.\n", "For this example, we assume that the standard deviation is known and fixed to the true value.\n", "As priors, we assume normal distributions." ] }, { "cell_type": "code", "execution_count": null, "id": "6eb930b7", "metadata": {}, "outputs": [], "source": [ "from functools import partial\n", "\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "from scipy import stats\n", "from scipy.special import logsumexp\n", "\n", "from pypesto import optimize, sample, variational, visualize\n", "from pypesto.C import LIN, NORMAL\n", "from pypesto.objective import (\n", " AggregatedObjective,\n", " NegLogParameterPriors,\n", " Objective,\n", " get_parameter_prior_dict,\n", ")\n", "from pypesto.problem import Problem\n", "\n", "# For testing purposes. Remove if not running the exact example.\n", "np.random.seed(42)" ] }, { "cell_type": "code", "execution_count": null, "id": "ad47e3f67a7896d3", "metadata": {}, "outputs": [], "source": [ "# model hyperparameters\n", "N = 10\n", "N2_1 = 3\n", "N2_2 = N - N2_1\n", "sigma2 = 2.0\n", "true_params = np.array([-2.0, 2.0])\n", "\n", "rng = np.random.default_rng(seed=0)\n", "# Alternative Model\n", "Y1 = rng.normal(loc=0.0, scale=1.0, size=N)\n", "\n", "# True Model\n", "Y2_1 = rng.normal(loc=true_params[0], scale=sigma2, size=N2_1)\n", "Y2_2 = rng.normal(loc=true_params[1], scale=sigma2, size=N2_2)\n", "Y2 = np.concatenate([Y2_1, Y2_2])\n", "mixture_data, sigma = Y2, sigma2\n", "n_obs = len(mixture_data)\n", "\n", "# plot the alternative model distribution as a normal distribution\n", "plt.figure()\n", "x = np.linspace(-10, 10, 100)\n", "plt.plot(\n", " x,\n", " stats.norm.pdf(x, loc=0.0, scale=1.0),\n", " label=\"Alternative Model\",\n", " color=\"red\",\n", ")\n", "plt.plot(\n", " x,\n", " stats.norm.pdf(x, loc=true_params[0], scale=sigma2),\n", " label=\"True Model Y2_1\",\n", " color=\"blue\",\n", ")\n", "plt.plot(\n", " x,\n", " stats.norm.pdf(x, loc=true_params[1], scale=sigma2),\n", " label=\"True Model Y2_2\",\n", " color=\"green\",\n", ")\n", "\n", "\n", "# Plot the data of the alternative and true model as dots on the x-axis for each model\n", "plt.scatter(Y1, np.zeros_like(Y1), label=\"Y1 samples\", color=\"red\")\n", "plt.scatter(Y2_1, np.full(len(Y2_1), 0.05), label=\"Y2_1 samples\", color=\"blue\")\n", "plt.scatter(Y2_2, np.full(len(Y2_2), 0.1), label=\"Y2_2 samples\", color=\"green\")\n", "plt.legend()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "2143410833d86594", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "# evidence\n", "def log_evidence_alt(data: np.ndarray, std: float):\n", " n = int(data.size)\n", " y_sum = np.sum(data)\n", " y_sq_sum = np.sum(data**2)\n", "\n", " term1 = 1 / (np.sqrt(2 * np.pi) * std)\n", " log_term2 = -0.5 * np.log(n + 1)\n", " inside_exp = -0.5 / (std**2) * (y_sq_sum - (y_sum**2) / (n + 1))\n", " return n * np.log(term1) + log_term2 + inside_exp\n", "\n", "\n", "def log_evidence_true(data: np.ndarray, std: float):\n", " y1 = data[:N2_1]\n", " y2 = data[N2_1:]\n", " n = N2_1 + N2_2\n", "\n", " y_mean_1 = np.mean(y1)\n", " y_mean_2 = np.mean(y2)\n", " y_sq_sum = np.sum(y1**2) + np.sum(y2**2)\n", "\n", " term1 = (1 / (np.sqrt(2 * np.pi) * std)) ** n\n", " term2 = 1 / (np.sqrt(N2_1 + 1) * np.sqrt(N2_2 + 1))\n", "\n", " inside_exp = (\n", " -1\n", " / (2 * std**2)\n", " * (\n", " y_sq_sum\n", " + 8\n", " - (N2_1 * y_mean_1 - 2) ** 2 / (N2_1 + 1)\n", " - (N2_2 * y_mean_2 + 2) ** 2 / (N2_2 + 1)\n", " )\n", " )\n", " return np.log(term1) + np.log(term2) + inside_exp\n", "\n", "\n", "true_log_evidence_alt = log_evidence_alt(mixture_data, sigma)\n", "true_log_evidence_true = log_evidence_true(mixture_data, sigma)\n", "\n", "print(\"True log evidence, true model:\", true_log_evidence_true)\n", "print(\"True log evidence, alternative model:\", true_log_evidence_alt)" ] }, { "cell_type": "code", "execution_count": null, "id": "784510a5dfdc9140", "metadata": {}, "outputs": [], "source": [ "# define likelihood for each model, and build the objective functions for the pyPESTO problem\n", "def neg_log_likelihood(params: np.ndarray | list, data: np.ndarray):\n", " # normal distribution\n", " mu, std = params\n", " n = int(data.size)\n", " return (\n", " 0.5 * n * np.log(2 * np.pi)\n", " + n * np.log(std)\n", " + np.sum((data - mu) ** 2) / (2 * std**2)\n", " )\n", "\n", "\n", "def neg_log_likelihood_grad(params: np.ndarray | list, data: np.ndarray):\n", " mu, std = params\n", " n = int(data.size)\n", " grad_mu = -np.sum(data - mu) / (std**2)\n", " grad_std = n / std - np.sum((data - mu) ** 2) / (std**3)\n", " return np.array([grad_mu, grad_std])\n", "\n", "\n", "def neg_log_likelihood_hess(params: np.ndarray | list, data: np.ndarray):\n", " mu, std = params\n", " n = int(data.size)\n", " hess_mu_mu = n / (std**2)\n", " hess_mu_std = 2 * np.sum(data - mu) / (std**3)\n", " hess_std_std = -n / (std**2) + 3 * np.sum((data - mu) ** 2) / (std**4)\n", " return np.array([[hess_mu_mu, hess_mu_std], [hess_mu_std, hess_std_std]])\n", "\n", "\n", "def neg_log_likelihood_m2(\n", " params: np.ndarray | list, data: np.ndarray, n_mix: int\n", "):\n", " # normal distribution\n", " y1 = data[:n_mix]\n", " y2 = data[n_mix:]\n", " m1, m2, std = params\n", "\n", " neg_log_likelihood([m1, std], y1)\n", " term1 = neg_log_likelihood([m1, std], y1)\n", " term2 = neg_log_likelihood([m2, std], y2)\n", " return term1 + term2\n", "\n", "\n", "def neg_log_likelihood_m2_grad(\n", " params: np.ndarray, data: np.ndarray, n_mix: int\n", "):\n", " m1, m2, std = params\n", " y1 = data[:n_mix]\n", " y2 = data[n_mix:]\n", "\n", " grad_m1, grad_std1 = neg_log_likelihood_grad([m1, std], y1)\n", " grad_m2, grad_std2 = neg_log_likelihood_grad([m2, std], y2)\n", " return np.array([grad_m1, grad_m2, grad_std1 + grad_std2])\n", "\n", "\n", "def neg_log_likelihood_m2_hess(\n", " params: np.ndarray, data: np.ndarray, n_mix: int\n", "):\n", " m1, m2, std = params\n", " y1 = data[:n_mix]\n", " y2 = data[n_mix:]\n", "\n", " [[hess_m1_m1, hess_m1_std], [_, hess_std_std1]] = neg_log_likelihood_hess(\n", " [m1, std], y1\n", " )\n", " [[hess_m2_m2, hess_m2_std], [_, hess_std_std2]] = neg_log_likelihood_hess(\n", " [m2, std], y2\n", " )\n", " hess_m1_m2 = 0\n", "\n", " return np.array(\n", " [\n", " [hess_m1_m1, hess_m1_m2, hess_m1_std],\n", " [hess_m1_m2, hess_m2_m2, hess_m2_std],\n", " [hess_m1_std, hess_m2_std, hess_std_std1 + hess_std_std2],\n", " ]\n", " )\n", "\n", "\n", "nllh_true = Objective(\n", " fun=partial(neg_log_likelihood_m2, data=mixture_data, n_mix=N2_1),\n", " grad=partial(neg_log_likelihood_m2_grad, data=mixture_data, n_mix=N2_1),\n", " hess=partial(neg_log_likelihood_m2_hess, data=mixture_data, n_mix=N2_1),\n", ")\n", "nllh_alt = Objective(\n", " fun=partial(neg_log_likelihood, data=mixture_data),\n", " grad=partial(neg_log_likelihood_grad, data=mixture_data),\n", " hess=partial(neg_log_likelihood_hess, data=mixture_data),\n", ")\n", "\n", "prior_true = NegLogParameterPriors(\n", " [\n", " get_parameter_prior_dict(\n", " index=0,\n", " prior_type=NORMAL,\n", " prior_parameters=[true_params[0], sigma2],\n", " parameter_scale=LIN,\n", " ),\n", " get_parameter_prior_dict(\n", " index=1,\n", " prior_type=NORMAL,\n", " prior_parameters=[true_params[1], sigma2],\n", " parameter_scale=LIN,\n", " ),\n", " ]\n", ")\n", "\n", "prior_alt = NegLogParameterPriors(\n", " [\n", " get_parameter_prior_dict(\n", " index=0,\n", " prior_type=NORMAL,\n", " prior_parameters=[0.0, 1.0],\n", " parameter_scale=LIN,\n", " ),\n", " ]\n", ")\n", "\n", "mixture_problem_true = Problem(\n", " objective=AggregatedObjective(objectives=[nllh_true, prior_true]),\n", " lb=[-10, -10, 0],\n", " ub=[10, 10, 10],\n", " x_names=[\"mu1\", \"mu2\", \"sigma\"],\n", " x_scales=[LIN, LIN, LIN],\n", " x_fixed_indices=[2],\n", " x_fixed_vals=[sigma],\n", " x_priors_defs=prior_true,\n", ")\n", "\n", "mixture_problem_alt = Problem(\n", " objective=AggregatedObjective(objectives=[nllh_alt, prior_alt]),\n", " lb=[-10, 0],\n", " ub=[10, 10],\n", " x_names=[\"mu\", \"sigma\"],\n", " x_scales=[LIN, LIN],\n", " x_fixed_indices=[1],\n", " x_fixed_vals=[sigma],\n", " x_priors_defs=prior_alt,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "cf9af2fa37f3a0cf", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "# to make the code more readable, we define a dictionary with all models\n", "# from here on, we use the pyPESTO problem objects, so the code can be reused for any other problem\n", "models = {\n", " \"mixture_model1\": {\n", " \"name\": \"True-Model\",\n", " \"true_log_evidence\": true_log_evidence_true,\n", " \"prior_mean\": np.array([-2, 2]),\n", " \"prior_std\": np.array([2, 2]),\n", " \"prior_cov\": np.diag([4, 4]),\n", " \"true_params\": true_params,\n", " \"problem\": mixture_problem_true,\n", " },\n", " \"mixture_model2\": {\n", " \"name\": \"Alternative-Model\",\n", " \"true_log_evidence\": true_log_evidence_alt,\n", " \"prior_mean\": np.array([0]),\n", " \"prior_std\": np.array([1]),\n", " \"prior_cov\": np.diag([1]),\n", " \"problem\": mixture_problem_alt,\n", " },\n", "}\n", "\n", "for m in models.values():\n", " # neg_log_likelihood is called with full vector, parameters might be still in log space\n", " m[\"neg_log_likelihood\"] = lambda x, m=m: m[\n", " \"problem\"\n", " ].objective._objectives[0](\n", " m[\"problem\"].get_full_vector(\n", " x=x, x_fixed_vals=m[\"problem\"].x_fixed_vals\n", " )\n", " )" ] }, { "cell_type": "markdown", "id": "e273503367e8bf4d", "metadata": {}, "source": [ "## Methods for Computing Marginal Likelihoods" ] }, { "cell_type": "code", "execution_count": null, "id": "95ec6b53c9133332", "metadata": {}, "outputs": [], "source": [ "%%time\n", "# run optimization for each model\n", "for m in models.values():\n", " m[\"results\"] = optimize.minimize(\n", " problem=m[\"problem\"],\n", " n_starts=100,\n", " )\n", "\n", " if \"true_params\" in m.keys():\n", " visualize.parameters(\n", " results=m[\"results\"],\n", " reference={\n", " \"x\": m[\"true_params\"],\n", " \"fval\": m[\"problem\"].objective(m[\"true_params\"]),\n", " },\n", " )\n", " else:\n", " visualize.parameters(m[\"results\"])" ] }, { "cell_type": "markdown", "id": "ffd895262133fe00", "metadata": {}, "source": [ "### 1. Bayesian Information Criterion (BIC)\n", "\n", "The BIC is a simple and widely-used approximation to the marginal likelihood. It is computed as:\n", "\n", "$$\n", "\\operatorname{BIC} = k \\ln(n) - 2 \\ln(\\hat{L})\n", "$$\n", "\n", "where $k$ is the number of parameters, $n$ is the number of data points, and $\\hat{L}$ is the maximum likelihood estimate. $-\\frac12 \\operatorname{BIC}$ approximates the marginal likelihood under the assumption that the prior is non-informative and the sample size is large.\n", "\n", "\n", "BIC is easy to compute and converges to the marginal likelihood, but it may not capture the full complexity of model selection, especially for complex models or significant prior information as the prior is completely ignored." ] }, { "cell_type": "code", "execution_count": null, "id": "1b40d72091d00e9f", "metadata": {}, "outputs": [], "source": [ "for m in models.values():\n", " m[\"BIC\"] = len(m[\"problem\"].x_free_indices) * np.log(n_obs) + 2 * m[\n", " \"neg_log_likelihood\"\n", " ](m[\"results\"].optimize_result.x[0])\n", " print(\n", " m[\"name\"], \"BIC marginal likelihood approximation:\", -1 / 2 * m[\"BIC\"]\n", " )" ] }, { "cell_type": "markdown", "id": "67cb4a7bb781d42", "metadata": {}, "source": [ "### 2. Laplace Approximation\n", "\n", "The Laplace approximation estimates the marginal likelihood by approximating the posterior distribution as a Gaussian centered at the maximum a posteriori (MAP) estimate $\\hat{\\theta}$ using the Hessian of the posterior distribution. The marginal likelihood is then approximated as:\n", "\n", "$$\n", "P(\\mathcal{D} \\mid \\mathcal{M}) \\approx (2\\pi)^{k/2} \\left| \\Sigma \\right|^{1/2} P(\\mathcal{D} \\mid \\hat{\\theta}, \\mathcal{M}) P(\\hat{\\theta} \\mid \\mathcal{M})\n", "$$\n", "\n", "where $\\Sigma$ is the covariance matrix of the posterior distribution (unnormalized, so likelihood $\\times$ prior).\n", "\n", "\n", "The Laplace approximation is accurate if the posterior is unimodal and roughly Gaussian." ] }, { "cell_type": "code", "execution_count": null, "id": "548513d76b8887dd", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for m in models.values():\n", " laplace_evidences = []\n", " for x in m[\"results\"].optimize_result.x:\n", " log_evidence = sample.evidence.laplace_approximation_log_evidence(\n", " m[\"problem\"], x\n", " )\n", " laplace_evidences.append(log_evidence)\n", "\n", " m[\"laplace_evidences\"] = np.array(laplace_evidences)\n", " print(m[\"name\"], f\"laplace approximation: {m['laplace_evidences'][0]}\")" ] }, { "cell_type": "markdown", "id": "b5ac29500e0e678b", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "source": [ "### 3. Sampling-Based Methods\n", "\n", "Sampling-based methods, such as Markov Chain Monte Carlo (MCMC) or nested sampling, do not make assumptions about the shape of the posterior and can provide more accurate estimates of the marginal likelihood. However, they can be computationally very intensive." ] }, { "cell_type": "markdown", "id": "212297d07ef90600", "metadata": {}, "source": [ "\n", "#### Arithmetic Mean Estimator\n", "\n", "The arithmetic mean estimator also uses samples from the prior evaluated at the likelihood function to approximate the marginal likelihood:\n", "\n", "$$\n", "P(\\mathcal{D} \\mid \\mathcal{M}) \\approx \\frac{1}{N} \\sum_{i=1}^N P(\\mathcal{D} \\mid \\theta_i, \\mathcal{M})\n", "$$\n", "\n", "The arithmetic mean estimator requires a large number of samples and is very inefficient. It approximates the marginal likelihood from below." ] }, { "cell_type": "code", "execution_count": null, "id": "ec2f000c836abad6", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "%%time\n", "for m in models.values():\n", " prior_sample = np.random.multivariate_normal(\n", " mean=m[\"prior_mean\"], cov=m[\"prior_cov\"], size=1000\n", " )\n", " log_likelihoods = np.array(\n", " [-m[\"neg_log_likelihood\"](x) for x in prior_sample]\n", " )\n", " m[\"arithmetic_log_evidence\"] = logsumexp(log_likelihoods) - np.log(\n", " log_likelihoods.size\n", " )\n", "\n", " print(m[\"name\"], f\"arithmetic mean: {m['arithmetic_log_evidence']}\")" ] }, { "cell_type": "markdown", "id": "77ec3e1ec016d0d1", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "source": [ "#### Harmonic Mean\n", "\n", "The harmonic mean estimator uses posterior samples to estimate the marginal likelihood:\n", "\n", "$$\n", "P(\\mathcal{D} \\mid \\mathcal{M}) \\approx \\left( \\frac{1}{N} \\sum_{i=1}^N \\frac{1}{P(\\mathcal{D} \\mid \\theta_i, \\mathcal{M})} \\right)^{-1}\n", "$$\n", "\n", "where $\\theta_i$ are samples from the posterior distribution.\n", "\n", "The harmonic mean estimator approximates the evidence from above since it tends to ignore low likelihood regions, such as those comprising the prior, leading to overestimates of the marginal likelihoods, even when asymptotically unbiased.\n", "Moreover, the estimator can have a high variance due to evaluating the likelihood at low probability regions and inverting it.\n", "Hence, it can be very unstable and even fail catastrophically. A more stable version, the stabilized harmonic mean, also uses samples from the prior (see [Newton and Raftery (1994)](https://doi.org/10.1111/j.2517-6161.1994.tb01956.x)). However, there are more efficient methods available.\n", "\n", "A reliable sampling method is bridge sampling (see [\"A Tutorial on Bridge Sampling\" by Gronau et al. (2017)](https://api.semanticscholar.org/CorpusID:5447695) for a nice introduction). It uses samples from a proposal and the posterior to estimate the marginal likelihood. The proposal distribution should be chosen to have a high overlap with the posterior (we construct it from half of the posterior samples by fitting a Gaussian distribution with the same mean and covariance). This method is more stable than the harmonic mean estimator. However, its accuracy may depend on the choice of the proposal distribution.\n", "\n", "A different approach, the learnt harmonic mean estimator, was proposed by [McEwen et al. (2021)](https://api.semanticscholar.org/CorpusID:244709474). The estimator solves the large variance problem by interpreting the harmonic mean estimator as importance sampling and introducing a new target distribution, which is learned from the posterior samples. The method can be applied just using samples from the posterior and is implemented in the software package accompanying the paper.\n" ] }, { "cell_type": "code", "execution_count": null, "id": "ba4cc742f71fad4", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for m in models.values():\n", " results = sample.sample(\n", " problem=m[\"problem\"],\n", " n_samples=1000,\n", " result=m[\"results\"],\n", " )\n", " # compute harmonic mean\n", " m[\"harmonic_log_evidence\"] = sample.evidence.harmonic_mean_log_evidence(\n", " results\n", " )\n", " print(m[\"name\"], f\"harmonic mean: {m['harmonic_log_evidence']}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a7272997b60de2e2", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for m in models.values():\n", " results = sample.sample(\n", " problem=m[\"problem\"],\n", " n_samples=800,\n", " result=m[\"results\"],\n", " )\n", " # compute stabilized harmonic mean\n", " prior_samples = np.random.multivariate_normal(\n", " mean=m[\"prior_mean\"], cov=m[\"prior_cov\"], size=200\n", " )\n", " m[\"harmonic_stabilized_log_evidence\"] = (\n", " sample.evidence.harmonic_mean_log_evidence(\n", " result=results,\n", " prior_samples=prior_samples,\n", " neg_log_likelihood_fun=m[\"neg_log_likelihood\"],\n", " )\n", " )\n", " print(\n", " m[\"name\"],\n", " f\"stabilized harmonic mean: {m['harmonic_stabilized_log_evidence']}\",\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "ce38f1a4975cd72a", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for m in models.values():\n", " results = sample.sample(\n", " problem=m[\"problem\"],\n", " n_samples=1000,\n", " result=m[\"results\"],\n", " )\n", " m[\"bridge_log_evidence\"] = sample.evidence.bridge_sampling_log_evidence(\n", " results\n", " )\n", " print(m[\"name\"], f\"bridge sampling: {m['bridge_log_evidence']}\")" ] }, { "cell_type": "markdown", "id": "443bf17c8ae27a15", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "source": [ "#### Nested Sampling\n", "\n", "Nested sampling is specifically designed for estimating marginal likelihoods. The static nested sampler is optimized for evidence computation and provides accurate estimates but may give less accurate posterior samples unless dynamic nested sampling is used.\n", "\n", "Dynamic nested sampling can improve the accuracy of posterior samples. The package [dynesty](https://dynesty.readthedocs.io/en/stable/) offers a lot of hyperparameters to tune accuracy and efficiency of computing samples from the posterior vs. estimating the marginal likelihood." ] }, { "cell_type": "code", "execution_count": null, "id": "c0236f455dfc64d5", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "%%time\n", "for m in models.values():\n", " # define prior transformation needed for nested sampling\n", " def prior_transform(u, m=m):\n", " \"\"\"Transform prior sample from unit cube to normal prior.\"\"\"\n", " t = stats.norm.ppf(u) # convert to standard normal\n", " c_sqrt = np.linalg.cholesky(m[\"prior_cov\"]) # Cholesky decomposition\n", " u_new = np.dot(c_sqrt, t) # correlate with appropriate covariance\n", " u_new += m[\"prior_mean\"] # add mean\n", " return u_new\n", "\n", " # initialize nested sampler\n", " nested_sampler = sample.DynestySampler(\n", " # sampler_args={'nlive': 250},\n", " run_args={\"maxcall\": 1000},\n", " dynamic=False, # static nested sampler is optimized for evidence computation\n", " prior_transform=prior_transform,\n", " )\n", "\n", " # run nested sampling\n", " result_dynesty_sample = sample.sample(\n", " problem=m[\"problem\"], n_samples=None, sampler=nested_sampler\n", " )\n", "\n", " # extract log evidence\n", " m[\"nested_log_evidence\"] = nested_sampler.sampler.results.logz[-1]\n", " print(m[\"name\"], f\"nested sampling: {m['nested_log_evidence']}\")" ] }, { "cell_type": "markdown", "id": "dcb16e2efcf4bf0d", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "source": [ "#### Thermodynamic Integration and Steppingstone Sampling\n", "\n", "These methods are based on the power posterior, where the posterior is raised to a power $t$ and integrated over $t$:\n", "\n", "$$\n", "P(\\mathcal{D} \\mid \\mathcal{M}) = \\int_0^1 \\frac1{Z_t} P(\\mathcal{D} \\mid \\theta, \\mathcal{M})^t P(\\theta \\mid \\mathcal{M}) \\, dt\n", "$$\n", "\n", "Parallel tempering is a sampling algorithm that improves accuracy for multimodal posteriors by sampling from different temperatures simultaneously and exchanging samples between parallel chains. It can be used to sample from all power posteriors simultaneously allowing for thermodynamic integration and steppingstone sampling [(Annis et al., 2019)](https://doi.org/10.1016/j.jmp.2019.01.005). These methods can be seen as path sampling methods, hence related to bridge sampling.\n", "\n", "These methods can be more accurate for complex posteriors but are computationally intensive. Thermodynamic integration (TI) relies on integrating the integral over the temperature $t$, while steppingstone sampling approximates the integral with a sum over a finite number of temperatures using an importance sampling estimator. Accuracy can be improved by using more temperatures.\n", "Errors in the estimator might come from the MCMC sampler in both cases and from numerical integration when applying TI. Steppingstone sampling can be a biased estimator for a small number of temperatures [(Annis et al., 2019)](https://doi.org/10.1016/j.jmp.2019.01.005).\n" ] }, { "cell_type": "code", "execution_count": null, "id": "13059e00c982d98d", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for m in models.values():\n", " # initialize parallel tempering sampler\n", " ti_sampler = sample.ParallelTemperingSampler( # not adaptive, since we want fixed temperatures\n", " internal_sampler=sample.AdaptiveMetropolisSampler(), n_chains=10\n", " )\n", "\n", " # run mcmc with parallel tempering\n", " result_ti = sample.sample(\n", " problem=m[\"problem\"],\n", " n_samples=1000,\n", " sampler=ti_sampler,\n", " result=m[\"results\"],\n", " )\n", " # compute log evidence via thermodynamic integration\n", " m[\"thermodynamic_log_evidence\"] = (\n", " sample.evidence.parallel_tempering_log_evidence(\n", " result_ti, use_all_chains=False\n", " )\n", " )\n", " print(\n", " m[\"name\"],\n", " f\"thermodynamic integration: {m['thermodynamic_log_evidence']}\",\n", " )\n", "\n", " # compute log evidence via steppingstone sampling\n", " m[\"steppingstone_log_evidence\"] = (\n", " sample.evidence.parallel_tempering_log_evidence(\n", " result_ti, method=\"steppingstone\", use_all_chains=False\n", " )\n", " )\n", " print(\n", " m[\"name\"], f\"steppingstone sampling: {m['steppingstone_log_evidence']}\"\n", " )" ] }, { "cell_type": "markdown", "id": "90fd0f80a9d94b7d", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "source": [ "#### Variational Inference\n", "\n", "Variational inference approximates the posterior with a simpler distribution and can be faster than sampling methods for large problems. The marginal likelihood can be estimated using similar approaches as before, but the accuracy is limited by the choice of the variational family.\n", "\n", "Variational inference optimization is based on the Evidence Lower Bound (ELBO), providing an additional check for the estimator." ] }, { "cell_type": "code", "execution_count": null, "id": "c616b8a566478d0d", "metadata": {}, "outputs": [], "source": [ "%%time\n", "for m in models.values():\n", " # one could define callbacks to check convergence during optimization\n", " # import pymc as pm\n", " # cb = [\n", " # pm.callbacks.CheckParametersConvergence(\n", " # tolerance=1e-3, diff='absolute'),\n", " # pm.callbacks.CheckParametersConvergence(\n", " # tolerance=1e-3, diff='relative'),\n", " # ]\n", "\n", " pypesto_variational_result = variational.variational_fit(\n", " problem=m[\"problem\"],\n", " method=\"advi\",\n", " n_iterations=10000,\n", " n_samples=None,\n", " result=m[\"results\"],\n", " # callbacks=cb,\n", " )\n", "\n", " # negative elbo, this is bound to the evidence (optimization criterion)\n", " vi_lower_bound = np.max(\n", " -pypesto_variational_result.variational_result.data.hist\n", " )\n", "\n", " # compute harmonic mean from posterior samples\n", " approx_sample = pypesto_variational_result.variational_result.sample(1000)[\n", " \"trace_x\"\n", " ][0]\n", " neg_log_likelihoods = np.array(\n", " [m[\"neg_log_likelihood\"](ps) for ps in approx_sample]\n", " )\n", " m[\"vi_harmonic_log_evidences\"] = -logsumexp(neg_log_likelihoods) + np.log(\n", " neg_log_likelihoods.size\n", " )\n", " print(\n", " m[\"name\"],\n", " f\"harmonic mean with variational inference: {m['vi_harmonic_log_evidences']}\",\n", " )\n", " print(\"Evidence lower bound:\", vi_lower_bound)\n", "\n", " # evidence cannot be smaller than the lower bound\n", " m[\"vi_harmonic_log_evidences\"] = max(\n", " m[\"vi_harmonic_log_evidences\"], vi_lower_bound\n", " )" ] }, { "cell_type": "markdown", "id": "5e6c53b1a6414210", "metadata": {}, "source": [ "## Comparison" ] }, { "cell_type": "code", "execution_count": null, "id": "fbb5a071645523d4", "metadata": {}, "outputs": [], "source": [ "labels = [\n", " \"-1/2 BIC\",\n", " \"Arithmetic Mean\",\n", " \"Laplace\",\n", " \"Harmonic Mean\",\n", " \"Stabilized\\nHarmonic Mean\",\n", " \"Bridge Sampling\",\n", " \"Nested Sampling\",\n", " \"Thermodynamic\\nIntegration\",\n", " \"Steppingstone\\nSampling\",\n", " \"Variational Inference\\nHarmonic Mean\",\n", "]\n", "\n", "bayes_factors = [\n", " -1 / 2 * models[\"mixture_model1\"][\"BIC\"]\n", " + 1 / 2 * models[\"mixture_model2\"][\"BIC\"],\n", " models[\"mixture_model1\"][\"arithmetic_log_evidence\"]\n", " - models[\"mixture_model2\"][\"arithmetic_log_evidence\"],\n", " models[\"mixture_model1\"][\"laplace_evidences\"][0]\n", " - models[\"mixture_model2\"][\"laplace_evidences\"][0],\n", " models[\"mixture_model1\"][\"harmonic_log_evidence\"]\n", " - models[\"mixture_model2\"][\"harmonic_log_evidence\"],\n", " models[\"mixture_model1\"][\"harmonic_stabilized_log_evidence\"]\n", " - models[\"mixture_model2\"][\"harmonic_stabilized_log_evidence\"],\n", " models[\"mixture_model1\"][\"bridge_log_evidence\"]\n", " - models[\"mixture_model2\"][\"bridge_log_evidence\"],\n", " models[\"mixture_model1\"][\"nested_log_evidence\"]\n", " - models[\"mixture_model2\"][\"nested_log_evidence\"],\n", " models[\"mixture_model1\"][\"thermodynamic_log_evidence\"]\n", " - models[\"mixture_model2\"][\"thermodynamic_log_evidence\"],\n", " models[\"mixture_model1\"][\"steppingstone_log_evidence\"]\n", " - models[\"mixture_model2\"][\"steppingstone_log_evidence\"],\n", " models[\"mixture_model1\"][\"vi_harmonic_log_evidences\"]\n", " - models[\"mixture_model2\"][\"vi_harmonic_log_evidences\"],\n", "]\n", "\n", "true_bf = (\n", " models[\"mixture_model1\"][\"true_log_evidence\"]\n", " - models[\"mixture_model2\"][\"true_log_evidence\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "30fea0ed78548d6b", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(2, 1, tight_layout=True, sharex=True, figsize=(6, 6))\n", "colors = [\"blue\", \"orange\"]\n", "\n", "for i, m in enumerate(models.values()):\n", " m[\"log_evidences\"] = np.array(\n", " [\n", " -1 / 2 * m[\"BIC\"],\n", " m[\"arithmetic_log_evidence\"],\n", " m[\"laplace_evidences\"][0],\n", " m[\"harmonic_log_evidence\"],\n", " m[\"harmonic_stabilized_log_evidence\"],\n", " m[\"bridge_log_evidence\"],\n", " m[\"nested_log_evidence\"],\n", " m[\"thermodynamic_log_evidence\"],\n", " m[\"steppingstone_log_evidence\"],\n", " m[\"vi_harmonic_log_evidences\"],\n", " ]\n", " )\n", " ax[0].scatter(\n", " x=np.arange(m[\"log_evidences\"].size),\n", " y=m[\"log_evidences\"],\n", " color=colors[i],\n", " label=m[\"name\"],\n", " )\n", " ax[0].axhline(\n", " m[\"true_log_evidence\"],\n", " color=colors[i],\n", " alpha=0.75,\n", " label=f\"True evidence of {m['name']}\",\n", " )\n", "\n", " m[\"error\"] = (\n", " np.exp(m[\"log_evidences\"]) - np.exp(m[\"true_log_evidence\"])\n", " ) ** 2\n", "mean_error = np.sum(np.array([m[\"error\"] for m in models.values()]), axis=0)\n", "ax[1].scatter(x=np.arange(len(labels)), y=mean_error)\n", "\n", "ax[1].set_xlabel(\"Estimator\")\n", "ax[0].set_title(\"Comparison of different evidence estimators\")\n", "ax[0].set_ylabel(\"Ln Evidence\")\n", "ax[1].set_ylabel(\"Squared Error of Evidence\\nsum of both models\")\n", "ax[1].set_yscale(\"log\")\n", "ax[1].set_xticks(ticks=np.arange(len(labels)), labels=labels, rotation=60)\n", "fig.legend(ncols=1, loc=\"center right\", bbox_to_anchor=(1.5, 0.7))\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "5d6590690b5c7a30", "metadata": { "collapsed": false, "jupyter": { "outputs_hidden": false } }, "outputs": [], "source": [ "fig, ax = plt.subplots(1, 1, tight_layout=True, figsize=(6, 5))\n", "ax.axhline(true_bf, linestyle=\"-\", color=\"r\", label=\"True Bayes Factor\")\n", "plt.scatter(\n", " x=np.arange(len(bayes_factors)), y=bayes_factors, label=\"Estimates\"\n", ")\n", "\n", "# add decision thresholds\n", "c = lambda x: np.log(\n", " np.power(10, x)\n", ") # usually defined in log10, convert to ln\n", "ax.axhline(\n", " c(0),\n", " color=\"red\",\n", " linestyle=\"--\",\n", " label='\"Not worth more than a bare mention\"',\n", ")\n", "ax.axhline(c(0.5), color=\"orange\", linestyle=\"--\", label='\"Substantial\"')\n", "ax.axhline(c(1), color=\"yellow\", linestyle=\"--\", label='\"Strong\"')\n", "ax.axhline(c(2), color=\"green\", linestyle=\"--\", label='\"Decisive\"')\n", "\n", "ax.set_ylabel(\"ln Bayes Factor\")\n", "ax.set_xlabel(\"Estimator\")\n", "ax.set_title(\n", " f\"Bayes Factor of {models['mixture_model1']['name']} vs. {models['mixture_model2']['name']}\"\n", ")\n", "plt.xticks(ticks=np.arange(len(bayes_factors)), labels=labels, rotation=60)\n", "fig.legend(ncols=1, loc=\"center right\", bbox_to_anchor=(1.5, 0.7))\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "6cbfd915823d6989", "metadata": {}, "source": [ "We recommend using either bridge sampling, nested sampling or one of the methods using power posteriors depending on the computational resources available.\n", "\n", "Bayes factors and marginal likelihoods are powerful tools for Bayesian model comparison. While there are various methods to compute marginal likelihoods, each has its strengths and weaknesses. Choosing the appropriate method depends on the specific context, the complexity of the models, and the computational resources available." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }