Source code for xpsi.PostProcessing._backends

from ._global_imports import *
from ._run import Run

try:
    from getdist.mcsamples import MCSamples
except ImportError:
    _warning('Cannot create a GetDist sample backend.')

try:
    from nestcheck.data_processing import process_multinest_run
    from nestcheck.data_processing import process_polychord_run
except ImportError:
    _warning('Cannot use nestcheck sample backend.')

[docs]class NestedBackend(Run): """ Container for nested samples generated by a single run, and backends for analysis of the run. The other keyword arguments are generic properties passed to the parent class, such as the identification (ID) string of the run. :param str root: The root filename of the sample file collection. :param str base_dir: The directly containing the sample file collection. :param bool use_nestcheck: Invoke :mod:`nestcheck` for nested sampling error analysis? :param callable transform: A function to transform the parameter vector to another space. """ def __init__(self, root, base_dir, use_nestcheck, transform=None, overwrite_transformed=False, **kwargs): filerootpath =_os.path.join(base_dir, root) _filerootpath = filerootpath if transform is not None: samples = _np.loadtxt(filerootpath+'.txt') ndims = samples.shape[1] - 2 temp = transform(samples[0,2:], old_API=True) ntransform = len(temp) - ndims _exists = _os.path.isfile(filerootpath+'_transformed.txt') if not _exists or overwrite_transformed: transformed = _np.zeros((samples.shape[0], samples.shape[1] + ntransform)) transformed[:,:2] = samples[:,:2] for i in range(samples.shape[0]): transformed[i,2:] = transform(samples[i,2:], old_API=True) _np.savetxt(filerootpath+'_transformed.txt', transformed) filerootpath += '_transformed' root += '_transformed' super(NestedBackend, self).__init__(filepath=filerootpath+'.txt',**kwargs) if getdist is not None: # getdist backend self._gd_bcknd = MCSamples(root=filerootpath, settings=self.kde_settings, sampler='nested', names=self.names, ranges=self.bounds, labels=[self.labels[name] for name in self.names]) self._gd_bcknd.readChains(getdist.chains.chainFiles(filerootpath)) self.use_nestcheck = use_nestcheck if self.use_nestcheck: # nestcheck backend if transform is not None: for ext in ['dead-birth.txt', 'phys_live-birth.txt']: _exists = _os.path.isfile(filerootpath + ext) if not _exists or overwrite_transformed: samples = _np.loadtxt(_filerootpath + ext) transformed = _np.zeros((samples.shape[0], samples.shape[1] + ntransform)) transformed[:,ndims+ntransform:] = samples[:,ndims:] for i in range(samples.shape[0]): transformed[i,:ndims+ntransform] =\ transform(samples[i,:ndims], old_API=True) _np.savetxt(filerootpath + "-" + ext, transformed) # .stats file with same root needed, but do not need to modify # the .stats file contents if not _os.path.isfile(filerootpath + '.stats'): if _os.path.isfile(_filerootpath + '.stats'): try: from shutil import copyfile as _copyfile except ImportError: pass else: _copyfile(_filerootpath + '.stats', filerootpath + '.stats') try: kwargs['implementation'] except KeyError: print('Root %r sampling implementation not specified... ' 'assuming MultiNest for nestcheck...') try: self._nc_bcknd = process_multinest_run(root, base_dir=base_dir) except FileNotFoundError: self._nc_bcknd = process_multinest_run(root+"-", base_dir=base_dir) else: if kwargs['implementation'] == 'multinest': try: self._nc_bcknd = process_multinest_run(root, base_dir=base_dir) except FileNotFoundError: self._nc_bcknd = process_multinest_run(root+"-", base_dir=base_dir) elif kwargs['implementation'] == 'polychord': self._nc_bcknd = process_polychord_run(root, base_dir=base_dir) else: raise ValueError('Cannot process with nestcheck.') @property def getdist_backend(self): """ Get the :class:`getdist.mcsamples.MCSamples` instance. """ return self._gd_bcknd @property def nestcheck_backend(self): """ Get the :mod:`nestcheck` backend for the nested samples. """ return self._nc_bcknd @property def margeStats(self): """ Return the marginal statistics using :mod:`getdist`. """ return self._mcsamples.getMargeStats()