{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Error bars: the ``Profiler``\n", "Taking a fully Bayesian approach to MSD inference allows us to calculate reliable error bars on the fit results." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initial example" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "from tqdm.auto import tqdm\n", "\n", "import bayesmsd\n", "\n", "np.random.seed(113756719)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We generate data from the powerlaw $\\text{MSD}(\\Delta t) = (\\Delta t)^{0.63}$ and fit it with ``bayesmsd.lib.NPXFit``; so we are hoping to recover the parameter values $\\Gamma = 1$ and $\\alpha = 0.63$. We run the fit & profiler for increasing numbers of trajectories in the dataset (between 1 and 20 trajectories); clearly with more data we should get more precise estimates." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@bayesmsd.deco.MSDfun\n", "def msd(dt):\n", " return dt**0.63\n", "\n", "out = []\n", "for n_traj in tqdm([1, 2, 5, 10, 20]):\n", " data = bayesmsd.gp.generate((msd, 1, 1), T=100, n=n_traj)\n", " \n", " fit = bayesmsd.lib.NPXFit(data, ss_order=1)\n", " fit.parameters['log(σ²) (dim 0)'].fix_to = -np.inf\n", " \n", " profiler = bayesmsd.Profiler(fit)\n", " mci = profiler.find_MCI()\n", " \n", " out.append((n_traj, mci))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Reformat the outputs and plot them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Reformat\n", "n = np.array([n_traj for n_traj, mci in out])\n", "G_pe = np.exp(np.array([mci['log(Γ) (dim 0)'][0] for n_traj, mci in out]))\n", "G_ci = np.exp(np.array([mci['log(Γ) (dim 0)'][1] for n_traj, mci in out]))\n", "a_pe = np.array([mci['α (dim 0)' ][0] for n_traj, mci in out])\n", "a_ci = np.array([mci['α (dim 0)' ][1] for n_traj, mci in out])\n", "\n", "# Plot\n", "fig, axs = plt.subplots(1, 2, figsize=[10, 4])\n", "\n", "for (pe, ci,\n", " name, true_value,\n", " ax,\n", " ) in zip([a_pe, G_pe], [a_ci, G_ci],\n", " ['α', 'Γ'], [0.63, 1.00],\n", " axs,\n", " ):\n", " ax.axhline(true_value, linestyle='--', color='k', label='ground truth')\n", " ax.errorbar(n, pe, yerr=np.abs(ci.T - pe[None, :]),\n", " linestyle='', marker='o',\n", " label='inference\\n(95% credible interval)',\n", " )\n", "\n", " ax.legend()\n", " ax.set_xlabel('number of trajectories in dataset')\n", " ax.set_ylabel(f'inferred {name}')\n", " ax.set_xscale('log')\n", " ax.set_title(f'Inference accuracy for {name}')\n", " \n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As expected, the point estimates become better with increasing amounts of data. Crucially though, we also know how imprecise the estimates are for, e.g., a single trajectory! Note that the 95% credible interval covers the true parameter value in most cases: the fit might be off, but at least we can gauge by how much we might be off. This provision of reliable error estimates is a core strength of ``bayesmsd``." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Usage\n", "Using the ``Profiler`` is relatively simple, as illustrated above. On initialization it takes a ``Fit`` object, which contains the data and definition of the fit model. Beyond that you can tweak a few settings, like the confidence level (95% by default), the verbosity (how much output is printed while running), or whether to use profile likelihoods (``profiling = True``) or marginal likelihoods (``profiling = False``).\n", "\n", "Once the profiler is set up with all that information, you just have to call its ``find_MCI()`` method to let it run. This method returns a dict containing, for each parameter of the fit, the point estimate and 95% credible interval. By default, ``find_MCI()`` calculates intervals for all parameters; you can use the ``parameters = `` keyword argument to restrict the computation to a select subset." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Internal workings\n", "While running the profiler requires just two lines of code, internally of course there is a bit more to it. Here is a rundown of what the profiler does upon execution of ``find_MCI()``:\n", "\n", " #. Get a point estimate. This just means running an initial fit, like you would do manually. If, in fact, you did run the fit already before initializing the profiler, you can just set ``profiler.point_estimate`` to the fit result, so the profiler won't run the same fit again (c.f. [Quickstart](00_intro.ipynb)).\n", " #. Start the profiling runs for the individual parameters. It is possible that during this more rigorous exploration of the posterior we come across a better point estimate than what the ``Fit`` found initially (e.g. if the likelihood landscape is sufficiently rugged). In that case, the profiler starts over completely, using that new point estimate as inital conditions.\n", " #. Now, for each parameter we walk along the profile likelihood until it decreases below a threshold (dependent on the desired credibility level). How these steps are taken is determined by the ``linearization`` attribute of each parameter, as described [here](01_parameters.ipynb#Advanced:-linearization).\n", " #. Having found a point below the threshold, we now have a \"bracket\" enclosing the desired bound of the credible interval (which is the point exactly *at* the threshold): the point estimate is certainly above, while the point we just found is below. We can therefore now solve the problem to given precision (c.f. ``Profiler.conf_precision``) by bisection, which converges exponentially.\n", " #. Repeat by walking in the other direction from the point estimate\n", " \n", "By running through these steps for all parameters successively, the ``Profiler`` finds the interval boundaries where the profile likelihood (or marginal likelihood, if ``profiling = False``) decreases by a given amount from the maximum (the point estimate). These are the credible intervals for each parameter." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }