from ._global_imports import *
try:
import fgivenx
except ImportError:
_warning('Cannot import fgivenx for conditional posterior contours.')
fgivenx = None
from ..tools.phase_interpolator import phase_interpolator as interp
from ._signalplot import SignalPlot
[docs]class PulsePlot(SignalPlot):
""" Plot posterior-averaged channel-summed count-rate pulse profiles.
The figure contains four panels which share phase as an x-axis:
* the first (topmost) panel displays the posterior expectation of
the specific photon flux signal from the source, jointly resolved in
energy and phase;
* the second panel displays the energy-integrated photon specific flux
signal as function of phase for a subset of samples, optionally using
:mod:`fgivenx`;
* the third panel displays the posterior expectation of the count-rate
signal as a function of channel and phase;
* the last (bottommost) panel displays the channel-summed pulse as
a function of phase for a subset of samples, optionally using
:mod:`fgivenx`.
The second and last panels aim to render the conditional posterior
distribution of the associated signal as a function phase, ideally with
contours to map out the conditional posterior mass. These panels have space
to optionally display other elements such as: the posterior-expected total
signal; the posterior-expected component signals; and the true total and
component signals if the ground truth (the injected signal correpsonding
to some model parameter vector) is known.
The following example is from `Riley et al. 2019 <https://ui.adsabs.harvard.edu/abs/2019ApJ...887L..21R/abstract>`_:
.. image:: _static/_pulseplot.png
:param int num_phases:
The number of phases to interpolate the pulse-profile signals at.
:param str incident_cmap:
Colormap name from :mod:`matplotlib` to use for the posterior-expected
incident signal as a function of energy and phase (top panel).
:param str registered_cmap:
Colormap name from :mod:`matplotlib` to use for the posterior-expected
registered signal as a function of channel and phase (third panel).
:param bool show_components:
If the :class:`~.Signal.Signal` instance has multiple components,
display the posterior expectations of those components as a
function of phase (second and last panels).
:param dict expectation_line_kwargs:
Keyword arguments for plotting the posterior-expected signal lines (in
the second and last panels).
:param bool use_fgivenx:
Use :mod:`fgivenx` to plot conditional posterior contours in the
second and last panels?
:param dict incident_contour_kwargs:
Keyword arguments for :mod:`fgivenx` incident signal contours (second
panel) that will take precedence over the corresponding class
attributes. (See the :class:`~.SignalPlot` class if you choose
not to modify these attributes on this present subclass.)
:param dict registered_contour_kwargs:
Keyword arguments for :mod:`fgivenx` registered signal contours (last
panel) that will take precedence over the corresponding class
attributes. (See the :class:`~.SignalPlot` class if you choose
not to modify these attributes on this present subclass.)
:param plot_truth:
Plot the ground truth (injected) signal, if known and available, in
the second and last panels.
:param truth_line_kwargs:
Keyword arguments for plotting the ground truth signal lines (second
and last panels).
"""
__figtype__ = 'signalplot_pulse'
# do not change at runtime (see base class comments):
__caching_targets__ = ['shifts', # hot region phase shifts
'signals', # count-rate signals
# incident flux signals integrated over energy bins
'incident_flux_signals']
__rows__ = 4
__columns__ = 1
__ax_rows__ = 4
__ax_columns__ = 2
__height_ratios__ = [1]*4
__width_ratios__ = [50, 1] # second column for colorbars
__wspace__ = 0.025
__hspace__ = 0.125
@make_verbose('Instantiating a pulse-profile plotter for posterior checking',
'Pulse-profile plotter instantiated')
def __init__(self,
num_phases=1000,
incident_cmap='inferno',
registered_cmap='inferno',
show_components=False,
expectation_line_kwargs=None,
comp_expectation_line_kwargs=None,
sample_line_kwargs=None,
use_fgivenx=False,
incident_contour_kwargs=None,
registered_contour_kwargs=None,
plot_truth=False,
truth_line_kwargs=None,
comp_truth_line_kwargs=None,
**kwargs):
super(PulsePlot, self).__init__(**kwargs)
self._phase_edges = _np.linspace(0.0, 2.0, int(num_phases))
self._phases = self._phase_edges[0:len(self._phase_edges)-1]+0.5*(self._phase_edges[1]-self._phase_edges[0])
if use_fgivenx and fgivenx is None:
raise ImportError('Install fgivenx to plot contours.')
self._use_fgivenx = use_fgivenx
if self._use_fgivenx:
self._incident_contour_kwargs =\
incident_contour_kwargs if incident_contour_kwargs else {}
self._registered_contour_kwargs =\
registered_contour_kwargs if registered_contour_kwargs else {}
self._get_figure()
# for the incident specific flux signal incident on an instrument
self._ax = self._add_subplot(0,0)
self._ax_1d = self._add_subplot(1,0)
# for the count-rate signal registered by an instrument
self._ax_registered = self._add_subplot(2,0)
self._ax_registered_1d = self._add_subplot(3,0)
# properties for phase each axis
for ax in self._axes:
ax.xaxis.set_major_locator(MultipleLocator(0.2))
ax.xaxis.set_minor_locator(MultipleLocator(0.05))
ax.set_xlim([0.0,2.0])
if ax is not self._ax_registered_1d:
ax.tick_params(axis='x', labelbottom=False)
else:
ax.set_xlabel('$\phi$ [cycles]')
self._ax.set_ylabel(r'$E$ [keV]')
self._ax_1d.set_ylabel(r'photons/cm$^{2}$/s')
self._ax_registered.set_ylabel('channel')
self._ax_registered_1d.set_ylabel('counts/s')
# colorbars
self._ax_cb = self._add_subplot(0,1)
self._ax_registered_cb = self._add_subplot(2,1)
if self._use_fgivenx:
self._ax_1d_cb = self._add_subplot(1,1)
self._ax_registered_1d_cb = self._add_subplot(3,1)
self._show_components = show_components
self._incident_cmap = incident_cmap
self._registered_cmap = registered_cmap
if sample_line_kwargs is not None:
self._sample_line_kwargs = sample_line_kwargs
else:
self._sample_line_kwargs = {}
self._plot_truth = plot_truth
if self._plot_truth:
if truth_line_kwargs is None:
self._truth_line_kwargs = \
dict(color=('b' if self._use_fgivenx else 'darkgreen'),
ls='-.',
lw=1.0,
alpha=1.0)
else:
self._truth_line_kwargs = truth_line_kwargs
if comp_truth_line_kwargs is not None:
self._comp_truth_line_kwargs = comp_truth_line_kwargs
else:
self._comp_truth_line_kwargs = self._truth_line_kwargs
if expectation_line_kwargs is None:
color = 'k' if self._use_fgivenx else 'r'
self._expectation_line_kwargs = dict(color=color,
ls='-',
lw=1.0,
alpha=1.0)
else:
self._expectation_line_kwargs = expectation_line_kwargs
if comp_expectation_line_kwargs is not None:
self._comp_expectation_line_kwargs = comp_expectation_line_kwargs
else:
self._comp_expectation_line_kwargs = self._expectation_line_kwargs
plt.close()
@property
def instruction_set(self):
return self._instruction_set
@instruction_set.deleter
def instruction_set(self):
try:
del self._instruction_set
except AttributeError:
pass
[docs] @make_verbose('PulsePlot object iterating over samples',
'PulsePlot object finished iterating')
def execute(self, thetas, wrapper):
""" Loop over posterior samples. """
self._num_samples = thetas.shape[0]
if self._use_fgivenx:
self._instruction_set = 0
self._add_incident_contours(wrapper(self, 'incident_sums'),
thetas)
yield 'Added conditional posterior contours for incident signal.'
self._instruction_set = 1
self._add_registered_contours(wrapper(self, 'registered_sums'),
thetas)
yield 'Added conditional posterior contours for registered signal.'
else: # iterate manually instead of driven by fgivenx
del self.instruction_set
wrapped = wrapper(self, ['incident_sums', 'registered_sums'])
for i in range(self._num_samples):
wrapped(None, thetas[i,:])
yield
def __next__(self):
""" Update posterior expected signals given the updated signal object.
Plots signals if :mod:`fgivenx` is not used, otherwise returns
callback information for :mod:`fgivenx`.
.. note::
You cannot make an iterator from an instance of this class.
"""
try:
self._instruction_set
except AttributeError:
incident = self._handle_incident()
self._add_signal(self._ax_1d,
self._phases,
incident,
**self._sample_line_kwargs)
registered = self._handle_registered()
self._add_signal(self._ax_registered_1d,
self._phases,
registered,
**self._sample_line_kwargs)
else:
if self._instruction_set == 0:
return self._handle_incident() # end execution here
if self._instruction_set == 1:
return self._handle_registered()
return None # reached if not invoking fgivenx
def _handle_incident(self):
""" Instructions for handling the incident signal. """
ref = self._signal
try:
self._incident_sums
except AttributeError:
self._incident_sums = [None]
self._incident_sums *= len(ref.incident_flux_signals)
incident = None
for i, component in enumerate(ref.incident_flux_signals):
temp = interp(self._phases,
ref.phases[i],
component,
ref.shifts[i])
try:
incident += temp
except TypeError:
incident = temp
try:
self._incident_sums[i] += temp
except TypeError:
self._incident_sums[i] = temp.copy()
return _np.sum(incident, axis=0)
@property
def incident_sums(self):
return self._incident_sums
@incident_sums.deleter
def incident_sums(self):
del self._incident_sums
@property
def expected_incident(self):
""" Get the expectations of the component incident signals. """
return [component/self._num_samples for component in self._incident_sums]
def _handle_registered(self):
""" Instructions for handling the registered signal. """
ref = self._signal
try:
self._registered_sums
except AttributeError:
self._registered_sums = [None] * len(ref.signals)
registered = None
for i, (signal, shift) in enumerate(zip(ref.signals, ref.shifts)):
temp = interp(self._phases,
ref.phases[i],
signal, shift)
try:
registered += temp
except TypeError:
registered = temp
try:
self._registered_sums[i] += temp
except TypeError:
self._registered_sums[i] = temp.copy()
return _np.sum(registered, axis=0)
@property
def registered_sums(self):
return self._registered_sums
@registered_sums.deleter
def registered_sums(self):
del self._registered_sums
@property
def expected_registered(self):
""" Get the expectations of the component registered signals. """
return [component/self._num_samples for component in self._registered_sums]
[docs] @make_verbose('PulsePlot object finalizing',
'PulsePlot object finalized')
def finalize(self):
""" Execute final instructions. """
ref = self._signal
self._plot_components = self._show_components and ref.num_components > 1
# add the incident signals
if self._plot_truth:
self._add_true_incident_signals()
self._add_expected_incident_signals()
# add the registered signals
if self._plot_truth:
self._add_true_registered_signals()
self._add_expected_registered_signals()
def _add_true_incident_signals(self):
""" Render ground truth incident (component) signals. """
ref = self._signal
total = None
for component, shift, phases in zip(ref.incident_flux_signals,
ref.shifts,
ref.phases):
temp = interp(self._phases,
phases,
component,
shift)
try:
total += temp
except TypeError:
total = temp
if self._plot_components:
self._add_signal(self._ax_1d,
self._phases,
temp,
axis=0,
**self._comp_truth_line_kwargs)
self._add_signal(self._ax_1d,
self._phases,
total,
axis=0,
**self._truth_line_kwargs)
def _add_expected_incident_signals(self):
""" Render posterior-expected incident (component) signals. """
ref = self._signal
total = None
for component in self.expected_incident:
try:
total += component
except TypeError:
total = component
if self._plot_components:
self._add_signal(self._ax_1d,
self._phases,
component,
axis=0,
**self._comp_expectation_line_kwargs)
# 1D
self._add_signal(self._ax_1d,
self._phases,
total,
axis=0,
**self._expectation_line_kwargs)
self._ax_1d.yaxis.set_major_locator(_get_default_locator(None))
self._ax_1d.yaxis.set_major_formatter(_get_default_formatter())
self._ax_1d.yaxis.set_minor_locator(AutoMinorLocator())
# 2D
Delta_E = ref.energy_edges[1:] - ref.energy_edges[:-1]
for i in range(total.shape[1]):
total[:,i] /= Delta_E # mean specific flux in each interval
incident = self._ax.pcolormesh(self._phase_edges,
ref.energy_edges,
total,
cmap = cm.get_cmap(self._incident_cmap),
linewidth = 0,
rasterized = self._rasterized)
incident.set_edgecolor('face')
self._ax.set_ylim([ref.energy_edges[0],
ref.energy_edges[-1]])
self._ax.set_yscale('log')
self._incident_cb = plt.colorbar(incident, cax=self._ax_cb,
ticks=_get_default_locator(None),
format=_get_default_formatter())
self._incident_cb.ax.set_frame_on(True)
self._incident_cb.ax.yaxis.set_minor_locator(AutoMinorLocator())
self._incident_cb.set_label(label=r'photons/keV/cm$^{2}$/s',
labelpad=15)
def _add_true_registered_signals(self):
""" Render ground truth registered (component) signals. """
ref = self._signal
total = None
for component, shift, phases in zip(ref.signals, ref.shifts, ref.phases):
temp = interp(self._phases,
phases,
component,
shift)
try:
total += temp
except TypeError:
total = temp
if self._plot_components:
self._add_signal(self._ax_registered_1d,
self._phases,
temp,
axis=0,
**self._truth_line_kwargs)
self._add_signal(self._ax_registered_1d,
self._phases,
total,
axis=0,
**self._truth_line_kwargs)
def _add_expected_registered_signals(self):
""" Render posterior-expected registered (component) signals. """
ref = self._signal
total = None
for component in self.expected_registered:
try:
total += component
except TypeError:
total = component
if self._plot_components:
self._add_signal(self._ax_registered_1d,
self._phases,
component,
axis=0,
**self._comp_expectation_line_kwargs)
# 1D
self._add_signal(self._ax_registered_1d,
self._phases,
total,
axis=0,
**self._expectation_line_kwargs)
self._ax_registered_1d.yaxis.set_major_locator(_get_default_locator(None))
self._ax_registered_1d.yaxis.set_major_formatter(_get_default_formatter())
self._ax_registered_1d.yaxis.set_minor_locator(AutoMinorLocator())
# 2D
registered = self._ax_registered.pcolormesh(self._phases,
ref.data.channels,
total,
cmap = cm.get_cmap(self._registered_cmap),
linewidth = 0,
rasterized = self._rasterized)
registered.set_edgecolor('face')
self._ax_registered.set_ylim([ref.data.channels[0],
ref.data.channels[-1]])
self._ax_registered.set_yscale('log')
self._registered_cb = plt.colorbar(registered,
cax=self._ax_registered_cb,
ticks=_get_default_locator(None),
format=_get_default_formatter())
self._registered_cb.ax.set_frame_on(True)
self._registered_cb.ax.yaxis.set_minor_locator(AutoMinorLocator())
self._registered_cb.set_label(label=r'counts/s', labelpad=15)
@make_verbose('Adding credible intervals on the incident photon flux '
'signal as function of phase',
'Credible intervals added')
def _add_incident_contours(self, callback, thetas):
""" Add contours to 1D incident photon flux signal axes objects. """
self._add_contours(callback, thetas, self._phases,
self._ax_1d, self._ax_1d_cb,
**self._incident_contour_kwargs)
label = r'$\pi(\mathrm{photons/cm}^{2}\mathrm{/s};\phi)$'
self._ax_1d_cb.set_ylabel(label)
@make_verbose('Adding credible intervals on the count-rate '
'signal as function of phase',
'Credible intervals added')
def _add_registered_contours(self, callback, thetas):
""" Add contours to 1D count-rate signal axes objects. """
self._add_contours(callback, thetas, self._phases,
self._ax_registered_1d, self._ax_registered_1d_cb,
**self._registered_contour_kwargs)
self._ax_registered_1d_cb.set_ylabel(r'$\pi(\mathrm{counts/s};\phi)$')