diff --git a/projects/online/online/subprocesses/p_astro.py b/projects/online/online/subprocesses/p_astro.py index 0a3c3777..d961a4c8 100644 --- a/projects/online/online/subprocesses/p_astro.py +++ b/projects/online/online/subprocesses/p_astro.py @@ -1,4 +1,5 @@ import logging +import time from pathlib import Path from typing import TYPE_CHECKING @@ -6,6 +7,7 @@ from .utils import subprocess_wrapper +import h5py import pickle from ledger.events import EventSet, RecoveredInjectionSet @@ -19,6 +21,10 @@ logger = logging.getLogger("pastro-process") +TIMEOUT = 10 +TIMESTEP = 1e-3 +MAX_RETRIES = int(TIMEOUT / TIMESTEP) + def fit_or_load_pastro( model_path: Path, @@ -96,9 +102,51 @@ def pastro_subprocess( while True: event = pastro_queue.get() logger.info("Calculating p_astro") - pastro = pastro_model(event.detection_statistic) + pastro = float(pastro_model(event.detection_statistic)) graceid = pastro_queue.get() - logger.info(f"Submitting p_astro: {pastro} for {graceid}") - gdb.submit_pastro(float(pastro), graceid, event.event_dir) + event_dir = outdir / "events" / event.event_dir + posterior_file = event_dir / "amplfi.posterior_samples.hdf5" + + retries = 0 + while True: + try: + with h5py.File(posterior_file, "r") as f: + samples = f["posterior_samples"][:] + m1_source = samples["mass_1_source"] + m2_source = samples["mass_2_source"] + + logger.info("Read posteriors from file") + num_samples = len(m1_source) + bns_frac = sum(m1_source < 3) / num_samples + bbh_frac = sum(m2_source > 3) / num_samples + nsbh_frac = 1 - bns_frac - bbh_frac + + probs = { + "BBH": pastro * bbh_frac, + "NSBH": pastro * nsbh_frac, + "BNS": pastro * bns_frac, + "Terrestrial": 1 - pastro, + } + + break + except Exception: + time.sleep(TIMESTEP) + retries += 1 + + if retries >= MAX_RETRIES: + logging.info( + f"Posterior file not found after {TIMEOUT} seconds, " + "assigning all probability to BBH" + ) + probs = { + "BBH": pastro, + "NSBH": 0, + "BNS": 0, + "Terrestrial": 1 - pastro, + } + break + + logger.info(f"Submitting p_astro: {probs} for {graceid}") + gdb.submit_pastro(probs, graceid, event.event_dir) logger.info(f"Submitted p_astro for {graceid}") diff --git a/projects/online/online/utils/gdb.py b/projects/online/online/utils/gdb.py index d7a32aff..c5a81b5a 100644 --- a/projects/online/online/utils/gdb.py +++ b/projects/online/online/utils/gdb.py @@ -7,7 +7,6 @@ import h5py from gwpy.time import tconvert from ligo.gracedb.rest import GraceDb as _GraceDb -from ligo.em_bright import em_bright from ligo.skymap.tool.ligo_skymap_plot import main as ligo_skymap_plot from ligo.skymap.io.fits import write_sky_map from online.utils.searcher import Event @@ -207,19 +206,9 @@ def submit_low_latency_pe( posterior_samples = posterior_df.to_records(index=False) with h5py.File(filename, "w") as f: f.create_dataset("posterior_samples", data=posterior_samples) - - _, has_ns, _, _ = em_bright.source_classification_pe( - filename, num_eos_draws=10 - ) - if has_ns > 0: - self.logger.info( - f"Event {graceid} had HasNS = {has_ns}, so {filename} " - " was not uploaded." - ) - else: - self.write_log( - graceid, "posterior", filename=filename, tag_name="pe" - ) + self.logger.debug("Submitting posterior samples to GraceDB") + self.write_log(graceid, "posterior", filename=filename, tag_name="pe") + self.logger.debug("Posterior samples submitted") # update event with source parameters self.update_event(event, graceid, result) @@ -339,18 +328,13 @@ def submit_skymap_plots(self, graceid: str, event_dir: Path): # tag_name="sky_loc", # ) - def submit_pastro(self, pastro: float, graceid: str, event_dir: Path): + def submit_pastro( + self, probs: dict[str, float], graceid: str, event_dir: Path + ): event_dir = self.write_dir / event_dir fname = event_dir / "aframe.p_astro.json" - pastro = { - "BBH": pastro, - "Terrestrial": 1 - pastro, - "NSBH": 0, - "BNS": 0, - } - with open(fname, "w") as f: - json.dump(pastro, f) + json.dump(probs, f) self.write_log( graceid, @@ -371,3 +355,6 @@ def create_event(self, filename: str, **_): def write_log(self, *args, **kwargs): pass + + def replace_event(self, *args, **kwargs): + pass diff --git a/projects/online/online/utils/pe.py b/projects/online/online/utils/pe.py index 1cfee256..e769b973 100644 --- a/projects/online/online/utils/pe.py +++ b/projects/online/online/utils/pe.py @@ -162,6 +162,8 @@ def postprocess_samples( ) posterior["mass_1"] = mass_1 posterior["mass_2"] = mass_2 + posterior["mass_1_source"] = mass_1 / (1 + z_vals) + posterior["mass_2_source"] = mass_2 / (1 + z_vals) # add time column so ligo-skymap-from-samples # can add the gpstime metadata attribute posterior["time"] = np.ones_like(mass_1) * event_time diff --git a/projects/online/pyproject.toml b/projects/online/pyproject.toml index cecc6589..b88ad2ab 100644 --- a/projects/online/pyproject.toml +++ b/projects/online/pyproject.toml @@ -23,7 +23,6 @@ dependencies = [ "matplotlib==3.9.4", "ligo-skymap>=2.4.0,<3", "ligo-gracedb>=2.14.1", - "ligo-em-bright>=1.2.2", "tables>=3.9", ] diff --git a/projects/online/uv.lock b/projects/online/uv.lock index 41574095..a0cdcf26 100644 --- a/projects/online/uv.lock +++ b/projects/online/uv.lock @@ -1646,15 +1646,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/31/b4/b9b800c45527aadd64d5b442f9b932b00648617eb5d63d2c7a6587b7cafc/jmespath-1.0.1-py3-none-any.whl", hash = "sha256:02e2e4cc71b5bcab88332eebf907519190dd9e6e82107fa7f83b1003a6252980", size = 20256 }, ] -[[package]] -name = "joblib" -version = "1.5.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/dc/fe/0f5a938c54105553436dbff7a61dc4fed4b1b2c98852f8833beaf4d5968f/joblib-1.5.1.tar.gz", hash = "sha256:f4f86e351f39fe3d0d32a9f2c3d8af1ee4cec285aafcb27003dda5205576b444", size = 330475 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7d/4f/1195bbac8e0c2acc5f740661631d8d750dc38d4a32b23ee5df3cde6f4e0d/joblib-1.5.1-py3-none-any.whl", hash = "sha256:4719a31f054c7d766948dcd83e9613686b27114f190f717cec7eaa2084f8a74a", size = 307746 }, -] - [[package]] name = "json5" version = "0.12.0" @@ -2080,23 +2071,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/67/9f/c885e222ec4e50fe9f6e3067ee6cc9f8631e3b642c69db83934423043921/lightray-0.2.4-py3-none-any.whl", hash = "sha256:c0868ee362f967a86cc57b4af9578b362e51ef09c998603d7ff62218130198fe", size = 7206 }, ] -[[package]] -name = "ligo-em-bright" -version = "1.2.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "astropy" }, - { name = "h5py" }, - { name = "lalsuite" }, - { name = "numpy" }, - { name = "pandas" }, - { name = "scikit-learn" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/52/ed/9d82d8a9b7861ea9d985083224c8f41b766df822b566207779d54fab25b6/ligo_em_bright-1.2.2.tar.gz", hash = "sha256:1746b5a4f5d2a492f08faa2e215b291163a32300f429e243c725641a6c7e099c", size = 26521 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/49/c3/3e6ee53a12a75cb63081835651b9a3fe695d5d7fd669bc9737337dcf98b6/ligo_em_bright-1.2.2-py3-none-any.whl", hash = "sha256:1c3ef50b9d537019d9b6bdd2f60322f8ded26d598034d6d453b2b2246434dd64", size = 30272 }, -] - [[package]] name = "ligo-gracedb" version = "2.14.2" @@ -2787,7 +2761,6 @@ dependencies = [ { name = "architectures" }, { name = "arrakis" }, { name = "ledger" }, - { name = "ligo-em-bright" }, { name = "ligo-gracedb" }, { name = "ligo-skymap" }, { name = "matplotlib" }, @@ -2812,7 +2785,6 @@ requires-dist = [ { name = "architectures", editable = "../../libs/architectures" }, { name = "arrakis", specifier = ">=0.2.0,<0.3" }, { name = "ledger", editable = "../../libs/ledger" }, - { name = "ligo-em-bright", specifier = ">=1.2.2" }, { name = "ligo-gracedb", specifier = ">=2.14.1" }, { name = "ligo-skymap", specifier = ">=2.4.0,<3" }, { name = "matplotlib", specifier = "==3.9.4" }, @@ -3949,30 +3921,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/26/2a/b5b09ff2781c7ea694b337f65f52a00319396578c049af8890ff9b4b8232/safe_netrc-1.0.1-py3-none-any.whl", hash = "sha256:5f0dd6a5e304b1da3be220f15efedbf09e50779fe90462143c228c781b9d8218", size = 10891 }, ] -[[package]] -name = "scikit-learn" -version = "1.5.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "joblib" }, - { name = "numpy" }, - { name = "scipy" }, - { name = "threadpoolctl" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/92/72/2961b9874a9ddf2b0f95f329d4e67f67c3301c1d88ba5e239ff25661bb85/scikit_learn-1.5.1.tar.gz", hash = "sha256:0ea5d40c0e3951df445721927448755d3fe1d80833b0b7308ebff5d2a45e6414", size = 6958368 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/03/86/ab9f95e338c5ef5b4e79463ee91e55aae553213835e59bf038bc0cc21bf8/scikit_learn-1.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:154297ee43c0b83af12464adeab378dee2d0a700ccd03979e2b821e7dd7cc1c2", size = 12087598 }, - { url = "https://files.pythonhosted.org/packages/7d/d7/fb80c63062b60b1fa5dcb2d4dd3a4e83bd8c68cdc83cf6ff8c016228f184/scikit_learn-1.5.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:b5e865e9bd59396220de49cb4a57b17016256637c61b4c5cc81aaf16bc123bbe", size = 10979067 }, - { url = "https://files.pythonhosted.org/packages/c1/f8/fd3fa610cac686952d8c78b8b44cf5263c6c03885bd8e5d5819c684b44e8/scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:909144d50f367a513cee6090873ae582dba019cb3fca063b38054fa42704c3a4", size = 12485469 }, - { url = "https://files.pythonhosted.org/packages/32/63/ed228892adad313aab0d0f9261241e7bf1efe36730a2788ad424bcad00ca/scikit_learn-1.5.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:689b6f74b2c880276e365fe84fe4f1befd6a774f016339c65655eaff12e10cbf", size = 13335048 }, - { url = "https://files.pythonhosted.org/packages/5d/55/0403bf2031250ac982c8053397889fbc5a3a2b3798b913dae4f51c3af6a4/scikit_learn-1.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:9a07f90846313a7639af6a019d849ff72baadfa4c74c778821ae0fad07b7275b", size = 10988436 }, - { url = "https://files.pythonhosted.org/packages/b1/8d/cf392a56e24627093a467642c8b9263052372131359b570df29aaf4811ab/scikit_learn-1.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:5944ce1faada31c55fb2ba20a5346b88e36811aab504ccafb9f0339e9f780395", size = 12102404 }, - { url = "https://files.pythonhosted.org/packages/d5/2c/734fc9269bdb6768905ac41b82d75264b26925b1e462f4ebf45fe4f17646/scikit_learn-1.5.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:0828673c5b520e879f2af6a9e99eee0eefea69a2188be1ca68a6121b809055c1", size = 11037398 }, - { url = "https://files.pythonhosted.org/packages/d3/a9/15774b178bcd1cde1c470adbdb554e1504dce7c302e02ff736c90d65e014/scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:508907e5f81390e16d754e8815f7497e52139162fd69c4fdbd2dfa5d6cc88915", size = 12089887 }, - { url = "https://files.pythonhosted.org/packages/8a/5d/047cde25131eef3a38d03317fa7d25d6f60ce6e8ccfd24ac88b3e309fc00/scikit_learn-1.5.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:97625f217c5c0c5d0505fa2af28ae424bd37949bb2f16ace3ff5f2f81fb4498b", size = 13079093 }, - { url = "https://files.pythonhosted.org/packages/cb/be/dec2a8d31d133034a8ec51ae68ac564ec9bde1c78a64551f1438c3690b9e/scikit_learn-1.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:da3f404e9e284d2b0a157e1b56b6566a34eb2798205cba35a211df3296ab7a74", size = 10945350 }, -] - [[package]] name = "scipy" version = "1.14.1" @@ -4241,15 +4189,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154 }, ] -[[package]] -name = "threadpoolctl" -version = "3.6.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b7/4d/08c89e34946fce2aec4fbb45c9016efd5f4d7f24af8e5d93296e935631d8/threadpoolctl-3.6.0.tar.gz", hash = "sha256:8ab8b4aa3491d812b623328249fab5302a68d2d71745c8a4c719a2fcaba9f44e", size = 21274 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/32/d5/f9a850d79b0851d1d4ef6456097579a9005b31fea68726a4ae5f2d82ddd9/threadpoolctl-3.6.0-py3-none-any.whl", hash = "sha256:43a0b8fd5a2928500110039e43a5eed8480b918967083ea48dc3ab9f13c4a7fb", size = 18638 }, -] - [[package]] name = "tinycss2" version = "1.4.0"