Source code for xpsi.PostProcessing._signalplotter

from .. import _warning

try:
    import h5py
except ImportError:
    _warning('Cannot import h5py for signal caching.')
    h5py = None

try:
    from tqdm.auto import trange as _range # detect if Jupyter notebook
except ImportError:
    def _range(n, *args, **kwargs): # ignore extra (kw)args
        return range(n)

from ._global_imports import *

from .. import Likelihood
from ._postprocessor import PostProcessor
from ._cache import _Cache
from ._signalplot import SignalPlot

[docs]class SignalPlotter(PostProcessor): """ Plot conditional posterior distributions of thematic X-ray signals. Methods to plot the data and model for posterior checking. Plots are generated for each posterior selected using the associated likelihood object. For a given model, there may be multiple :class:`~xpsi.Signal.Signal` instances per likelihood object. If this is the case, you need to reduce the model down to only the objects needed for functioning of the :class:`~xpsi.Signal.Signal` object to be handled. To do this, simply remove supply to your likelihood object the references to this minimal set of objects. This minimises computation time and ensures explicit declaration of the signals to be plotted. .. note:: If a model has multiple instruments, then the energy set for signal integration is calculated based on waveband coverage union. If instruments are omitted from the likelihood object in order to execute posterior signal plotting, the number of energies that span the waveband of the remaining instrument waveband should be set to match the number in the full model (with all instruments) if the likelihood factor for the remaining instrument is to be exactly the same. """
[docs] @fix_random_seed @make_verbose('Plotting signals for posterior checking', 'Plotted signals for posterior checking') def plot(self, plots, IDs=None, combine=False, combine_all=False, only_combined=False, force_combine=True, nsamples=200, cache=True, force_cache=False, cache_dir='./', read_only=False, archive=True): """ Compute and plot signals *a posteriori*. :param dict plots: Dictionary of lists of plot objects, where each dictionary key must match a posterior ID. :param OrderedDict IDs: Keys must be string identifiers of :class:`Runs` instances. Each dictionary element must be a list of string identifiers, each matching objects collected in :class:`Runs` instance corresponding to the key. Defaults to ``None``, meaning attempt to use as many runs as possible subject to plotting restrictions. :param bool cache: Cache intermediary model objects to accelerate post-processing? :param bool force_cache: Force caching irrespective of checks of existing cache files. Useful if some part of the model is tweaked and the cache file with the same name and sample set is not manually moved from the designated directory.. :param int nsamples: Number of samples to use. Equally-weighted samples are generated, thus introducing a additional Monte Carlo noise which is ignored. :param int num_phases: Number of phases to interpolate at on the interval [0,2] cycles. :param str filename: Filename of cache. :param str cache_dir: Directory to write cache to. :param bool read_only: Do not write to cache file? :param bool archive: If not read-only, then archive an existing cache file found at the same path? """ self.set_subset(IDs, combine, combine_all, force_combine, only_combined) for posterior in self.subset: yield 'Handling posterior %s' % posterior.ID run = posterior.subset_to_plot[0] try: likelihood = posterior.likelihood except AttributeError: print('Supply a likelihood object to proceed.') raise state_to_restore = likelihood.externally_updated likelihood.externally_updated = False try: _plots = plots[posterior.ID] except (TypeError, KeyError): print('Invalid plot object specification.') raise else: try: iter(_plots) except TypeError: if not isinstance(_plots, SignalPlot): raise TypeError('Invalid plot object type.') _plots.posterior = run.prepend_ID _plots = [_plots] else: for plot in _plots: if not isinstance(plot, SignalPlot): raise TypeError('Invalid plot object type.') plot.posterior = run.prepend_ID # assimilate all declared targets _caching_targets = [] for plot in _plots: try: _caching_targets += plot.__caching_targets__ except (TypeError, AttributeError): print('Invalid specification of caching targets.') raise # eliminate duplicates (some plot types can share caching targets) caching_targets = [] for target in _caching_targets: if target not in caching_targets: caching_targets.append(target) # caching targets are assimilated from plot objects likelihood.signal.caching_targets = caching_targets self._driver(run, likelihood, nsamples, cache, force_cache, cache_dir, read_only, archive, _plots) likelihood.externally_updated = state_to_restore yield 'Handled posterior %s.' % posterior.ID self._plots = plots # in case user needs a handle (e.g., if plot objects created # via a classmethod), could return handle here yield
@property def plots(self): """ Get the dictionary of plot objects last processed. """ return self._plots @staticmethod def _draw_equally_weighted(samples, nsamples, num_params): """ Get a set of equally weighted samples from a weighted set. ..note:: Additional Monte Carlo noise from trimming sample set ignored. """ assert nsamples < samples.shape[0], 'Number of samples for plotting \ cannot exceed number of nested \ samples.' weights = samples[:,0].copy() if _np.abs(1.0 - _np.sum(weights)) > 0.01: print('Warning: 1 - (sum of weights) = %.8e'%(1.0-_np.sum(weights))) print('Weights renormalized to sum to unity.') weights /= _np.sum(weights) indices = _np.random.choice(samples.shape[0], nsamples, replace=True, p=weights) thetas = samples[indices, 2:2+num_params] return thetas def _driver(self, run, likelihood, nsamples, cache, force_cache, cache_dir, read_only, archive, plots): """ Execute plotting loop given samples. """ thetas = self._draw_equally_weighted(run.samples, nsamples, len(likelihood)) signal = likelihood.signal # should only be one signal object available for p in plots: p.signal = likelihood.signal names = likelihood.names if cache and h5py is not None: try: s = signal.prefix except AttributeError: s = '' filename = run.prepend_ID.replace(' ', '_') filename += '__signal' + ('_' + s + '__' if s else '__') filename += 'cached__' filename += '__'.join(signal.caching_target_names) cache = _Cache(filename, cache_dir, read_only, archive) if cache.do_caching(thetas, force_cache): # skips body if can simply read cache for i in _range(thetas.shape[0], desc='Signal caching loop'): likelihood([thetas[i,run.get_index(n)] for n in names]) cache.cache(signal.caching_targets) elif cache and h5py is None: raise ImportError('You need to install h5py to use caching.') def update(theta): # order the parameter vector appropriately vector = [theta[run.get_index(n)] for n in names] if cache: # use the cache if available # set the parameter values in case needed by plot objects # that call methods of the signal object, or methods of objects # that the signal object encapsulates refernences to, an # example being the interstellar attenuation applied for # the spectrum plot type super(Likelihood, likelihood).__call__(vector) # now restore the signals objects that were cached cached = next(cache) for key, value in cached.items(): try: delattr(signal, key) except AttributeError: pass if len(value.shape) == 3: for i in range(value.shape[0]): setattr(signal, key, value[i,...]) else: setattr(signal, key, value) else: # otherwise resort to likelihood evaluations likelihood(vector) def wrapper(plot_obj, delete_me=None, index=None): """ Wrap a plot obj's next method into a cache-enabled callback. """ if cache: cache.reset_iterator() if delete_me is not None: try: iter(delete_me) except TypeError: delete_me = [delete_me] for attr in delete_me: try: delattr(plot_obj, attr) except AttributeError: pass # ignore x because it is already known by a plot object in # this more object oriented approach to calling fgivenx: if index is not None: def callback(x, theta): update(theta) return next(plot_obj)[index] return callback # for fgivenx else: def callback(x, theta): update(theta) return next(plot_obj) return callback for plot in plots: with plot: # acquire plot as context manager plot.execute(thetas, wrapper) # iterate over samples truths = [run.truth_vector[run.get_index(n)] for n in likelihood.names] if None not in truths: likelihood(truths)