{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Posterior Inference using SBI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Simulation-Based Inference (SBI) with Neural Posterior Estimation (NPE) is a statistical framework for parameter inference in complex models. Instead of relying on an explicit likelihood function, NPE uses simulations to generate synthetic data and trains neural networks to learn an approximate posterior distribution of the parameters given observed data. This approach is typically used when the likelihood is intractable or expensive to compute, making traditional methods impractical. For pulse profile modelling, while the likelihood computation is tractable, sampling complex high-dimensioinal parameter spaces can get computationally expensive. SBI provides a potentially lucrative alternative to this problem.\n", "\n", "In this example notebook, we utilize the `sbi` package to perform this. Refer to the installation instructions for additional dependencies required to run this notebook." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Initialization" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/=============================================\\\n", "| X-PSI: X-ray Pulse Simulation and Inference |\n", "|---------------------------------------------|\n", "| Version: 3.0.0 |\n", "|---------------------------------------------|\n", "| https://xpsi-group.github.io/xpsi |\n", "\\=============================================/\n", "\n", "Imported emcee version: 3.1.6\n", "Imported PyMultiNest.\n", "Imported UltraNest.\n", "Imported GetDist version: 1.5.3\n", "Imported nestcheck version: 0.2.1\n" ] } ], "source": [ "## IMPORTANT: Import sequence - torch, sbi, and xpsi.\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from sbi import utils as utils\n", "from sbi import analysis as analysis\n", "from sbi.neural_nets import posterior_nn\n", "from sbi.inference import SNPE, simulate_for_sbi\n", "from sbi.utils.user_input_checks import prepare_for_sbi\n", "\n", "import numpy as np\n", "import math\n", "\n", "from matplotlib import pyplot as plt\n", "from matplotlib import rcParams\n", "rc = {\"font.family\" : \"serif\",\n", " \"mathtext.fontset\" : \"stix\"}\n", "plt.rcParams.update(rc)\n", "plt.rcParams[\"font.serif\"] = [\"Times New Roman\"] + plt.rcParams[\"font.serif\"]\n", "plt.rcParams.update({'font.size': 18})\n", "plt.rcParams.update({'legend.fontsize': 15})\n", "\n", "import xpsi\n", "\n", "import sys\n", "## Add your path to the example modules to run this notebook\n", "sys.path.append('../../examples/examples_fast/Modules/')\n", "\n", "\n", "import xpsi.SBI_wrapper as SBI_wrapper\n", "from xpsi.SBI_wrapper import xpsi_wrappers\n", "import xpsi.utilities.Example_CNNs as CNNs" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda is available: False\n" ] } ], "source": [ "# Check that cuda is available\n", "print( 'cuda is available: ' , torch.cuda.is_available() )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Importing modules from `examples_fast`" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setting channels for event data...\n", "Channels set.\n", "Setting channels for loaded instrument response (sub)matrix...\n", "Channels set.\n", "No parameters supplied... empty subspace created.\n", "Creating parameter:\n", " > Named \"phase_shift\" with fixed value 0.000e+00.\n", " > The phase shift for the signal, a periodic parameter [cycles].\n", "Creating parameter:\n", " > Named \"frequency\" with fixed value 3.140e+02.\n", " > Spin frequency [Hz].\n", "Creating parameter:\n", " > Named \"mass\" with bounds [1.000e+00, 1.600e+00].\n", " > Gravitational mass [solar masses].\n", "Creating parameter:\n", " > Named \"radius\" with bounds [1.000e+01, 1.300e+01].\n", " > Coordinate equatorial radius [km].\n", "Creating parameter:\n", " > Named \"distance\" with bounds [5.000e-01, 2.000e+00].\n", " > Earth distance [kpc].\n", "Creating parameter:\n", " > Named \"cos_inclination\" with bounds [0.000e+00, 1.000e+00].\n", " > Cosine of Earth inclination to rotation axis.\n", "Creating parameter:\n", " > Named \"super_colatitude\" with bounds [1.000e-03, 1.570e+00].\n", " > The colatitude of the centre of the superseding region [radians].\n", "Creating parameter:\n", " > Named \"super_radius\" with bounds [1.000e-03, 1.570e+00].\n", " > The angular radius of the (circular) superseding region [radians].\n", "Creating parameter:\n", " > Named \"phase_shift\" with bounds [-2.500e-01, 7.500e-01].\n", " > The phase of the hot region, a periodic parameter [cycles].\n", "Creating parameter:\n", " > Named \"super_temperature\" with bounds [6.000e+00, 7.000e+00].\n", " > log10(superseding region effective temperature [K]).\n", "Creating parameter:\n", " > Named \"mode_frequency\" with fixed value 3.140e+02.\n", " > Coordinate frequency of the mode of radiative asymmetry in the\n", "photosphere that is assumed to generate the pulsed signal [Hz].\n", "No parameters supplied... empty subspace created.\n", "Checking likelihood and prior evaluation before commencing sampling...\n", "Not using ``allclose`` function from NumPy.\n", "Using fallback implementation instead.\n", "Checking closeness of likelihood arrays:\n", "-3.1603740790e+04 | -3.1603740790e+04 .....\n", "Closeness evaluated.\n", "Log-likelihood value checks passed on root process.\n", "Checks passed.\n" ] } ], "source": [ "import main" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Preparing for SBI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First we follow procedures for synthetic data generation that will be used by SBI to generate training data. The `SBI_wrapper` module consists of multiple classes and functions, including the usual data synthesis process but with some extended functionalities." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Setting channels for event data...\n", "Channels set.\n" ] } ], "source": [ "_data = SBI_wrapper.SynthesiseData(main.Instrument.channels,\n", " main.data.phases,\n", " 0, \n", " len(main.Instrument.channels) - 1) " ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating parameter:\n", " > Named \"phase_shift\" with fixed value 0.000e+00.\n", " > The phase shift for the signal, a periodic parameter [cycles].\n", "No data... can synthesise data but cannot evaluate a likelihood function.\n" ] } ], "source": [ "main.CustomSignal.synthesise = SBI_wrapper.synthesise\n", "signal = main.CustomSignal(data = _data,\n", " instrument = main.Instrument,\n", " background = None,\n", " interstellar = None,\n", " prefix='instr',\n", " cache=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the code block above, the background has been set to `None`.\n", "In its current form, it's recommended to either fix the background or use a parameterized functional background model since the `default_background_marginalisation` is only utilized during the likelihood computation process that SBI skips. \n", "\n", "In principle, one may leave the background free and then allow the neural network to simply learn what the background is for any given dataset. However, performance in such a scenario has not been tested, and one may expect that to introduce too much degeneracy in the parameter space for SBI to work meaningfully." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the code block below, we are using a `Custom_SBI_Likelihood` that inherits fromt the `xpsi.Likelihood` class and modifies that `_driver` and `synthesise` methods to return `model_flux` that is the synthesised signal, which then constitutes the training dataset." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Free parameters\n", "---------------\n", "mass: Gravitational mass [solar masses].\n", "radius: Coordinate equatorial radius [km].\n", "distance: Earth distance [kpc].\n", "cos_inclination: Cosine of Earth inclination to rotation axis.\n", "p__phase_shift: The phase of the hot region, a periodic parameter [cycles].\n", "p__super_colatitude: The colatitude of the centre of the superseding region [radians].\n", "p__super_radius: The angular radius of the (circular) superseding region [radians].\n", "p__super_temperature: log10(superseding region effective temperature [K]).\n", "\n" ] } ], "source": [ "likelihood = SBI_wrapper.Custom_SBI_Likelihood(star = main.star,\n", " signals = signal,\n", " prior = main.prior,\n", " num_energies = 64,\n", " threads = 1)\n", "print(likelihood)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "During training, SBI learns the joint data and model space $P(\\theta; D)$, where the model is defined by the free parameters (the ones shown above in this example). In this process, it essentially also approximates the likelihood $P(D|\\theta)$ without explicitly calculating it.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Training SBI" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "With our prerequisites in place, we first instantiate our `simulator` and `prior` for SBI. To do so, we need to inform the `sbi` package about our pulse profile generator (defined by `SBI_wrapper.xpsi_wrappers.simulator`) and prior distributions (defined by `SBI_wrapper.xpsi_wrappers.sample` and `SBI_wrapper.xpsi_wrappers.log_prob`) for our free model parameters. These methods are in compliance with the requirements of the `sbi` package for training. \n", "\n", "The `SBI_wrapper.xpsi_wrappers.simulator` calls the `synthesise` method that we bound to `main.CustomSignal` above, which in turn requires information about the `exposure_time` (or `expected_source_counts`), `nchannels` and `nphases`.\n", "\n", "The `prepare_for_sbi` functionality of `sbi` then checks whether its internal requirements are met, reshapes and typecasts them into usable products for training." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Drawing samples from the joint prior...\n", "Samples drawn.\n", "Drawing samples from the joint prior...\n", "Samples drawn.\n", "Drawing samples from the joint prior...\n", "Samples drawn.\n", "Drawing samples from the joint prior...\n", "Samples drawn.\n", "Drawing samples from the joint prior...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_36371/4002819185.py:11: DeprecationWarning: This method is deprecated as of sbi version v0.23.0. and will be removed in a future release.Please use `process_prior` and `process_simulator` in the future.\n", " simulator, prior = prepare_for_sbi(wrapped.simulator, wrapped)\n", "/home/saltuomo/.conda/envs/xpsi_py3/lib/python3.12/site-packages/sbi/utils/user_input_checks_utils.py:28: UserWarning: No prior bounds were passed, consider passing lower_bound and / or upper_bound if your prior has bounded support.\n", " self.custom_support = build_support(lower_bound, upper_bound)\n", "/home/saltuomo/.conda/envs/xpsi_py3/lib/python3.12/site-packages/sbi/utils/user_input_checks_utils.py:30: UserWarning: Prior is lacking mean attribute, estimating prior mean from samples.\n", " self._set_mean_and_variance()\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Samples drawn.\n", "Drawing samples from the joint prior...\n", "Samples drawn.\n", "Drawing samples from the joint prior...\n", "Samples drawn.\n", "Drawing samples from the joint prior...\n", "Samples drawn.\n", "Drawing samples from the joint prior...\n", "Samples drawn.\n", "Drawing samples from the joint prior...\n", "Samples drawn.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/home/saltuomo/.conda/envs/xpsi_py3/lib/python3.12/site-packages/sbi/utils/user_input_checks_utils.py:30: UserWarning: Prior is lacking variance attribute, estimating prior variance from samples.\n", " self._set_mean_and_variance()\n" ] } ], "source": [ "instr_kwargs = dict(exposure_time = 1.0E+06, # alternatively input \n", " # expected_source_counts\n", " nchannels = len(main.Instrument.channels),\n", " nphases = len(main.data.phases))\n", "\n", "wrapped = xpsi_wrappers(prior = main.prior,\n", " likelihood = likelihood,\n", " instr_kwargs = instr_kwargs,\n", " train_using_CNNs = True)\n", "\n", "simulator, prior = prepare_for_sbi(wrapped.simulator, wrapped)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we generate training samples. Here we are using 1000 samples for training. This is insufficient for the complexity of the given model and is only used for tutorial purposes.\n", "(Tip: Save the training samples for future use. Not done here.)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Drawing samples from the joint prior...\n", "Samples drawn.\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "413bb9dfa4be421fb229e8bbde7487c2", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/100 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "p = [1.4, # mass\n", " 12, # radius\n", " 1.0, # distance\n", " math.cos(np.radians(60)), # inclination\n", " 0.5, # phase shift\n", " np.radians(70), # super colatitude\n", " 0.75, # super radius\n", " 6.7] # super temperature\n", "test_observation = torch.Tensor(likelihood.synthesise(p, force=True, instr=instr_kwargs))\n", "if torch.cuda.is_available():\n", " test_observation = test_observation.cuda()\n", "\n", "samples = posterior.sample((10000,), x=test_observation)\n", "\n", "_ = analysis.pairplot(samples.cpu(), \n", " limits=[[1, 3.0], [3, 20], [0.0, 2.0], [0.0, math.pi], [0.0, 1.0], [0.0, math.pi], [0.0, math.pi/2.0], [5.0, 7.0]], \n", " figsize=(10, 10),\n", " points=np.array(p),\n", " labels=[r'M$_{\\odot}$', r'R [km]', r'D [pc]', r'cos $i$', r'$\\phi$ [cycles]', r'$\\theta$ [rad]', r'$\\zeta$ [rad]', r'T [K]'])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }