import os
from argparse import ArgumentParser
import dill
import h5py
import numpy as np
import pandas as pd
from bilby.core.result import read_in_result
from bilby.core.utils import logger
from gwpopulation.backend import set_backend, SUPPORTED_BACKENDS
from gwpopulation.utils import to_numpy, xp
from tqdm import tqdm
from .data_analysis import load_model
from .data_collection import evaluate_prior
from .utils import prior_conversion
from .vt_helper import load_injection_data
[docs]
def create_parser():
import wcosmo
parser = ArgumentParser()
parser.add_argument("-r", "--result-file", help="Bilby result file")
parser.add_argument(
"-s", "--samples-file", help="File containing single event samples"
)
parser.add_argument("--injection-file", help="File containing injections")
parser.add_argument("-f", "--filename", default=None, help="Output file name")
parser.add_argument("--max-redshift", default=2.3, type=float)
parser.add_argument("--minimum-mass", default=2, type=float)
parser.add_argument("--maximum-mass", default=100, type=float)
parser.add_argument(
"--n-events", type=int, default=None, help="Number of events to draw"
)
parser.add_argument(
"--n-samples",
type=int,
default=5000,
help="The number of population samples to use, default=5000",
)
parser.add_argument(
"--vt-ifar-threshold",
type=float,
default=1,
help="IFAR threshold for resampling injections",
)
parser.add_argument(
"--vt-snr-threshold",
type=float,
default=11,
help="IFAR threshold for resampling injections. "
"This is only used for O1/O2 injections",
)
parser.add_argument(
"--distance-prior",
default="euclidean",
help="Distance prior format, e.g., euclidean, comoving",
)
parser.add_argument(
"--mass-prior",
default="flat-detector",
help="Mass prior, only flat-detector is implemented",
)
parser.add_argument(
"--spin-prior",
default="component",
help="Spin prior, this should be either component or gaussian",
)
parser.add_argument(
"--backend",
default="jax",
choices=SUPPORTED_BACKENDS,
help="The backend to use, default is jax",
)
parser.add_argument(
"--cosmo", action="store_true", help="Whether to fit cosmological parameters."
)
parser.add_argument(
"--cosmology",
type=str,
default="Planck15_LAL",
help=(
"Cosmology to use for the analysis, this should be one of "
f"{', '.join(wcosmo.available.keys())}. Some of these are fixed pre-defined "
"cosmologies while others are parameterized cosmologies. If a parameterized "
"cosmology is used the parameters relevant parameters should be included "
"in the prior specification."
),
)
return parser
[docs]
def main():
parser = create_parser()
args = parser.parse_args()
set_backend(args.backend)
result = read_in_result(args.result_file)
n_samples = min(args.n_samples, len(result.posterior))
all_samples = dict()
posterior = result.posterior.sample(n_samples, replace=False)
posterior = do_spin_conversion(posterior)
all_samples["posterior"] = posterior
with open(args.samples_file, "rb") as ff:
samples = dill.load(ff)
if "prior" not in samples["original"]:
data = dict(original=samples["original"])
evaluate_prior(data, args)
samples.update(data)
for key in samples["original"]:
samples["original"][key] = xp.asarray(samples["original"][key])
if args.n_events:
n_draws = args.n_events
logger.info(f"Number of draws set to {n_draws}.")
else:
n_draws = len(samples["names"])
logger.info(f"Number of draws equals number of events, {n_draws}.")
logger.info("Generating observed populations.")
args.models = result.meta_data["models"]
model = load_model(args)
observed_dataset = resample_events_per_population_sample(
posterior=posterior,
samples=samples["original"],
model=model,
n_draws=len(samples["names"]),
)
for ii, name in enumerate(samples["names"]):
new_posterior = pd.DataFrame()
for key in observed_dataset:
new_posterior[f"{name}_{key}"] = to_numpy(observed_dataset[key][:, ii])
all_samples[name] = new_posterior
if args.injection_file is None:
logger.info("Can't generate predicted populations without VT file.")
else:
logger.info("Generating predicted populations.")
args.models = result.meta_data["vt_models"]
vt_data = load_injection_data(
args.injection_file,
ifar_threshold=args.vt_ifar_threshold,
snr_threshold=args.vt_snr_threshold,
).to_dict()
model = load_model(args)
synthetic_dataset = resample_injections_per_population_sample(
posterior=posterior,
data=vt_data,
model=model,
n_draws=n_draws,
)
for ii in range(n_draws):
new_posterior = pd.DataFrame()
for key in synthetic_dataset:
new_posterior[f"synthetic_{key}_{ii}"] = to_numpy(
synthetic_dataset[key][:, ii]
)
all_samples[f"synthetic_{ii}"] = new_posterior
if args.filename is None:
filename = os.path.join(result.outdir, f"{result.label}_full_posterior.hdf5")
else:
filename = args.filename
save_to_common_format(
posterior=all_samples, events=samples["names"], filename=filename
)
[docs]
def do_spin_conversion(posterior):
"""Utility function to convert between beta distribution parameterizations."""
original_keys = list(posterior.keys())
posterior = prior_conversion(posterior)
for key in ["amax", "amax_1", "amax_2"]:
if key not in original_keys and key in posterior:
del posterior[key]
return posterior
[docs]
def resample_events_per_population_sample(posterior, samples, model, n_draws):
"""
Resample the input posteriors with a fiducial prior to the population
informed distribution. This returns a single sample for each event for
each passed hyperparameter sample.
See, e.g., section IIIC of `Moore and Gerosa <https://arxiv.org/abs/2108.02462>`_
for a description of the method.
Parameters
----------
posterior: pd.DataFrame
Hyper-parameter samples to use for the reweighting.
samples: dict
Posterior samples with the fiducial prior.
model: bilby.hyper.model.Model
Object that implements a `prob` method that will calculate the population
probability.
n_draws: int
The number of samples to draw. This should generally be the number of events.
This will return one sample per input event.
Returns
-------
observed_dataset: dict
The observed dataset of the events with the population informed prior.
"""
from .utils import maybe_jit
@maybe_jit
def extract_choices(parameters, seed=None):
model.parameters.update(parameters)
if hasattr(model, "cosmology_names"):
data, jacobian = model.detector_frame_to_source_frame(samples, **parameters)
else:
data = samples
jacobian = 1
weights = model.prob(data) / data["prior"] / jacobian
weights = (weights.T / xp.sum(weights, axis=-1)).T
if "jax" in xp.__name__:
seeds = random.split(random.PRNGKey(seed), len(weights))
choices = [
random.choice(seeds[ii], len(weights[ii]), p=weights[ii])
for ii in range(n_draws)
]
else:
choices = [
random.choice(
len(weights[ii]),
p=weights[ii],
size=1,
).squeeze()
for ii in range(n_draws)
]
return xp.asarray(choices)
points = posterior.to_dict(orient="records")
if "jax" in xp.__name__:
from jax import random
seeds = np.random.choice(100000, len(posterior))
else:
random = xp.random
seeds = [None] * len(posterior)
inputs = list(zip(points, seeds))
all_choices = xp.asarray(
[extract_choices(point, seed) for point, seed in tqdm(inputs)]
)
@maybe_jit
def resample(samples, choices):
return xp.asarray([samples[ii, choice] for ii, choice in enumerate(choices)])
observed_dataset = dict()
for key in samples:
observed_dataset[key] = to_numpy(
xp.asarray([resample(samples[key], choices) for choices in all_choices])
)
return observed_dataset
[docs]
def resample_injections_per_population_sample(posterior, data, model, n_draws):
"""
Resample the input data with a fiducial prior to the population
informed distribution. This returns a fixed number of samples for
each passed hyperparameter sample.
This is designed to be used with found injections.
See, e.g., section IIIC of `Moore and Gerosa <https://arxiv.org/abs/2108.02462>`_
for a description of the method.
Parameters
----------
posterior: pd.DataFrame
Hyper-parameter samples to use for the reweighting.
data: dict
Input data samples to be reweighted.
model: bilby.hyper.model.Model
Object that implements a `prob` method that will calculate the population
probability.
n_draws: int
The number of samples to draw. This should generally be the number of events.
This will return one sample per input event.
Returns
-------
observed_dataset: dict
The observed dataset of the input data with the population informed prior.
"""
from .utils import maybe_jit
@maybe_jit
def extract_choices(parameters, seed):
model.parameters.update(parameters)
if hasattr(model, "cosmology_names"):
samples, jacobian = model.detector_frame_to_source_frame(data, **parameters)
else:
samples = data
jacobian = 1
weights = model.prob(samples) / samples["prior"] / jacobian
weights /= xp.sum(weights)
if "jax" in xp.__name__:
choices = choice(
PRNGKey(seed),
len(weights),
p=weights,
shape=(n_draws,),
replace=False,
)
else:
weights = to_numpy(weights)
choices = choice(len(weights), p=weights, size=n_draws, replace=False)
return choices
if "jax" in xp.__name__:
from jax.random import PRNGKey, choice
seeds = np.random.choice(100000, len(posterior))
else:
# FIXME: cupy.random.choice doesn't support p != None and replace=False
# cupy==13.0.0 (240202)
from numpy.random import choice
seeds = [None] * len(posterior)
points = posterior.to_dict(orient="records")
inputs = list(zip(points, seeds))
all_choices = to_numpy(
xp.asarray([extract_choices(point, seed) for point, seed in tqdm(inputs)])
)
@maybe_jit
def resample(data, choices):
"""
For some reason even this simple function significantly benefits from jit
"""
return data[choices]
synthetic_dataset = dict()
for key in data:
if not isinstance(data[key], xp.ndarray):
continue
synthetic_dataset[key] = xp.asarray(
[resample(data[key], choices) for choices in all_choices]
)
return synthetic_dataset
[docs]
def data_frame_to_sarray(data_frame):
"""
Convert a pandas DataFrame object to a numpy structured array.
This is functionally equivalent to but more efficient than
`np.array(df.to_array())`.
lifted from https://stackoverflow.com/questions/30773073/save-pandas-dataframe-using-h5py-for-interoperabilty-with-other-hdf5-readers
Parameters
-----------
data_frame: pd.DataFrame
the data frame to convert
Returns
-------
output: np.ndarray
a numpy structured array representation of df
"""
types = [
(data_frame.columns[index], data_frame[key].dtype.type)
for (index, key) in enumerate(data_frame.columns)
]
dtype = np.dtype(types)
output = np.zeros(data_frame.values.shape[0], dtype)
for (index, key) in enumerate(output.dtype.names):
output[key] = data_frame.values[:, index]
return output