Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions mnms/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class Params(ABC):
def __init__(self, *args, data_model_name=None, subproduct=None,
maps_product=None, maps_subproduct='default',
enforce_equal_qid_kwargs=None, calibrated=False,
differenced=True, srcfree=True, iso_filt_method=None,
ivar_filt_method=None, filter_kwargs=None, ivar_fwhms=None,
ivar_lmaxs=None, masks_subproduct=None, mask_est_name=None,
mask_est_edgecut=0, mask_est_apodization=0,
mask_obs_name=None, mask_obs_edgecut=0,
calibrations_subproduct=None, differenced=True, srcfree=True,
iso_filt_method=None, ivar_filt_method=None,
filter_kwargs=None, ivar_fwhms=None, ivar_lmaxs=None,
masks_subproduct=None, mask_est_name=None, mask_est_edgecut=0,
mask_est_apodization=0, mask_obs_name=None, mask_obs_edgecut=0,
model_lim=None, model_lim0=None,
catalogs_subproduct=None, catalog_name=None,
kfilt_lbounds=None, dtype=np.float32, model_file_template=None,
Expand All @@ -49,7 +49,14 @@ def __init__(self, *args, data_model_name=None, subproduct=None,
what is supplied here, 'num_splits' is always enforced. All enforced kwargs
are available to be passed to model or sim filename templates.
calibrated : bool, optional
Whether to load calibrated raw data, by default False.
Whether to apply calibration factors to simulations after they are
drawn by default, by default False. If True, calibration factors
will be applied by default but this can be negated at runtime. If
False, calibration factors will not be applied by default but can
be supplied at runtime.
calibrations_subproduct : str, optional
The calibrations subproduct within the supplied data model to use
if calibrated is True. Disregarded if calibrated is False.
differenced : bool, optional
Whether to take differences between splits or treat loaded maps as raw noise
(e.g., a time-domain sim) that will not be differenced, by default True.
Expand Down Expand Up @@ -144,6 +151,7 @@ def __init__(self, *args, data_model_name=None, subproduct=None,

# other instance properties
self._calibrated = calibrated
self._calibrations_subproduct = calibrations_subproduct
self._differenced = differenced
self._dtype = np.dtype(dtype) # better str(...) appearance
self._srcfree = srcfree
Expand Down Expand Up @@ -216,6 +224,7 @@ def param_formatted_dict(self):
maps_product=self._maps_product,
maps_subproduct=self._maps_subproduct,
calibrated=self._calibrated,
calibrations_subproduct=self._calibrations_subproduct,
catalogs_subproduct=self._catalogs_subproduct,
catalog_name=self._catalog_name,
differenced=self._differenced,
Expand Down
161 changes: 123 additions & 38 deletions mnms/noise_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(self, *qids, subproduct_kwargs=None, **kwargs):
# prepare cache
cache = {}
self._permissible_cache_keys = [
'mask_est', 'mask_obs', 'sqrt_ivar', 'cfact', 'dmap', 'model'
'mask_est', 'mask_obs', 'sqrt_ivar', 'cfact', 'dmap', 'model', 'cal'
]
for k in self._permissible_cache_keys:
if k not in cache:
Expand Down Expand Up @@ -377,10 +377,6 @@ def get_sqrt_ivar(self, split_num, downgrade=1):
# outer loop: qids, inner loop: scales (minimize i/o)
for i, (subproduct_kwargs, qid) in enumerate(self._super_qids):
with bench.show(f'Generating sqrt_ivars for {utils.kwargs_str(text_terminator=", ", **subproduct_kwargs)}qid {qid}'):
if self._calibrated:
mul = utils.get_mult_fact(self._data_model, qid, ivar=True)
else:
mul = 1

# we want to do this split-by-split in case we can save
# memory by downgrading one split at a time
Expand All @@ -390,7 +386,6 @@ def get_sqrt_ivar(self, split_num, downgrade=1):
**subproduct_kwargs
)
ivar = enmap.extract(ivar, self._full_shape, self._full_wcs)
ivar *= mul

if downgrade != 1:
if self._variant == 'cc':
Expand Down Expand Up @@ -492,10 +487,6 @@ def get_cfact(self, split_num, downgrade=1):

for i, (subproduct_kwargs, qid) in enumerate(self._super_qids):
with bench.show(f'Generating difference-map correction-factors for {utils.kwargs_str(text_terminator=", ", **subproduct_kwargs)}qid {qid}'):
if self._calibrated:
mul = utils.get_mult_fact(self._data_model, qid, ivar=True)
else:
mul = 1

# get the coadd from disk, this is the same for all splits
cvar = utils.read_map(
Expand All @@ -504,7 +495,6 @@ def get_cfact(self, split_num, downgrade=1):
**subproduct_kwargs
)
cvar = enmap.extract(cvar, self._full_shape, self._full_wcs)
cvar *= mul

# we want to do this split-by-split in case we can save
# memory by downgrading one split at a time
Expand All @@ -514,7 +504,6 @@ def get_cfact(self, split_num, downgrade=1):
**subproduct_kwargs
)
ivar = enmap.extract(ivar, self._full_shape, self._full_wcs)
ivar *= mul

cfact = utils.get_corr_fact(ivar, sum_ivar=cvar)

Expand Down Expand Up @@ -593,12 +582,6 @@ def get_dmap(self, split_num, downgrade=1, keep_mask_obs=True):

for i, (subproduct_kwargs, qid) in enumerate(self._super_qids):
with bench.show(f'Generating difference maps for {utils.kwargs_str(text_terminator=", ", **subproduct_kwargs)}qid {qid}'):
if self._calibrated:
mul_imap = utils.get_mult_fact(self._data_model, qid, ivar=False)
mul_ivar = utils.get_mult_fact(self._data_model, qid, ivar=True)
else:
mul_imap = 1
mul_ivar = 1

# get the coadd from disk, this is the same for all splits
if self._differenced:
Expand All @@ -608,7 +591,6 @@ def get_dmap(self, split_num, downgrade=1, keep_mask_obs=True):
**subproduct_kwargs
)
cmap = enmap.extract(cmap, self._full_shape, self._full_wcs)
cmap *= mul_imap
else:
cmap = 0

