Source code for gwpopulation_pipe.utils

import os

from gwpopulation.conversions import convert_to_beta_parameters
from gwpopulation.hyperpe import HyperparameterLikelihood, xp
from gwpopulation.models.spin import (
    iid_spin,
    iid_spin_magnitude_beta,
    independent_spin_magnitude_beta,
)


[docs] def get_cosmology(name): from wcosmo.astropy import FlatLambdaCDM, available if name.lower() == "planck15_lal": return FlatLambdaCDM(H0=67.90, Om0=0.3065, name="Planck15_LAL") elif name in available.keys(): return available[name] else: raise ValueError(f"Unknown cosmology {name}")
[docs] def prior_conversion(parameters): """Wrapper around conversion for prior constraints""" for key in ["amax", "amax_1", "amax_2"]: if key not in parameters: parameters[key] = 1 parameters, _ = convert_to_beta_parameters(parameters) return parameters
[docs] def get_path_or_local(path: str) -> str: """ Check if a path exists, if not check if the basename exists, if not raise an error. Parameters ---------- path: str The path to check. Returns ------- path: str The path to use. Raises ------ ValueError If neither the path nor the basename exist. """ if os.path.exists(path): return path elif os.path.exists(os.path.basename(path)): return os.path.basename(path) else: raise ValueError(f"Cannot find {path} or {os.path.basename(path)}")
KNOWN_ARGUMENTS = { iid_spin: ["mu_chi", "sigma_chi", "xi_spin", "sigma_spin"], iid_spin_magnitude_beta: ["mu_chi", "sigma_chi"], independent_spin_magnitude_beta: [ "mu_chi_1", "mu_chi_2", "sigma_chi_1", "sigma_chi_2", ], }
[docs] class MinimumEffectiveSamplesLikelihood(HyperparameterLikelihood): def _compute_per_event_ln_bayes_factors(self, return_uncertainty=True): """ Compute the per event ln Bayes factors and associated variance. This method imposes a condition that the number of effective samples per Monte Carlo integral must be at least as much as the total number of events. Otherwise the lnBF is set to - infinity. Returns ------- ln_per_event_bfs: array-like The ln BF per event subject to having a sufficient number of independent samples. variance: array-like The variances (uncertainties) in the ln BF per event. Only returned if `return_uncertainty` is True. This output will generally not be used for convergence criteria, as this function already enforces a threshold on effective number of samples. """ ( per_event_bfs, n_effectives, variance, ) = self.per_event_bayes_factors_and_n_effective_and_variances() per_event_bfs *= n_effectives > self.n_posteriors if return_uncertainty: return xp.log(per_event_bfs), variance else: return xp.log(per_event_bfs)
[docs] def per_event_bayes_factors_and_n_effective_and_variances(self): """ Called by `_compute_per_event_ln_bayes_factors` to compute the per event BFs, effective number of samples for each event's computed BF, and the associated uncertainty (variance) in the *ln* BF. Computes same qunatities as superclass function `_compute_per_event_ln_bayes_factors` but additionally provides the effective sample size. Returns ------- per_event_bfs: array-like The BF per event, computed by reweighting single-event likelihood samples into the `hyper_prior` model. n_effectives: array-like The effective sample size for each Monte Carlo sum computation of the BFs. The BF is computed for each event, so this array has length n_events. variance: array-like The variances (uncertainties) in the ln BF per event. """ weights = self.hyper_prior.prob(self.data) / self.sampling_prior per_event_bfs = xp.sum(weights, axis=-1) n_effectives = xp.nan_to_num(per_event_bfs**2 / xp.sum(weights**2, axis=-1)) per_event_bfs /= self.samples_per_posterior square_expectation = xp.mean(weights**2, axis=-1) variance = (square_expectation - per_event_bfs**2) / ( self.samples_per_posterior * per_event_bfs**2 ) return per_event_bfs, n_effectives, variance
def per_event_bayes_factors_and_n_effective(self): ( per_event_bfs, n_effectives, _, ) = self.per_event_bayes_factors_and_n_effective_and_variances() return per_event_bfs, n_effectives
[docs] def maybe_jit(func): from gwpopulation.backend import __backend__ if __backend__ == "jax": from jax import jit return jit(func) else: return func