Source code for xpsi.SBI_wrapper

import torch
import numpy as np
import xpsi
from xpsi import Likelihood
from xpsi.tools.synthesise import synthesise_exposure
from xpsi.tools.synthesise import synthesise_given_total_count_number
from random import randint


[docs] class Custom_SBI_Likelihood(xpsi.Likelihood): """ Custom likelihood function for use with SBI. Modifies the `_driver` and `synthesise` methods from the base class to return `model_flux` that is the synthesised signal. """
[docs] def _driver(self, fast_mode=False, synthesise=False, force_update=False, **kwargs): """ Main likelihood evaluation driver routine. """ self._star.activate_fast_mode(fast_mode) star_updated = False if self._star.needs_update or force_update: # ignore fast parameters in this version try: if fast_mode or not self._do_fast: fast_total_counts = None else: fast_total_counts = tuple(signal.fast_total_counts for\ signal in self._signals) self._star.update(fast_total_counts, self.threads,force_update=force_update) except xpsiError as e: if isinstance(e, HotRegion.RayError): print('Warning: HotRegion.RayError raised.') elif isinstance(e, Elsewhere.RayError): print('Warning: Elsewhere.RayError raised.') return self.random_near_llzero for photosphere, signals in zip(self._star.photospheres, self._signals): try: if fast_mode: energies = signals[0].fast_energies else: energies = signals[0].energies photosphere.integrate(energies, self.threads) except xpsiError as e: try: prefix = ' prefix ' + photosphere.prefix except AttributeError: prefix = '' if isinstance(e, HotRegion.PulseError): print('Warning: HotRegion.PulseError raised for ' 'photosphere%s.' % prefix) elif isinstance(e, Elsewhere.IntegrationError): print('Warning: Elsewhere.IntegrationError for ' 'photosphere%s.' % prefix) elif isinstance(e, HotRegion.AtmosError): raise elif isinstance(e, Elsewhere.AtmosError): raise print('Parameter vector: ', super(Likelihood,self).__call__()) return self.random_near_llzero star_updated = True # register the signals by operating with the instrument response for signals, photosphere in zip(self._signals, self._star.photospheres): for signal in signals: if star_updated or signal.needs_update: if signal.isI: signal.register(tuple( tuple(self._divide(component, self._star.spacetime.d_sq) for component in hot_region) for hot_region in photosphere.signal), fast_mode=fast_mode, threads=self.threads) elif signal.isQ: signal.register(tuple( tuple(self._divide(component, self._star.spacetime.d_sq) for component in hot_region) for hot_region in photosphere.signalQ), fast_mode=fast_mode, threads=self.threads) elif signal.isU: signal.register(tuple( tuple(self._divide(component, self._star.spacetime.d_sq) for component in hot_region) for hot_region in photosphere.signalU), fast_mode=fast_mode, threads=self.threads) elif signal.isQn: signal.register(tuple( tuple(self._divide(component, self._star.spacetime.d_sq) for component in hot_region) for hot_region in photosphere.signalQ), fast_mode=fast_mode, threads=self.threads) Qsignal = signal.signals signal.register(tuple( tuple(self._divide(component, self._star.spacetime.d_sq) for component in hot_region) for hot_region in photosphere.signal), fast_mode=fast_mode, threads=self.threads) Isignal = signal.signals for ihot in range(len(photosphere.signalQ)): signal._signals[ihot]=np.where(Isignal[ihot]==0.0, 0.0, Qsignal[ihot]/Isignal[ihot]) elif signal.isUn: signal.register(tuple( tuple(self._divide(component, self._star.spacetime.d_sq) for component in hot_region) for hot_region in photosphere.signalU), fast_mode=fast_mode, threads=self.threads) Usignal = signal.signals signal.register(tuple( tuple(self._divide(component, self._star.spacetime.d_sq) for component in hot_region) for hot_region in photosphere.signal), fast_mode=fast_mode, threads=self.threads) Isignal = signal.signals for ihot in range(len(photosphere.signalU)): signal._signals[ihot]=np.where(Isignal[ihot]==0.0, 0.0, Usignal[ihot]/Isignal[ihot]) else: raise TypeError('Signal type must be either I, Q, U, Qn, or Un.') reregistered = True else: reregistered = False if not fast_mode and reregistered: if synthesise: hot = photosphere.hot try: kws = kwargs.pop(signal.prefix) except AttributeError: kws = {} shifts = [h['phase_shift'] for h in hot.objects] signal.shifts = np.array(shifts) model_flux = signal.synthesise(threads=self._threads, **kws) else: try: hot = photosphere.hot shifts = [h['phase_shift'] for h in hot.objects] signal.shifts = np.array(shifts) signal(threads=self._threads, llzero=self._llzero) except LikelihoodError: try: prefix = ' prefix ' + signal.prefix except AttributeError: prefix = '' print('Warning: LikelihoodError raised for ' 'signal%s.' % prefix) print('Parameter vector: ', super(Likelihood,self).__call__()) return self.random_near_llzero return star_updated, model_flux
[docs] def synthesise(self, p, reinitialise=False, force=False, **kwargs): """ Synthesise pulsation data. :param list p: Parameter vector. :param optional[bool] reinitialise: Call ``self.reinitialise()``? :param optional[bool] force: Force complete reevaluation even if some parameters are unchanged. :param dict kwargs: Keyword arguments propagated to custom signal synthesis methods. Examples of such arguments include exposure times or required total count numbers (see example notebooks). :returns ndarray model_flux: *model_flux* (`numpy.ndarray`) synthesised counts. """ if reinitialise: # for safety if settings have been changed self.reinitialise() # do setup again given exisiting object refs self.clear_cache() # clear cache and values elif force: # no need to reinitialise, just clear cache and values self.clear_cache() if p is None: # expected a vector of values instead of nothing raise TypeError('Parameter values are None.') super(Likelihood, self).__call__(p) # update free parameters try: logprior = self._prior(p) # pass vector just in case wanted except AttributeError: pass else: if not np.isfinite(logprior): because_of_1D_bounds = False for param in self._prior.parameters: if param.bounds[0] is not None: if not param.bounds[0] <= param.value: because_of_1D_bounds = True if param.bounds[1] is not None: if not param.value <= param.bounds[1]: because_of_1D_bounds = True if because_of_1D_bounds: print("Warning: Prior check failed, because at least one of the parameters is not within the hard 1D-bounds. No synthetic data will be produced.") else: print('Warning: Prior check failed because a requirement set in CustomPrior has failed. No synthetic data will be produced.') # we need to restore due to premature return super(Likelihood, self).__call__(self.cached) return None if self._do_fast: # perform a low-resolution precomputation to direct cell # allocation x, model_flux = self._driver(fast_mode=True) if not isinstance(x, bool): super(Likelihood, self).__call__(self.cached) # restore return model_flux elif x: x, model_flux = self._driver(synthesise=True, **kwargs) if not isinstance(x, bool): super(Likelihood, self).__call__(self.cached) # restore return model_flux else: x, model_flux = self._driver(synthesise=True, **kwargs) if not isinstance(x, bool): super(Likelihood, self).__call__(self.cached) # restore return model_flux else: return model_flux
[docs] class SynthesiseData(xpsi.Data): """ Custom data container to enable synthesis. :param ndarray channels: Instrument channel numbers which must be equal in number to the first dimension of the count matrix. :param ndarray phases: Phases of the phase bins which must be equal in number to the second dimension of the count matrix. :param int first: First channel index number to include in the synthesised data. :param int last: Last channel index number to include in the synthesised data. """ def __init__(self, channels, phases, first, last): self.channels = channels self._phases = phases try: self._first = int(first) self._last = int(last) except TypeError: raise TypeError('The first and last channels must be integers.') if self._first >= self._last: raise ValueError('The first channel number must be lower than the ' 'the last channel number.')
[docs] def synthesise(self, exposure_time = None, expected_source_counts = None, nchannels = None, nphases = None, seed=0, **kwargs): """ Synthesise data set. :param float exposure_time: Exposure time in seconds to scale the expected count rate. :param float expected_source_counts: Total expected number of source counts. :param int nchannels: Number of channels in the synthesised data. :param int nphases: Number of phase bins in the synthesised data. :param optional[int] seed: Seed for random number generation for Poisson noise in synthesised data. :return: **synthetic** (`numpy.ndarray`) The synthesised data set. """ if nchannels is None or nphases is None: raise ValueError('nchannels and nphases must be specified.') bkg = np.zeros((nchannels, nphases)) if exposure_time: _, synthetic, _= synthesise_exposure(exposure_time, self._data.phases, self._signals, self._phases, self._shifts, np.sum(bkg), bkg, gsl_seed=seed) elif expected_source_counts: _, synthetic, _, _ = synthesise_given_total_count_number(self._data.phases, expected_source_counts, self._signals, self._phases, self._shifts, np.sum(bkg), bkg, gsl_seed=seed) else: raise ValueError('Must specify either exposure time or expected source counts.') return synthetic
[docs] class xpsi_wrappers: """ Class that wraps the xpsi likelihood and prior into a SBI compatible interface. :param xpsi.Prior prior: xpsi.Prior instance. :param xpsi.Likelihood likelihood: xpsi.Likelihood instance. :param dict instr_kwargs: Instrument keyword arguments for the likelihood synthesise method. :param optional[bool] train_using_CNNs: Whether to use CNNs for training. Defaults to True. """ # Todo: Enable cuda parallelisation. def __init__(self, prior, likelihood, instr_kwargs={}, train_using_CNNs=True): self.prior = prior self.likelihood = likelihood self.instr_kwargs = instr_kwargs self.train_using_CNNs = train_using_CNNs
[docs] def sample(self, sample_shape=torch.Size([])): """ Sample from the prior distribution. :param torch.Size sample_shape: The shape of the sample. Defaults to torch.Size([]). :return: **sample** (`torch.Tensor` or `torch.cuda.Tensor` if CUDA is available.) The sampled values. """ if len(sample_shape) > 0: samples = self.prior.draw(sample_shape[0]) return torch.Tensor(samples[0]).cuda() if torch.cuda.is_available() else torch.Tensor(samples[0]) else: return torch.Tensor(self.prior.draw(1)[0]).cuda() if torch.cuda.is_available() else torch.Tensor(self.prior.draw(1)[0])
[docs] def log_prob(self, parameter_vector): """ Compute the log probability of the parameter vector. :param torch.Tensor parameter_vector: The parameter vector. :return: **log_probability** (`torch.Tensor`) The log probability of the parameter vector. """ parameter_vector = torch.Tensor(parameter_vector).cuda() if torch.cuda.is_available() else torch.Tensor(parameter_vector) log_probability = [] for i in range(parameter_vector.shape[0]): log_probability.append(self.prior(parameter_vector[i,:])) return torch.Tensor(log_probability).cuda() if torch.cuda.is_available() else torch.Tensor(log_probability)
[docs] def simulator(self, parameter_vector): """ Compute the likelihood of the parameter vector. :param torch.Tensor parameter_vector: The parameter vector for which to simulate pulse profile. :return: **model_flux** (`torch.Tensor`) The pulse profile for the input parameter vector. """ self.instr_kwargs['seed'] = randint(0,1000000000000000) parameter_vector = np.array(parameter_vector.cpu()) model_flux = self.likelihood.synthesise(parameter_vector, force=True, instr=self.instr_kwargs) if self.train_using_CNNs==True: model_flux = torch.Tensor(model_flux).cuda() if torch.cuda.is_available() else torch.Tensor(model_flux) else: model_flux = torch.flatten(torch.Tensor(model_flux)).cuda() if torch.cuda.is_available() else torch.flatten(torch.Tensor(model_flux)) return model_flux