Expand All @@ -620,7 +602,6 @@ def get_dmap(self, split_num, downgrade=1, keep_mask_obs=True):
**subproduct_kwargs
)
cvar = enmap.extract(cvar, self._full_shape, self._full_wcs)
cvar *= mul_ivar

# we want to do this split-by-split in case we can save
# memory by downgrading one split at a time
Expand All @@ -630,7 +611,6 @@ def get_dmap(self, split_num, downgrade=1, keep_mask_obs=True):
**subproduct_kwargs
)
imap = enmap.extract(imap, self._full_shape, self._full_wcs)
imap *= mul_imap

# need to reload ivar at full res and get ivar_eff
# if inpainting or kspace filtering
Expand All @@ -641,7 +621,6 @@ def get_dmap(self, split_num, downgrade=1, keep_mask_obs=True):
**subproduct_kwargs
)
ivar = enmap.extract(ivar, self._full_shape, self._full_wcs)
ivar *= mul_ivar
if self._differenced:
ivar_eff = utils.get_ivar_eff(ivar, sum_ivar=cvar, use_zero=True)
else:
Expand Down Expand Up @@ -751,7 +730,9 @@ def _empty(self, shape, wcs, ivar=False, num_arrays=None, num_splits=None):
ivar=ivar, maps_subproduct=self._maps_subproduct,
srcfree=self._srcfree, **subproduct_kwargs
)
shape = (shape[0], *footprint_shape)
npol = shape[0]
assert npol in (1, 3), f'Got {npol=}, must be 1 or 3'
shape = (npol, *footprint_shape)

if num_arrays is None:
num_arrays = self._num_arrays
Expand All @@ -761,14 +742,60 @@ def _empty(self, shape, wcs, ivar=False, num_arrays=None, num_splits=None):
shape = (num_arrays, num_splits, *shape)
return enmap.empty(shape, wcs=footprint_wcs, dtype=self._dtype)

def get_cal(self, calibrations_subproduct, alm=False):
"""Get calibration and polarization factors from this noise model's
data model's calibration product under 'calibration_subproduct'.
The factors will broadcast against a simulation of any shape, i.e.,
they have shape:

(nmaps, 1, npol, 1, 1)

where nmaps is the number of super qids, and npol is 1 or 3, depending
on the data shape. For each map, if npol is 1, the factor is given by
c where c is a calibration. If npol is 3, the factors are (c, c/p, c/p)
where p is the polarization efficiency.

Parameters
----------
calibrations_subproduct : str
Name of the calibrations_subproduct entry in this noise model's data
model.
alm : bool, optional
Whether the cals need to broadcast against an alm instead, in which
case the last singleton axis is removed, by default False.

Returns
-------
(nmaps, 1, npol, 1, 1) np.ndarray
Calibration (and possibly polarization efficiency) factors.
"""
# get an array that will broadcast against dmaps
shape = self._empty((1, 1), self._full_wcs, ivar=False, num_splits=1).shape
cals = np.zeros(shape, dtype=self._dtype)
npol = shape[-3]

