{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# MCMC sampling diagnostics" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this notebook, we illustrate how to assess the quality of your MCMC samples, e.g. convergence and auto-correlation, in pyPESTO." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:28:30.546369Z", "start_time": "2023-06-26T07:28:30.540934Z" } }, "outputs": [], "source": [ "# install if not done yet\n", "# !apt install libatlas-base-dev swig\n", "# %pip install pypesto[amici,petab] --quiet" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## The pipeline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we load the model and data to generate the MCMC samples from. Here, we show a toy example of a conversion reaction, loaded as a [PEtab](https://github.com/petab-dev/petab) problem." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:33:33.768292Z", "start_time": "2023-06-26T07:33:30.535312Z" } }, "outputs": [], "source": [ "import logging\n", "\n", "import numpy as np\n", "import petab\n", "\n", "import pypesto\n", "import pypesto.optimize as optimize\n", "import pypesto.petab\n", "import pypesto.sample as sample\n", "import pypesto.visualize as visualize\n", "\n", "# log diagnostics\n", "logger = logging.getLogger(\"pypesto.sample.diagnostics\")\n", "logger.setLevel(logging.INFO)\n", "logger.addHandler(logging.StreamHandler())\n", "\n", "# import to petab\n", "petab_problem = petab.Problem.from_yaml(\n", " \"conversion_reaction/multiple_conditions/conversion_reaction.yaml\"\n", ")\n", "# import to pypesto\n", "importer = pypesto.petab.PetabImporter(petab_problem)\n", "# create problem\n", "problem = importer.create_problem()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create the sampler object, in this case we will use adaptive parallel tempering with 3 temperatures." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:34:35.117499Z", "start_time": "2023-06-26T07:34:34.841805Z" } }, "outputs": [], "source": [ "sampler = sample.AdaptiveParallelTemperingSampler(\n", " internal_sampler=sample.AdaptiveMetropolisSampler(), n_chains=3\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, we will initiate the MCMC chain at a \"random\" point in parameter space, e.g. $\\theta_{start} = [3, -4]$" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:44:20.629679Z", "start_time": "2023-06-26T07:43:02.209978Z" } }, "outputs": [], "source": [ "%%capture\n", "result = sample.sample(\n", " problem,\n", " n_samples=1000,\n", " sampler=sampler,\n", " x0=np.array([3, -4]),\n", " filename=None,\n", ")\n", "elapsed_time = result.sample_result.time\n", "print(f\"Elapsed time: {round(elapsed_time, 2)}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:37:18.376334Z", "start_time": "2023-06-26T07:37:17.333572Z" } }, "outputs": [], "source": [ "ax = visualize.sampling_parameter_traces(\n", " result, use_problem_bounds=False, size=(12, 5)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By visualizing the chains, we can see a warm up phase occurring until convergence of the chain is reached. This is commonly known as \"burn in\" phase and should be discarded. An automatic way to evaluate and find the index of the chain in which the warm up is finished can be done by using the Geweke test." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:37:22.016036Z", "start_time": "2023-06-26T07:37:21.380827Z" } }, "outputs": [], "source": [ "sample.geweke_test(result=result)\n", "ax = visualize.sampling_parameter_traces(\n", " result, use_problem_bounds=False, size=(12, 5)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:37:25.353191Z", "start_time": "2023-06-26T07:37:24.097221Z" } }, "outputs": [], "source": [ "ax = visualize.sampling_parameter_traces(\n", " result, use_problem_bounds=False, full_trace=True, size=(12, 5)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "One can calculate the effective sample size per computation time. We save the results in a variable to compare them later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:37:28.830563Z", "start_time": "2023-06-26T07:37:28.805396Z" } }, "outputs": [], "source": [ "sample.effective_sample_size(result=result)\n", "ess = result.sample_result.effective_sample_size\n", "print(\n", " f\"Effective sample size per computation time: {round(ess / elapsed_time, 2)}\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:37:30.874974Z", "start_time": "2023-06-26T07:37:30.610636Z" } }, "outputs": [], "source": [ "alpha = [99, 95, 90]\n", "ax = visualize.sampling_parameter_cis(result, alpha=alpha, size=(10, 5))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predictions can be performed by creating a parameter ensemble from the sample, then applying a predictor to the ensemble. The predictor requires a simulation tool. Here, [AMICI](https://github.com/AMICI-dev/AMICI) is used. First, the predictor is set up." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:37:43.265325Z", "start_time": "2023-06-26T07:37:43.257428Z" } }, "outputs": [], "source": [ "from pypesto.C import AMICI_STATUS, AMICI_T, AMICI_X, AMICI_Y\n", "from pypesto.predict import AmiciPredictor\n", "\n", "\n", "# This post_processor will transform the output of the simulation tool\n", "# such that the output is compatible with the next steps.\n", "def post_processor(amici_outputs, output_type, output_ids):\n", " outputs = [\n", " (\n", " amici_output[output_type]\n", " if amici_output[AMICI_STATUS] == 0\n", " else np.full((len(amici_output[AMICI_T]), len(output_ids)), np.nan)\n", " )\n", " for amici_output in amici_outputs\n", " ]\n", " return outputs\n", "\n", "\n", "# Setup post-processors for both states and observables.\n", "from functools import partial\n", "\n", "amici_objective = result.problem.objective\n", "state_ids = amici_objective.amici_model.get_state_ids()\n", "observable_ids = amici_objective.amici_model.get_observable_ids()\n", "post_processor_x = partial(\n", " post_processor,\n", " output_type=AMICI_X,\n", " output_ids=state_ids,\n", ")\n", "post_processor_y = partial(\n", " post_processor,\n", " output_type=AMICI_Y,\n", " output_ids=observable_ids,\n", ")\n", "\n", "# Create pyPESTO predictors for states and observables\n", "predictor_x = AmiciPredictor(\n", " amici_objective,\n", " post_processor=post_processor_x,\n", " output_ids=state_ids,\n", ")\n", "predictor_y = AmiciPredictor(\n", " amici_objective,\n", " post_processor=post_processor_y,\n", " output_ids=observable_ids,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, the ensemble is created." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:37:48.270529Z", "start_time": "2023-06-26T07:37:48.258116Z" } }, "outputs": [], "source": [ "from pypesto.C import EnsembleType\n", "from pypesto.ensemble import Ensemble\n", "\n", "# corresponds to only the estimated parameters\n", "x_names = result.problem.get_reduced_vector(result.problem.x_names)\n", "\n", "# Create the ensemble with the MCMC chain from parallel tempering with the real temperature.\n", "ensemble = Ensemble.from_sample(\n", " result,\n", " chain_slice=slice(\n", " None, None, 5\n", " ), # Optional argument: only use every fifth vector in the chain.\n", " x_names=x_names,\n", " ensemble_type=EnsembleType.sample,\n", " lower_bound=result.problem.lb,\n", " upper_bound=result.problem.ub,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The predictor is then applied to the ensemble to generate predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:38:00.338142Z", "start_time": "2023-06-26T07:37:52.885968Z" } }, "outputs": [], "source": [ "from pypesto.engine import MultiProcessEngine\n", "\n", "engine = MultiProcessEngine()\n", "\n", "ensemble_prediction = ensemble.predict(\n", " predictor_x, prediction_id=AMICI_X, engine=engine\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:38:04.409272Z", "start_time": "2023-06-26T07:38:03.600374Z" } }, "outputs": [], "source": [ "from pypesto.C import CONDITION, OUTPUT\n", "\n", "credibility_interval_levels = [90, 95, 99]\n", "\n", "ax = visualize.sampling_prediction_trajectories(\n", " ensemble_prediction,\n", " levels=credibility_interval_levels,\n", " size=(10, 5),\n", " labels={\"A\": \"state_A\", \"condition_0\": \"cond_0\"},\n", " axis_label_padding=60,\n", " groupby=CONDITION,\n", " condition_ids=[\"condition_0\", \"condition_1\"], # `None` for all conditions\n", " output_ids=[\"A\", \"B\"], # `None` for all outputs\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:38:06.930290Z", "start_time": "2023-06-26T07:38:06.061992Z" } }, "outputs": [], "source": [ "ax = visualize.sampling_prediction_trajectories(\n", " ensemble_prediction,\n", " levels=credibility_interval_levels,\n", " size=(10, 5),\n", " labels={\"A\": \"obs_A\", \"condition_0\": \"cond_0\"},\n", " axis_label_padding=60,\n", " groupby=OUTPUT,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Predictions are stored in `ensemble_prediction.prediction_summary`." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Finding parameter point estimates\n", "Commonly, as a first step, optimization is performed, in order to find good parameter point estimates." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:38:19.022901Z", "start_time": "2023-06-26T07:38:13.471020Z" } }, "outputs": [], "source": [ "res = optimize.minimize(problem, n_starts=10, filename=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By passing the result object to the function, the previously found global optimum is used as starting point for the MCMC sampling." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:40:56.975257Z", "start_time": "2023-06-26T07:39:47.907190Z" } }, "outputs": [], "source": [ "%%capture\n", "res = sample.sample(\n", " problem, n_samples=1000, sampler=sampler, result=res, filename=None\n", ")\n", "elapsed_time = res.sample_result.time\n", "print(\"Elapsed time: \" + str(round(elapsed_time, 2)))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When the sampling is finished, we can analyse our results. pyPESTO provides functions to analyse both the sampling process as well as the obtained sampling result. Visualizing the traces e.g. allows to detect burn-in phases, or fine-tune hyperparameters. First, the parameter trajectories can be visualized:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:41:07.398824Z", "start_time": "2023-06-26T07:41:06.448954Z" } }, "outputs": [], "source": [ "ax = visualize.sampling_parameter_traces(\n", " res, use_problem_bounds=False, size=(12, 5)\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By visual inspection one can see, that the chain is already converged from the start. This is already showing the benefit of initiating the chain at the optimal parameter vector. However, this may not be always the case." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:41:11.826906Z", "start_time": "2023-06-26T07:41:11.174104Z" } }, "outputs": [], "source": [ "sample.geweke_test(result=res)\n", "ax = visualize.sampling_parameter_traces(\n", " res, use_problem_bounds=False, size=(12, 5)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:41:16.161400Z", "start_time": "2023-06-26T07:41:15.930905Z" } }, "outputs": [], "source": [ "sample.effective_sample_size(result=res)\n", "ess = res.sample_result.effective_sample_size\n", "print(\n", " f\"Effective sample size per computation time: {round(ess / elapsed_time, 2)}\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:41:18.244703Z", "start_time": "2023-06-26T07:41:17.884791Z" } }, "outputs": [], "source": [ "percentiles = [99, 95, 90]\n", "ax = visualize.sampling_parameter_cis(res, alpha=percentiles, size=(10, 5))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:42:28.807853Z", "start_time": "2023-06-26T07:42:20.364810Z" } }, "outputs": [], "source": [ "# Create the ensemble with the MCMC chain from parallel tempering with the real temperature.\n", "ensemble = Ensemble.from_sample(\n", " res,\n", " x_names=x_names,\n", " chain_slice=slice(None, None, 5),\n", " ensemble_type=EnsembleType.sample,\n", " lower_bound=res.problem.lb,\n", " upper_bound=res.problem.ub,\n", ")\n", "\n", "ensemble_prediction = ensemble.predict(\n", " predictor_y, prediction_id=AMICI_Y, engine=engine\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:42:31.086433Z", "start_time": "2023-06-26T07:42:30.290238Z" } }, "outputs": [], "source": [ "ax = visualize.sampling_prediction_trajectories(\n", " ensemble_prediction,\n", " levels=credibility_interval_levels,\n", " size=(10, 5),\n", " labels={\"A\": \"obs_A\", \"condition_0\": \"cond_0\"},\n", " axis_label_padding=60,\n", " groupby=CONDITION,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:42:34.826288Z", "start_time": "2023-06-26T07:42:34.207270Z" } }, "outputs": [], "source": [ "ax = visualize.sampling_prediction_trajectories(\n", " ensemble_prediction,\n", " levels=credibility_interval_levels,\n", " size=(10, 5),\n", " labels={\"A\": \"obs_A\", \"condition_0\": \"cond_0\"},\n", " axis_label_padding=60,\n", " groupby=OUTPUT,\n", " reverse_opacities=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Custom timepoints can also be specified, either for each condition\n", "- `amici_objective.set_custom_timepoints(..., timepoints=...)`\n", "\n", "or for all conditions\n", "- `amici_objective.set_custom_timepoints(..., timepoints_global=...)`.\n", "\n", "Plotting of measurement data (`petab_problem.measurement_df`) is also demonstrated here, which requires correct IDs in the `AmiciPredictor` that align with the observable and condition IDs in the measurement data." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "ExecuteTime": { "end_time": "2023-06-26T07:42:49.563714Z", "start_time": "2023-06-26T07:42:38.258389Z" } }, "outputs": [], "source": [ "# Create a custom objective with new output timepoints.\n", "timepoints = [np.linspace(0, 10, 100), np.linspace(0, 20, 200)]\n", "amici_objective_custom = amici_objective.set_custom_timepoints(\n", " timepoints=timepoints\n", ")\n", "\n", "# Create an observable predictor with the custom objective.\n", "predictor_y_custom = AmiciPredictor(\n", " amici_objective_custom,\n", " post_processor=post_processor_y,\n", " output_ids=observable_ids,\n", " condition_ids=[edata.id for edata in amici_objective_custom.edatas],\n", ")\n", "\n", "# Predict then plot.\n", "ensemble_prediction = ensemble.predict(\n", " predictor_y_custom, prediction_id=AMICI_Y, engine=engine\n", ")\n", "\n", "ax = visualize.sampling_prediction_trajectories(\n", " ensemble_prediction,\n", " levels=credibility_interval_levels,\n", " groupby=CONDITION,\n", " measurement_df=petab_problem.measurement_df,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 4 }