for i, (subproduct_kwargs, qid) in enumerate(self._super_qids):
cals[i, 0, 0] = self._data_model.read_calibration(
qid, which='cals', subproduct=calibrations_subproduct,
**subproduct_kwargs)
if npol == 3:
cals[i, 0, 1:] = cals[i, 0, 0] / self._data_model.read_calibration(
qid, which='poleffs', subproduct=calibrations_subproduct,
**subproduct_kwargs)

if alm:
cals = cals[..., 0]

return cals

def cache_data(self, cacheprod, data, *args, **kwargs):
"""Add some data to the cache.

Parameters
----------
cacheprod : str
The "cache product", must be one of 'mask_est', 'mask_obs',
'sqrt_ivar', 'cfact', 'dmap', or 'model'.
'sqrt_ivar', 'cfact', 'dmap', 'model', or 'cal'.
data : any
Item to be stored.
args : tuple, optional
Expand All @@ -793,9 +820,7 @@ def get_from_cache(self, cacheprod, *args, **kwargs):
----------
cacheprod : str
The "cache product", must be one of 'mask_est', 'mask_obs',
'sqrt_ivar', 'cfact', 'dmap', or 'model'.
data : any
Item to be stored.
'sqrt_ivar', 'cfact', 'dmap', 'model', or 'cal'.
args : tuple, optional
data will be stored under a key formed by (*args, **kwargs),
where the args are order-sensitive and the kwargs are
Expand All @@ -822,12 +847,12 @@ def cache_clear(self, *args, **kwargs):
----------
args : tuple, optional
If provided, the first arg must be the "cacheprod", i.e., one of
'mask_est', 'mask_obs', 'sqrt_ivar', 'cfact', 'dmap', or 'model'. If
no subsequent args are provided (and no kwargs are provided), all
data under that "cacheprod" is deleted. If provided, subsequent
args are used with kwargs to form a key (*args, **kwargs), where
the args are order-sensitive and the kwargs are order-insensitive.
Then, the data under that key only is deleted.
'mask_est', 'mask_obs', 'sqrt_ivar', 'cfact', 'dmap', 'model', or
'cal'. If no subsequent args are provided (and no kwargs are
provided), all data under that "cacheprod" is deleted. If provided,
subsequent args are used with kwargs to form a key (*args,
**kwargs), where the args are order-sensitive and the kwargs are
order-insensitive. Then, the data under that key only is deleted.
kwargs : dict, optional
If provided, used with args to form a key (*args, **kwargs), where
the args are order-sensitive and the kwargs are order-insensitive.
Expand Down Expand Up @@ -1632,10 +1657,11 @@ def filter_model(cls, inp, iso_filt_method=None, ivar_filt_method=None,

return inp, out

def get_sim(self, split_num, sim_num, lmax, seed=-1, alm=False,
def get_sim(self, split_num, sim_num, lmax, seed='auto', alm=False,
calibrate='auto', calibrations_subproduct=None,
check_on_disk=True, generate=True, keep_model=True,
keep_mask_obs=True, keep_sqrt_ivar=True, write=False,
verbose=False):
keep_mask_obs=True, keep_sqrt_ivar=True, keep_cal=True,
write=False, verbose=False):
"""Load or generate a sim from this NoiseModel. Will load necessary
products to disk if not yet stored in instance attributes.

Expand All @@ -1649,7 +1675,7 @@ def get_sim(self, split_num, sim_num, lmax, seed=-1, alm=False,
9999, ie, one cannot have more than 10_000 of the same sim, of the same split,
from the same noise model (including the 'notes').
seed : int or iterable of ints
Seed for random draw. If -1 (the default), a seed unique to the split_num,
Seed for random draw. If 'auto' (the default), a seed unique to the split_num,
sim_num, data_model_name, maps_product, maps_subproduct, noise_model_name,
qids, and subproduct_kwargs of this noise model instance is used. User can
manually pass 'None' or their own int or iterable of ints; however, they
Expand All @@ -1658,6 +1684,17 @@ def get_sim(self, split_num, sim_num, lmax, seed=-1, alm=False,
Bandlimit for output noise covariance.
alm : bool, optional
Generate simulated alms instead of a simulated map, by default False.
calibrate : bool, optional
Whether to apply calibration and polarization efficiency factors to
the simulation, by default 'auto'. If 'auto', the 'calibrated'
setting in the configuration file for this noise model will be used,
(along with the calibrations_subproduct if the setting is 'true').
If True, see calibrations_subproduct. If False, do not apply any
calibration factor. See notes about writing simulations to disk.
calibrations_subproduct : str, optional
If calibrate is True, this will be used to override the calibration
subproduct versus what is indicated in the configuration file. This
argument is ignored if calibrate is 'auto' or False.
check_on_disk : bool, optional
If True, first check if an identical sim (including the noise model 'notes')
exists on-disk. If it does not, generate the sim if 'generate' is True, and
Expand All @@ -1677,6 +1714,12 @@ def get_sim(self, split_num, sim_num, lmax, seed=-1, alm=False,
keep_sqrt_ivar : bool, optional
Store the loaded, possibly downgraded, sqrt_ivar in the instance
attributes, by default True.
keep_cal : bool, optional
Store the loaded calibration and polarization efficiency factors, by
default True. Will be stored according to the specific calibrations
subproduct used in this call, i.e. the subproduct from the config
file if calibrate is 'auto' or the passed subproduct for this call
if calibrate is True.
write : bool, optional
Save a generated sim to disk, by default False.
verbose : bool, optional
Expand All @@ -1688,6 +1731,16 @@ def get_sim(self, split_num, sim_num, lmax, seed=-1, alm=False,
A sim of this noise model with the specified sim num, with shape
(num_arrays, num_splits=1, num_pol, ny, nx), even if some of these
axes have size 1. As implemented, num_splits is always 1.

Notes
-----
If a simulation is requested to be calibrated and written to disk, only
the uncalibrated version is written to disk, while the calibrated version
is returned. This is so metadata about the calibration does not have to
be tracked along with the file on-disk -- it is always uncalibrated. If
a simulation is requested to be calibrated and it already exists on-disk
and therefore is loaded from disk, the calibration factors will be
applied at runtime.
"""
_filter_kwargs = {} if self._filter_kwargs is None else self._filter_kwargs

Expand All @@ -1701,9 +1754,31 @@ def get_sim(self, split_num, sim_num, lmax, seed=-1, alm=False,
self._full_shape, self._full_wcs, downgrade
)

# get the cal before possibly loading sim from disk. we adopt
# convention that sims on disk are never calibrated, so that sims drawn
# on the fly and sims on disk are treated equivalently
if calibrate == 'auto' and self._calibrated:
_calibrations_subproduct = self._calibrations_subproduct
_calibrate = True
elif calibrate is True: # 'auto' is truthy, so need to be careful
_calibrations_subproduct = calibrations_subproduct
_calibrate = True
else:
_calibrate = False

if _calibrate:
try:
cal = self.get_from_cache('cal', _calibrations_subproduct, alm=alm) # NOTE: assumes cal indep. of split
cal_from_cache = True
except KeyError:
cal = self.get_cal(_calibrations_subproduct, alm=alm)
cal_from_cache = False

if check_on_disk:
res = self._check_sim_on_disk(split_num, sim_num, lmax, alm=alm, generate=generate)
if res is not False:
if _calibrate:
res *= cal
return res
else: # generate == True
pass
Expand Down Expand Up @@ -1732,7 +1807,7 @@ def get_sim(self, split_num, sim_num, lmax, seed=-1, alm=False,
sqrt_ivar *= mask_obs

# seed = utils.get_seed(split_num, sim_num, self.noise_model_name, *self._qids, n_max_strs=5) OLD
if seed == -1:
if seed == 'auto':
seed = utils.get_seed(split_num, sim_num, self._data_model_name,
self._maps_product, self._maps_subproduct,
self.noise_model_name, *self._super_qid_strs)
Expand Down Expand Up @@ -1777,11 +1852,21 @@ def get_sim(self, split_num, sim_num, lmax, seed=-1, alm=False,

if keep_sqrt_ivar and not sqrt_ivar_from_cache:
self.cache_data('sqrt_ivar', sqrt_ivar, split_num=split_num, downgrade=downgrade)

if _calibrate:
if keep_cal and not cal_from_cache:
self.cache_data('cal', cal, _calibrations_subproduct, alm=alm)

if write:
fn = self.get_sim_fn(split_num, sim_num, lmax, alm=alm, to_write=True)
self.write_sim(fn, sim, alm=alm)

# only calibrate after writing. see note above: we adopt
# convention that sims on disk are never calibrated, so that sims drawn
# on the fly and sims on disk are treated equivalently
if _calibrate:
sim *= cal

return sim

def _check_sim_on_disk(self, split_num, sim_num, lmax, alm=False, generate=True):
Expand Down
Loading