From 7fd2a6867fdcf40a61481554f6d7b47e48ae295b Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 2 Jul 2025 09:46:59 -0700 Subject: [PATCH 01/32] Merge training waveforms into a single file --- aframe/pipelines/sandbox/configs/base.cfg | 2 +- aframe/tasks/data/waveforms/training.py | 64 +++++++++++++++++++++-- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/aframe/pipelines/sandbox/configs/base.cfg b/aframe/pipelines/sandbox/configs/base.cfg index e7cd32ba7..4d311478b 100644 --- a/aframe/pipelines/sandbox/configs/base.cfg +++ b/aframe/pipelines/sandbox/configs/base.cfg @@ -23,7 +23,7 @@ seed = 1122 streams_per_gpu = 6 # waveform parameters -waveform_approximant = IMRPhenomXPHM +waveform_approximant = IMRPhenomPv2 waveform_duration = 8 minimum_frequency = 20 reference_frequency = 50 diff --git a/aframe/tasks/data/waveforms/training.py b/aframe/tasks/data/waveforms/training.py index 092500df4..b9f9d66a4 100644 --- a/aframe/tasks/data/waveforms/training.py +++ b/aframe/tasks/data/waveforms/training.py @@ -1,3 +1,6 @@ +from pathlib import Path +import shutil + import law from luigi.util import inherits @@ -10,7 +13,7 @@ @inherits(WaveformParams) -class TrainingWaveforms( +class DeployTrainingWaveforms( AframeDataTask, DeployTask, law.LocalWorkflow, StaticMemoryWorkflow ): """ @@ -24,14 +27,20 @@ class TrainingWaveforms( output_dir = PathParameter( description="Directory where merged training waveforms will be saved", - default=paths().train_datadir / "training_waveforms", + default=paths().train_datadir, + ) + + tmp_dir = PathParameter( + description="Directory where temporary " + "waveforms will be saved before being merged", + default=paths().tmp_dir / "training_waveforms", ) def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) def output(self): - return s3_or_local(self.output_dir / f"waveforms-{self.branch}.hdf5") + return s3_or_local(self.tmp_dir / f"waveforms-{self.branch}.hdf5") def run(self): from data.waveforms.training import training_waveforms @@ -52,3 +61,52 @@ def run(self): chunks = (min(64, num_signals), waveforms.get_waveforms().shape[-1]) with self.output().open("w") as f: waveforms.write(f, chunks=chunks) + + +@inherits(DeployTrainingWaveforms) +class TrainingWaveforms(AframeDataTask): + """ + Launch condorized generation of validation waveforms via + rejection sampling, and merge results into a single file + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.output_file = self.output_dir / "training_waveforms.hdf5" + + def output(self): + return s3_or_local(self.output_file) + + def requires(self): + return DeployTrainingWaveforms.req( + self, + workflow=self.workflow, + request_memory=self.request_memory, + request_disk=self.request_disk, + request_cpus=self.request_cpus, + ) + + @property + def targets(self): + return list(self.input().collection.targets.values()) + + @property + def waveform_files(self): + return list(map(Path, [targets.path for targets in self.targets])) + + def run(self): + from ledger.injections import ( + WaveformPolarizationSet, + waveform_class_factory, + ) + + cls = waveform_class_factory( + ["cross", "plus"], + WaveformPolarizationSet, + "WaveformPolarizationSet", + ) + with self.output().open("w") as f: + cls.aggregate(self.waveform_files, f, clean=True) + + # clean up temporary directory + shutil.rmtree(self.tmp_dir) From 31d28d1e5fb8d0c633a1a136e180276860b7de8d Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 2 Jul 2025 09:50:22 -0700 Subject: [PATCH 02/32] Update train task for optional realtime waveforms --- aframe/tasks/train/base.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/aframe/tasks/train/base.py b/aframe/tasks/train/base.py index 49a7c3fea..1d8aa969a 100644 --- a/aframe/tasks/train/base.py +++ b/aframe/tasks/train/base.py @@ -57,11 +57,20 @@ class TrainBaseParameters(law.Task): ) data_dir = PathParameter( description="Directory where training data is stored." - "It is expected to contain a `signals.hdf5` file of signals, " - "and a `/background` sub-directory containing background " - "files used for training", + "It is expected to contain a `val_waveforms.hdf5` file of " + "signals for validation, a `/background` sub-directory containing " + "background, and a `train_waveforms.hdf5` file containing " + "training signals if `generate_train_waveforms` is set to False.", default=paths().train_datadir, ) + precompute_train_waveforms = luigi.BoolParameter( + default=False, + description="Whether to pre-compute the waveforms used " + "during training. If True, the training waveforms will be " + "read from the `train_waveforms.hdf5` file in the data " + "directory. If False, the waveforms will be simulated " + "on-the-fly during training.", + ) @inherits(TrainBaseParameters) @@ -69,9 +78,9 @@ class TrainBase(law.Task): def requires(self): reqs = {} reqs["strain"] = FetchTrain.req(self) - reqs["train_waveforms"] = TrainingWaveforms.req(self) - reqs["val_waveforms"] = ValidationWaveforms.req(self) + if self.precompute_train_waveforms: + reqs["train_waveforms"] = TrainingWaveforms.req(self) return reqs def configure_wandb(self, args: List[str]) -> None: From 661a849a5f9b0c9823efe8ce28ff473fcabde7ea Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 2 Jul 2025 09:54:22 -0700 Subject: [PATCH 03/32] Update configs and apptainer.def --- projects/train/apptainer.def | 1 + projects/train/configs/bbh.yaml | 1 - projects/train/configs/bns.yaml | 1 - projects/train/pyproject.toml | 2 ++ projects/train/train.yaml | 28 +++++++++++++++++++++++++--- 5 files changed, 28 insertions(+), 5 deletions(-) diff --git a/projects/train/apptainer.def b/projects/train/apptainer.def index ce8b0dd2f..22b95bb6d 100644 --- a/projects/train/apptainer.def +++ b/projects/train/apptainer.def @@ -10,6 +10,7 @@ From: ghcr.io/astral-sh/uv:0.6.10-python3.10-bookworm-slim ../../aframe /opt/aframe/aframe ../../pyproject.toml /opt/aframe/pyproject.toml ../../libs/ledger /opt/aframe/libs/ledger +../../libs/prior /opt/aframe/libs/prior ../../libs/architectures /opt/aframe/libs/architectures %post diff --git a/projects/train/configs/bbh.yaml b/projects/train/configs/bbh.yaml index 350dccd7d..28ca602b3 100644 --- a/projects/train/configs/bbh.yaml +++ b/projects/train/configs/bbh.yaml @@ -68,7 +68,6 @@ data: # alpha: -3 # decay_steps: 989 # validation args - valid_frac: 0.25 valid_stride: 0.5 num_valid_views: 5 valid_livetime: 57600 diff --git a/projects/train/configs/bns.yaml b/projects/train/configs/bns.yaml index acc773d30..f740e3089 100644 --- a/projects/train/configs/bns.yaml +++ b/projects/train/configs/bns.yaml @@ -72,7 +72,6 @@ data: # alpha: -3 # decay_steps: 989 # validation args - valid_frac: 0.25 valid_stride: 0.5 num_valid_views: 5 valid_livetime: 57600 diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index b7492391d..c01346f3e 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -25,6 +25,7 @@ dependencies = [ "ml4gw>=0.7.2", "aframe", "ledger", + "priors", "architectures", "bayesian-optimization>=1.4.3,<2", "filelock>=3.13.1,<4", @@ -49,6 +50,7 @@ explicit = true utils = { path = "../../libs/utils", editable = true } aframe = { path = "../..", editable = true } ledger = { path = "../../libs/ledger", editable = true } +priors = { path = "../../libs/priors", editable = true } architectures = { path = "../../libs/architectures", editable = true } [build-system] diff --git a/projects/train/train.yaml b/projects/train/train.yaml index 9f7f422d7..06182c1d2 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -38,8 +38,6 @@ data: batch_size: 384 batches_per_epoch: 3700 num_files_per_batch: 10 - chunk_size: 10000 - chunks_per_epoch: 10 # kernel_length: psd_length: 8 # fduration: @@ -68,8 +66,32 @@ data: # max_snr: 100 # alpha: -3 # decay_steps: 989 + waveform_sampler: + class_path: train.data.waveforms.generator.cbc.CBCGenerator + init_args: + training_prior: priors.priors.end_o3_ratesandpops + val_waveform_file: ${oc.env:AFRAME_DATADIR}/train/val_waveforms.hdf5 + approximant: ml4gw.waveforms.IMRPhenomPv2 + f_min: 20 + f_ref: 40 + right_pad: 0.5 + kernel_length: 8 + # Extrinsic parameter distributions + dec: + class_path: ml4gw.distributions.Cosine + psi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 3.14159 + validate_args: false + phi: + class_path: torch.distributions.Uniform + init_args: + low: 0 + high: 6.28318 + validate_args: false # validation args - valid_frac: 0.25 valid_stride: 0.5 num_valid_views: 5 valid_livetime: 57600 From 43b575fd26d2f7e7781a6f2ce2c27e445c088638 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 2 Jul 2025 09:58:48 -0700 Subject: [PATCH 04/32] Update cli and callbacks --- projects/train/train/callbacks.py | 14 +++++++------- projects/train/train/cli.py | 13 +++++++++++++ 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index b23a42330..f7ed1c056 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -43,7 +43,7 @@ def on_train_end(self, trainer, pl_module): ) device = pl_module.device - [X], waveforms = next(iter(trainer.train_dataloader)) + [X] = next(iter(trainer.train_dataloader)) X = X.to(device) X, y = trainer.datamodule.augment(X, waveforms) if isinstance(X, tuple): @@ -75,8 +75,7 @@ def on_train_start(self, trainer, pl_module): save_dir = trainer.logger.save_dir # build training batch by hand - [X], waveforms = next(iter(trainer.train_dataloader)) - waveforms = trainer.datamodule.slice_waveforms(waveforms) + [X] = next(iter(trainer.train_dataloader)) X = X.to(device) X, y = trainer.datamodule.augment(X, waveforms) @@ -86,13 +85,14 @@ def on_train_start(self, trainer, pl_module): X = (X,) # build val batch by hand - [background, _, _], [signals] = next( + [background, _, _], [cross, plus] = next( iter(trainer.datamodule.val_dataloader()) ) background = background.to(device) - signals = signals.to(device) - X_bg, X_inj = trainer.datamodule.build_val_batches( - background, signals + cross = cross.to(device) + plus = plus.to(device) + X_bg, X_inj, _ = trainer.datamodule.build_val_batches( + background, cross, plus ) # Make background and injected validation data into # tuples for consistency if necessary diff --git a/projects/train/train/cli.py b/projects/train/train/cli.py index 2d96632f3..4d9a035a0 100644 --- a/projects/train/train/cli.py +++ b/projects/train/train/cli.py @@ -9,6 +9,7 @@ class AframeCLI(LightningCLI): def __init__(self, *args, **kwargs): # hack into init to hardcode # the WandbSaveConfig callback + kwargs["parser_kwargs"] = {"parser_mode": "omegaconf"} kwargs["save_config_callback"] = WandbSaveConfig super().__init__(*args, **kwargs) @@ -38,6 +39,18 @@ def add_arguments_to_parser(self, parser): "model.init_args.metric.init_args.stride", ) + parser.link_arguments( + "data.init_args.sample_rate", + "data.init_args.waveform_sampler.init_args.sample_rate", + apply_on="parse", + ) + + parser.link_arguments( + "data.init_args.fduration", + "data.init_args.waveform_sampler.init_args.fduration", + apply_on="parse", + ) + def main(args=None): cli = AframeCLI( From f5706df485bcea8e0e81d55e3a00b1b4331d497f Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 2 Jul 2025 10:15:22 -0700 Subject: [PATCH 05/32] Update datasets --- projects/train/train/data/base.py | 280 +++++++++--------- .../train/data/supervised/frequency_domain.py | 8 +- .../train/train/data/supervised/supervised.py | 17 +- .../train/data/supervised/time_domain.py | 8 +- 4 files changed, 167 insertions(+), 146 deletions(-) diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index b7b27b37f..cd1cd7e78 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -18,17 +18,18 @@ from train import augmentations as aug from train.data.utils import fs as fs_utils from train.metrics import get_timeslides -from train.waveform_sampler import ( - ChunkedWaveformDataset, - Hdf5WaveformLoader, - WaveformSampler, -) +from train.data.waveforms.sampler import WaveformSampler from utils import x_per_y from utils.preprocessing import PsdEstimator Tensor = torch.Tensor +Distribution = torch.distributions.Distribution TransformedDist = torch.distributions.TransformedDistribution +# TODO: +# Make the coalescence point/right pad parameterization consistent +# Ensure coalescence point placement is the same for training and validation + # TODO: using this right now because # lightning.pytorch.utilities.CombinedLoader @@ -56,15 +57,97 @@ def __iter__(self): class BaseAframeDataset(pl.LightningDataModule): + """ + Base LightningDataModule for loading data to train Aframe models. + Subclasses must define the `inject` method + which encodes how background strain, + cross/plus polarizations and parameters + are processed before being passed to a model + + Args: + data_dir: + Path to the directory containing the training data. + If this is a s3 bucket, it will be downloaded + to a local directory. + ifos: + List of interferometers to use for training. + sample_rate: + Sample rate in Hz of the training data. + dec: + The distribution of declinations to sample from + psi: + The distribution of polarization angles to sample from + phi: + The distribution of "right ascensions" to sample from + batches_per_epoch: + Number of batches per epoch. + num_files_per_batch: + Number of background files to sample from each batch. + waveform_sampler: + `WaveformSampler` object that produces waveforms and parameters + for training and validation. See `train.data.waveforms.sampler` + for methods this object should define. + max_num_workers: + Maximum number of workers to assign to each + training dataloader. + batch_size: + Batch size for training. + kernel_length: + Length in seconds of the time-domain kernel + passed to the model after whitening. + fduration: + The length of the whitening filter's impulse + response, in seconds. `fduration / 2` seconds + worth of data will be cropped from the edges + of the whitened timeseries. + psd_length: + Length in seconds of the PSD used for whitening. + waveform_prob: + Probability that a batch element will contain a waveform. + left_pad: + Length in seconds of the minimum gap between the left + edge of whitened kernel and coalescence point of the signal. + right_pad: + Length in seconds of the minimum gap between the right + edge of whitened kernel and coalescence point of the signal. + fftlength: + Length in seconds of the FFT used for PSD estimation. + If `None`, will use the length of the kernel + fduration. + highpass: + Highpass filter frequency in Hz. + lowpass: + Lowpass filter frequency in Hz. + snr_sampler: + A callable that samples SNRs for the injected signals. + If `None`, SNRs will be left unchanged. + valid_stride: + Stride in seconds for the validation timeslides. + If `None`, will use `kernel_length + fduration`. + num_valid_views: + Number of views to inject into the validation data. + Each view will contain the same signal at a different + position in the background. + min_valid_duration: + Minimum duration in seconds of the validation data. + valid_livetime: + Total livetime in seconds of the validation data + to be generated via timeslides. + verbose: + Whether to log debug information during training. + """ + def __init__( self, # data loading args data_dir: str, ifos: Sequence[str], sample_rate: float, - valid_frac: float, + dec: Distribution, + psi: Distribution, + phi: Distribution, batches_per_epoch: int, num_files_per_batch: int, + waveform_sampler: WaveformSampler, # preprocessing args batch_size: int, kernel_length: float, @@ -72,8 +155,6 @@ def __init__( psd_length: float, # augmentation args waveform_prob: float = 1, - max_snr: float = 100, - snr_alpha: float = 3, left_pad: float = 0, right_pad: float = 0, fftlength: Optional[float] = None, @@ -87,15 +168,14 @@ def __init__( num_valid_views: int = 4, min_valid_duration: float = 15000, valid_livetime: float = (3600 * 12), + max_num_workers: int = 6, verbose: bool = False, - # waveform dataloader args - chunks_per_epoch: int = 1, - chunk_size: int = 10000, ) -> None: super().__init__() self.init_logging(verbose) self.num_ifos = len(ifos) - self.save_hyperparameters() + self.max_num_workers = max_num_workers + self.save_hyperparameters(ignore=["waveform_sampler"]) # Set up some of our data augmentation modules self.inverter = SignalInverter(0.5) @@ -105,11 +185,13 @@ def __init__( # downloaded first, either for loading signals # or to infer sample rate, so wait to construct # them until self.setup - self.waveform_sampler = None self.whitener = None self.projector = None self.psd_estimator = None self._on_device = False + + self.dec, self.psi, self.phi = dec, psi, phi + self.waveform_sampler = waveform_sampler self.snr_sampler = snr_sampler # generate our local node data directory @@ -198,17 +280,6 @@ def right_pad_size(self) -> int: """ return int(self.hparams.right_pad * self.hparams.sample_rate) - @property - def train_waveform_fnames(self) -> Sequence[str]: - data_dir = os.path.join(self.data_dir, "training_waveforms") - fnames = glob.glob(f"{data_dir}/waveforms*.hdf5") - return list(fnames) - - @property - def signal_time(self): - with h5py.File(self.train_waveform_fnames[0], "r") as f: - return f.attrs["coalescence_time"] - def train_val_split(self) -> tuple[Sequence[str], Sequence[str]]: fnames = glob.glob(f"{self.data_dir}/background/*.hdf5") fnames = sorted([Path(fname) for fname in fnames]) @@ -228,6 +299,13 @@ def val_batch_size(self): """Use larger batch sizes when we don't need gradients.""" return int(1 * self.hparams.batch_size) + @property + def num_workers(self): + local_world_size = len(self.trainer.device_ids) + return min( + self.max_num_workers, int(os.cpu_count() / local_world_size) + ) + # ================================================ # # Utilities for initial data loading and preparation # ================================================ # @@ -261,7 +339,9 @@ def slice_waveforms(self, waveforms: torch.Tensor) -> torch.Tensor: Slice waveforms to the correct length depending on requested left and right padding """ - signal_idx = int(self.signal_time * self.hparams.sample_rate) + signal_idx = waveforms.shape[-1] - int( + self.waveform_sampler.right_pad * self.hparams.sample_rate + ) kernel_size = int( self.hparams.kernel_length * self.hparams.sample_rate ) @@ -298,27 +378,6 @@ def get_slice_bounds(self, total, world_size, rank) -> tuple[int, int]: stop = (rank + 1) * per_dev return start, stop - def load_val_waveforms(self, f, world_size, rank): - waveform_set = self.waveform_set_cls.read(f) - - if waveform_set.coalescence_time != self.signal_time: - raise ValueError( - "Training waveforms and validation waveforms have different " - f"signal times, got {self.signal_time} and " - f"{waveform_set.coalescence_time}, respectively" - ) - - length = len(waveform_set.waveforms) - - if not rank: - self._logger.info(f"Validating on {length} waveforms") - stop, start = self.get_slice_bounds(length, world_size, rank) - - self._logger.info(f"Loading {start - stop} validation signals") - start, stop = -start, -stop or None - waveforms = torch.Tensor(waveform_set.waveforms[start:stop]) - return waveforms - def load_val_background(self, fnames: list[str]): self._logger.info("Loading validation background data") val_background = [] @@ -366,6 +425,16 @@ def build_transforms(self): self.hparams.lowpass, ) + def sample_extrinsic(self, X: torch.Tensor): + """ + Sample extrinsic parameters used to project waveforms + """ + N = len(X) + dec = self.dec.sample((N,)).to(X.device) + psi = self.psi.sample((N,)).to(X.device) + phi = self.phi.sample((N,)).to(X.device) + return dec, psi, phi + def setup(self, stage: str) -> None: world_size, rank = self.get_world_size_and_rank() self._logger = self.get_logger(world_size, rank) @@ -381,12 +450,6 @@ def setup(self, stage: str) -> None: self._logger.info(f"Validated sample rate {sample_rate}") - # now define some of the augmentation transforms - # that require sample rate information - self._logger.info("Constructing sample rate dependent transforms") - self.build_transforms() - self.transforms_to_device() - # load in our validation background up front and # compute which timeslides we'll do on this device # if we're doing distributed training so we'll know @@ -406,14 +469,20 @@ def setup(self, stage: str) -> None: self.val_batch_size, ) - self.waveform_sampler = WaveformSampler() - - val_waveform_file = os.path.join(self.data_dir, "val_waveforms.hdf5") - self.val_waveforms = self.load_val_waveforms( - val_waveform_file, world_size, rank + self.val_waveforms = self.waveform_sampler.get_val_waveforms( + world_size, rank + ) + self.waveform_sampler.get_train_waveforms( + world_size, rank, self.device ) self._logger.info("Initial dataloading complete") + # now define some of the augmentation transforms + # that require sample rate information + self._logger.info("Constructing sample rate dependent transforms") + self.build_transforms() + self.transforms_to_device() + # ================================================ # # Utilities for doing augmentation/preprocessing # after tensors have been transferred to GPU @@ -423,19 +492,6 @@ def device(self): """Return the device of the associated lightning module""" return self.trainer.lightning_module.device - def on_before_batch_transfer(self, batch, _): - """ - Slice loaded waveforms before sending to device - """ - # TODO: maybe pass indices as argument to - # waveform loader to reduce quantity of data - # we need to load - if self.trainer.training: - X, waveforms = batch - waveforms = self.slice_waveforms(waveforms) - batch = X, waveforms - return batch - def on_after_batch_transfer(self, batch, _): """ This is a method inherited from the DataModule @@ -447,14 +503,14 @@ def on_after_batch_transfer(self, batch, _): if self.trainer.training: # if we're training, perform random augmentations # on input data and use it to impact labels - [X], waveforms = batch - batch = self.augment(X, waveforms) + [batch] = batch + batch = self.inject(batch) elif self.trainer.validating or self.trainer.sanity_checking: # If we're in validation mode but we're not validating # on the local device, the relevant tensors will be # empty, so just pass them through with a 0 shift to # indicate that this should be ignored - [background, _, timeslide_idx], [signals] = batch + [background, _, timeslide_idx], [cross, plus] = batch # If we're validating, unfold the background # data into a batch of overlapping kernels now that @@ -462,12 +518,12 @@ def on_after_batch_transfer(self, batch, _): # much data from CPU to GPU. Once everything is # on-device, pre-inject signals into background. shift = self.timeslides[timeslide_idx].shift_size - X_bg, X_fg = self.build_val_batches(background, signals) + X_bg, X_fg = self.build_val_batches(background, cross, plus) batch = (shift, X_bg, X_fg) return batch @torch.no_grad() - def augment(self, X): + def inject(self, X): """ Override this in child classes to define application-specific augmentations @@ -479,7 +535,7 @@ def augment(self, X): # ================================================ # @torch.no_grad() def build_val_batches( - self, background: Tensor, signals: Tensor + self, background: Tensor, cross: Tensor, plus: Tensor ) -> tuple[Tensor, Tensor, Tensor]: """ Unfold a timeseries of background data @@ -489,7 +545,8 @@ def build_val_batches( Args: background: A tensor of background data - signals: A tensor of signals to inject + cross: A tensor of cross polarization waveforms + plus: A tensor of plus polarization waveforms Returns: raw strain background kernels, injected kernels, and psds @@ -502,6 +559,11 @@ def build_val_batches( # split data into kernel and psd data and estimate psd X, psd = self.psd_estimator(background) + + # Sample sky locations and project polarizations + dec, psi, phi = self.sample_extrinsic(cross) + signals = self.projector(dec, psi, phi, cross=cross, plus=plus) + # sometimes at the end of a segment, there won't be # enough background kernels and so we'll have to inject # our signals on overlapping data and ditch some at the end @@ -516,7 +578,9 @@ def build_val_batches( # the background, each showing a different, overlapping # portion of the signal kernel_size = X.size(-1) - signal_idx = int(self.signal_time * self.hparams.sample_rate) + signal_idx = signals.shape[-1] - int( + self.waveform_sampler.right_pad * self.hparams.sample_rate + ) max_start = int( signal_idx - self.left_pad_size - self.filter_size // 2 ) @@ -556,11 +620,10 @@ def val_dataloader(self) -> ZippedDataset: # we're going to go through, then batch the # signals so that they're spaced evenly # throughout all those batches. - num_waveforms = len(self.val_waveforms) + cross, plus = self.val_waveforms + num_waveforms = len(cross) signal_batch_size = (num_waveforms - 1) // self.valid_loader_length + 1 - - # self._logger.info(f"signal batch size: {signal_batch_size}") - signal_dataset = torch.utils.data.TensorDataset(self.val_waveforms) + signal_dataset = torch.utils.data.TensorDataset(cross, plus) signal_loader = torch.utils.data.DataLoader( signal_dataset, batch_size=signal_batch_size, @@ -575,10 +638,6 @@ def val_dataloader(self) -> ZippedDataset: return dataset def train_dataloader(self) -> torch.utils.data.DataLoader: - # divide batches per epoch up among all devices - world_size, _ = self.get_world_size_and_rank() - batches_per_epoch = self.hparams.batches_per_epoch // world_size - # build our strain dataset and dataloader dataset = Hdf5TimeSeriesDataset( self.train_fnames, @@ -593,55 +652,12 @@ def train_dataloader(self) -> torch.utils.data.DataLoader: pin_memory = isinstance( self.trainer.accelerator, pl.accelerators.CUDAAccelerator ) - # multiprocess data loading - local_world_size = len(self.trainer.device_ids) - num_workers = min(6, int(os.cpu_count() / local_world_size)) self._logger.debug( - f"Using {num_workers} workers for strain data loading" + f"Using {self.num_workers} workers for strain data loading" ) dataloader = torch.utils.data.DataLoader( dataset, - num_workers=num_workers, + num_workers=self.num_workers, pin_memory=pin_memory, ) - - # build iterator for waveform loading - # that will load chunks of waveforms - # to be sampled from - waveform_loader = Hdf5WaveformLoader( - self.train_waveform_fnames, - batch_size=self.hparams.chunk_size, - batches_per_epoch=self.hparams.chunks_per_epoch or 1, - channels=["cross", "plus"], - path="waveforms", - ) - # calculate how many batches we'll sample from each chunk - # based on requested chunks per epoch and batches per epoch - batches_per_chunk = ( - int(batches_per_epoch // self.hparams.chunks_per_epoch) + 1 - ) - self._logger.info( - f"Training on pool of {waveform_loader.total} waveforms. " - f"Sampling {batches_per_chunk} batches per chunk " - f"from {self.hparams.chunks_per_epoch} chunks " - f"of size {self.hparams.chunk_size} each epoch" - ) - - # multiprocess waveform chunk loader - # so we don't have to wait for waveforms - waveform_loader = torch.utils.data.DataLoader( - waveform_loader, - num_workers=2, - pin_memory=pin_memory, - persistent_workers=True, - ) - - # build a dataset that will sample from - # iterator of chunks of waveforms - waveform_dataset = ChunkedWaveformDataset( - waveform_loader, - batch_size=self.hparams.batch_size, - batches_per_chunk=batches_per_chunk, - ) - - return ZippedDataset(dataloader, waveform_dataset) + return dataloader diff --git a/projects/train/train/data/supervised/frequency_domain.py b/projects/train/train/data/supervised/frequency_domain.py index 80a187e4a..54bba8ed3 100644 --- a/projects/train/train/data/supervised/frequency_domain.py +++ b/projects/train/train/data/supervised/frequency_domain.py @@ -23,8 +23,8 @@ def build_transforms(self, *args, **kwargs): spectrogram_shape=self.spectrogram_shape, ) - def augment(self, X): - X, y, psds = super().augment(X) + def inject(self, X): + X, y, psds = super().inject(X) X = self.whitener(X, psds) X = self.qtransform(X) return X, y @@ -117,8 +117,8 @@ def build_val_batches(self, *args, **kwargs): return X_bg, X_inj - def augment(self, X): - X, y, psds = super().augment(X) + def inject(self, X): + X, y, psds = super().inject(X) # fft whiten and bandpass in frequency domain X = self.whiten(X, psds) diff --git a/projects/train/train/data/supervised/supervised.py b/projects/train/train/data/supervised/supervised.py index 279d9d2f4..e5904a7da 100644 --- a/projects/train/train/data/supervised/supervised.py +++ b/projects/train/train/data/supervised/supervised.py @@ -43,19 +43,24 @@ def sample_prob(self): return self.hparams.waveform_prob + self.swap_prob + self.mute_prob @torch.no_grad() - def augment(self, X, waveforms): + def inject(self, X): X, psds = self.psd_estimator(X) X = self.inverter(X) X = self.reverser(X) # sample enough waveforms to do true injections, # swapping, and muting - *params, polarizations, mask = self.waveform_sampler( - X, self.sample_prob, waveforms + rvs = torch.rand(size=X.shape[:1], device=X.device) + mask = rvs < self.sample_prob + + dec, psi, phi = self.sample_extrinsic(X[mask]) + hc, hp = self.waveform_sampler.sample(X[mask]) + + snrs = self.snr_sampler.sample((mask.sum().item(),)).to(X.device) + responses = self.projector( + dec, psi, phi, snrs, psds[mask], cross=hc, plus=hp ) - N = len(params[0]) - snrs = self.snr_sampler.sample((N,)).to(X.device) - responses = self.projector(*params, snrs, psds[mask], **polarizations) + responses = self.slice_waveforms(responses) kernels = sample_kernels( responses, kernel_size=X.size(-1), coincident=True ) diff --git a/projects/train/train/data/supervised/time_domain.py b/projects/train/train/data/supervised/time_domain.py index 73dce48ec..ddec2a4c3 100644 --- a/projects/train/train/data/supervised/time_domain.py +++ b/projects/train/train/data/supervised/time_domain.py @@ -4,8 +4,8 @@ class TimeDomainSupervisedAframeDataset(SupervisedAframeDataset): - def build_val_batches(self, background, signals): - X_bg, X_inj, psds = super().build_val_batches(background, signals) + def build_val_batches(self, background, cross, plus): + X_bg, X_inj, psds = super().build_val_batches(background, cross, plus) X_bg = self.whitener(X_bg, psds) # whiten each view of injections X_fg = [] @@ -16,7 +16,7 @@ def build_val_batches(self, background, signals): X_fg = torch.stack(X_fg) return X_bg, X_fg - def augment(self, X, waveforms): - X, y, psds = super().augment(X, waveforms) + def inject(self, X): + X, y, psds = super().inject(X) X = self.whitener(X, psds) return X, y From 37d2d7eb0cc3b51c4166d98c307ddd194b4fd519 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 2 Jul 2025 10:16:25 -0700 Subject: [PATCH 06/32] Remove old waveform_sampler.py --- projects/train/train/waveform_sampler.py | 246 ----------------------- 1 file changed, 246 deletions(-) delete mode 100644 projects/train/train/waveform_sampler.py diff --git a/projects/train/train/waveform_sampler.py b/projects/train/train/waveform_sampler.py deleted file mode 100644 index 371847874..000000000 --- a/projects/train/train/waveform_sampler.py +++ /dev/null @@ -1,246 +0,0 @@ -import logging -import math -import warnings -from pathlib import Path -from typing import Iterable, Optional - -import h5py -import numpy as np -import torch -from ml4gw.distributions import Cosine -from torch.distributions.uniform import Uniform - - -# TODO: move to ml4gw -class Hdf5WaveformLoader(torch.utils.data.IterableDataset): - """ - Iterable dataset that loads samples of waveforms - from a set of HDF5 files. - - It is _strongly_ recommended that these files have been - written using [chunked storage] - (https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage). - This has shown to produce increases in read-time speeds - of over an order of magnitude. - - Args: - fnames: - Paths to HDF5 files from which to sample data. - channels: - Datasets to read from the indicated files, which - will be stacked along dim 1 of the generated batches - during iteration. - batch_size: - Number of samples to load at each iteration. - batches_per_epoch: - Number of batches to generate during each call - to `__iter__`. - chunk_size: - Number of samples to load from each file at a time. - This is useful for reducing I/O overhead when reading. - path: - Optional path to location of datasets in hdf5 files. - `path` should be delimited by forward slashes. If `None` - it is assumed the datasets are at the root of the file. - """ - - def __init__( - self, - fnames: Iterable[Path], - channels: Iterable[str], - batch_size: int, - batches_per_epoch: int, - chunk_size: int = 1000, - path: Optional[Path] = None, - ): - self.fnames = fnames - self.channels = channels - self.batch_size = batch_size - self.batches_per_epoch = batches_per_epoch - self.chunk_size = chunk_size - - if path is not None: - self.path = path.split("/") - else: - self.path = None - - self.sizes = {} - self.mmap_files = {} - self.mmap_datasets = {} - - # for each file store the datasets - # of interest in a dictionary so we - # can access them at will without needing - # to reopen the files each time - for fname in self.fnames: - f, g = self.open(fname) - self.mmap_files[fname] = f - self.mmap_datasets[fname] = { - channel: g[channel] for channel in self.channels - } - - # store sizes of each dataset and warn if not chunked; - # assumes all dsets have same attributes - # like size and chunking behavior - dset = self.mmap_datasets[fname][self.channels[0]] - self.sizes[fname] = len(dset) - if dset.chunks is None: - warnings.warn( - "File {} contains datasets that were generated " - "without using chunked storage. This can have " - "severe performance impacts at data loading time. " - "If you need faster loading, try re-generating " - "your datset with chunked storage turned on.".format( - fnames - ), - stacklevel=2, - ) - - self.waveform_size = dset.shape[1] - self.probs = np.array([i / self.total for i in self.sizes.values()]) - - @property - def num_channels(self): - return len(self.channels) - - @property - def chunks_per_batch(self): - return math.ceil(self.batch_size / self.chunk_size) - - @property - def total(self): - return sum(self.sizes.values()) - - def __len__(self): - return self.batches_per_epoch - - def __del__(self): - # close all opened files when the object is destroyed - for f in self.mmap_files.values(): - f.close() - - def open(self, fname) -> tuple[h5py.File, h5py.Group]: - f = group = h5py.File(fname, "r") - if self.path is not None: - for path in self.path: - group = group[path] - return f, group - - def load_chunk(self, fname, start, size): - end = min(start + size, self.sizes[fname]) - return { - channel: self.mmap_datasets[fname][channel][start:end] - for channel in self.channels - } - - def sample_batch(self): - # allocate batch up front - batch = np.zeros( - (self.batch_size, self.num_channels, self.waveform_size) - ) - - for i in range(self.chunks_per_batch): - fname = np.random.choice(self.fnames, p=self.probs) - - chunk_size = min( - self.chunk_size, self.batch_size - i * self.chunk_size - ) - - # select a random starting index for the chunk - max_start = self.sizes[fname] - chunk_size - start = np.random.randint(0, max_start + 1) - - # load the chunk and insert it into the batch - chunk = self.load_chunk(fname, start, chunk_size) - batch_start = i * self.chunk_size - batch_end = batch_start + chunk_size - - for i, channel in enumerate(self.channels): - batch[batch_start:batch_end, i, :] = chunk[channel] - - return torch.tensor(batch) - - def __iter__(self): - for _ in range(self.batches_per_epoch): - yield self.sample_batch() - - -class ChunkedWaveformDataset(torch.utils.data.IterableDataset): - """ - Wrapper dataset that will loop through chunks of timeseries - data produced by another iterable and sample subsets - of waveforms from each chunk. - - Args: - chunk_it: - Iterator which will produce batches of waveform - data to sample subsets from. Should have shape - `(N, C, T)`, where `N` is the number of waveformns - to sample from, `C` is the number of channels, - and `T` is the number of samples along the - time dimension for each waveform. - batch_size: - Number of waveforms to sample at each iteration - batches_per_chunk: - Number of batches of waveforms to sample from - each chunk before moving on to the next one. - """ - - def __init__( - self, - chunk_it: Iterable, - batch_size: int, - batches_per_chunk: int, - ) -> None: - self.logger = logging.getLogger(__name__) - self.chunk_it = chunk_it - self.batch_size = batch_size - self.batches_per_chunk = batches_per_chunk - - def __len__(self): - return len(self.chunk_it) * self.batches_per_chunk - - def __iter__(self): - it = iter(self.chunk_it) - [chunk] = next(it) - - num_waveforms, _, _ = chunk.shape - while True: - # generate batches from the current chunk - for _ in range(self.batches_per_chunk): - idx = torch.randperm(num_waveforms)[: self.batch_size] - yield chunk[idx] - - try: - [chunk] = next(it) - except StopIteration: - break - num_waveforms, _, _ = chunk.shape - - -class WaveformSampler(torch.nn.Module): - def __init__(self) -> None: - super().__init__() - self.dec = Cosine() - self.psi = Uniform(0, torch.pi) - self.phi = Uniform(-torch.pi, torch.pi) - - def forward(self, X, prob, waveforms): - # determine batch size from X and prob - rvs = torch.rand(size=X.shape[:1], device=X.device) - mask = rvs < prob - N = mask.sum().item() - - # sample sky parameters for each injections - dec = self.dec.sample((N,)).to(X.device) - psi = self.psi.sample((N,)).to(X.device) - phi = self.phi.sample((N,)).to(X.device) - - # now sample the actual waveforms we want to inject - idx = torch.randperm(waveforms.shape[0])[:N] - waveforms = waveforms[idx].to(X.device).float() - - cross, plus = waveforms[:, 0], waveforms[:, 1] - polarizations = {"cross": cross, "plus": plus} - - return dec, psi, phi, polarizations, mask From d8817593b282218b900e96d437c4d0470a85552d Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 2 Jul 2025 12:28:45 -0700 Subject: [PATCH 07/32] Added waveform_sampler objects --- .../train/data/waveforms/generator/cbc.py | 75 +++++++++++ .../data/waveforms/generator/generator.py | 50 ++++++++ projects/train/train/data/waveforms/loader.py | 57 +++++++++ .../train/train/data/waveforms/sampler.py | 85 +++++++++++++ projects/train/uv.lock | 119 +++++++++++++++++- 5 files changed, 381 insertions(+), 5 deletions(-) create mode 100644 projects/train/train/data/waveforms/generator/cbc.py create mode 100644 projects/train/train/data/waveforms/generator/generator.py create mode 100644 projects/train/train/data/waveforms/loader.py create mode 100644 projects/train/train/data/waveforms/sampler.py diff --git a/projects/train/train/data/waveforms/generator/cbc.py b/projects/train/train/data/waveforms/generator/cbc.py new file mode 100644 index 000000000..6e0e215aa --- /dev/null +++ b/projects/train/train/data/waveforms/generator/cbc.py @@ -0,0 +1,75 @@ +from typing import Callable + +import torch +from ml4gw.waveforms.generator import TimeDomainCBCWaveformGenerator + +from ledger.injections import BilbyParameterSet + +from .generator import WaveformGenerator + + +class CBCGenerator(WaveformGenerator): + def __init__( + self, + *args, + approximant: Callable, + f_min: float, + f_ref: float, + right_pad: float, + **kwargs, + ): + """ + A lightweight wrapper around + `ml4gw.waveforms.generator.TimeDomainCBCWaveformGenerator` + to make it compatible with + `aframe.train.train.data.waveforms.generator.WaveformGenerator`. + Args: + *args: + Positional arguments passed to + `aframe.train.train.data.waveforms.generator.WaveformGenerator` + approximant: + A callable that takes parameter tensors + and returns the cross and plus polarizations. + For example, `ml4gw.waveforms.IMRPhenomD()` + f_min: + Lowest frequency at which waveform signal content + is generated + f_ref: + Reference frequency + right_pad: + Position in seconds where coalesence is placed + relative to the right edge of the window + **kwargs: + Keyword arguments passed to + `aframe.train.train.data.waveforms.generator.WaveformGenerator` + """ + super().__init__(*args, **kwargs) + self.right_pad = right_pad + self.approximant = approximant + self.f_ref = f_ref + self.waveform_generator = TimeDomainCBCWaveformGenerator( + approximant, + self.sample_rate, + self.kernel_length, + f_min, + f_ref, + right_pad, + ) + + def convert(self, parameters): + # TODO: This assumes a detector-frame prior. Remove this + # when we switch to source-frame prior. + for key in ["mass_1", "mass_2", "chirp_mass", "total_mass"]: + if key in parameters: + parameters[key] *= 1 + parameters["redshift"] + parameter_set = BilbyParameterSet(**parameters) + generation_params = parameter_set.generation_params( + reference_frequency=self.f_ref + ) + return generation_params + + def forward(self, **parameters) -> torch.Tensor: + hc, hp = self.waveform_generator(**parameters) + waveforms = torch.stack([hc, hp], dim=1) + hc, hp = waveforms.transpose(1, 0) + return hc.float(), hp.float() diff --git a/projects/train/train/data/waveforms/generator/generator.py b/projects/train/train/data/waveforms/generator/generator.py new file mode 100644 index 000000000..374a0e367 --- /dev/null +++ b/projects/train/train/data/waveforms/generator/generator.py @@ -0,0 +1,50 @@ +from typing import Callable, TYPE_CHECKING + +from ..sampler import WaveformSampler +import torch + + +if TYPE_CHECKING: + pass + + +class WaveformGenerator(WaveformSampler): + def __init__( + self, + *args, + training_prior: Callable, + **kwargs, + ): + """ + A torch module for generating waveforms on the fly. + Args: + training_prior: + A callable that returns a prior distribution + for the parameters of the waveform generator. + """ + super().__init__(*args, **kwargs) + self.training_prior, _ = training_prior() + + def get_train_waveforms(self, *_): + """ + Method is not implemented for this class, as + waveforms are generated on the fly. + """ + pass + + def sample(self, X: torch.Tensor): + N = len(X) + parameters = self.training_prior.sample(N) + generation_params = self.convert(parameters) + generation_params = { + k: torch.Tensor(v).to(X.device) + for k, v in generation_params.items() + } + hc, hp = self(**generation_params) + return hc, hp + + def convert(self, parameters: dict) -> dict: + raise NotImplementedError + + def forward(self): + raise NotImplementedError diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py new file mode 100644 index 000000000..257beabe1 --- /dev/null +++ b/projects/train/train/data/waveforms/loader.py @@ -0,0 +1,57 @@ +from pathlib import Path + +import h5py +import torch + +from .sampler import WaveformSampler + + +class WaveformLoader(WaveformSampler): + """ + Torch module for loading training and validation + waveforms from disk and sampling them during training. + TODO: modify this to sample waveforms from disk, taking + an index sampler object so that DDP training can sample + different waveforms for each device. + Args: + training_waveform_file: + Path to the training waveforms file + """ + + def __init__( + self, + *args, + training_waveform_file: Path, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.training_waveform_file = training_waveform_file + + with h5py.File(training_waveform_file) as f: + self.num_train_waveforms = len(f["waveforms"]["cross"]) + self.right_pad = f.attrs["duration"] - f.attrs["coalescence_time"] + + def get_train_waveforms(self, world_size, rank, device): + """ + Returns train waveforms for this device + """ + start, stop = self.get_slice_bounds( + self.num_train_waveforms, world_size, rank + ) + with h5py.File(self.val_waveform_file) as f: + waveforms = [] + for key in f["waveforms"].keys(): + waveforms.append(torch.Tensor(f["waveforms"][key][start:stop])) + + self.train_waveforms = torch.stack(waveforms, dim=0).to(device) + + def sample(self, X: torch.Tensor): + """ + Sample method for generating training waveforms + """ + N = len(X) + idx = torch.randperm(self.num_train_waveforms)[:N] + waveforms = self.train_waveforms[:, idx] + + hc, hp = waveforms + return hc, hp diff --git a/projects/train/train/data/waveforms/sampler.py b/projects/train/train/data/waveforms/sampler.py new file mode 100644 index 000000000..748fe92c2 --- /dev/null +++ b/projects/train/train/data/waveforms/sampler.py @@ -0,0 +1,85 @@ +from pathlib import Path + +import h5py +import torch +from utils import x_per_y + +Distribution = torch.distributions.Distribution + + +class WaveformSampler(torch.nn.Module): + """ + Base object defining methods that waveform producing classes + should implement. Should not be instantiated on its own. + Args: + fduration: + Desired length in seconds of the time domain + response of the whitening filter built from PSDs. + See `ml4gw.spectral.truncate_inverse_power_spectrum` + kernel_length: + Length in seconds of window passed to neural network. + sample_rate: + Sample rate in Hz of generated waveforms + val_waveform_file: + Path to the validation waveforms file. + """ + + def __init__( + self, + *args, + fduration: float, + kernel_length: float, + sample_rate: float, + val_waveform_file: Path, + **kwargs, + ) -> None: + super().__init__(*args, **kwargs) + self.fduration = fduration + self.kernel_length = kernel_length + self.sample_rate = sample_rate + self.val_waveform_file = val_waveform_file + + with h5py.File(val_waveform_file) as f: + key = list(f["waveforms"].keys())[0] + self.num_val_waveforms = len(f["waveforms"][key]) + + @property + def duration(self): + """ + Length of kernel before whitening removes + fduration / 2 from each side + """ + return self.fduration + self.kernel_length + + def get_slice_bounds(self, total, world_size, rank) -> tuple[int, int]: + """ + Determine waveform indices to load for this device + given our rank and world size + """ + per_dev = x_per_y(abs(total), world_size) + start = rank * per_dev + stop = (rank + 1) * per_dev + return start, stop + + # Assuming that we're going to be loading validation waveforms + # from disk for now, so this function can be defined here. + def get_val_waveforms(self, world_size, rank): + """ + Returns validation waveforms for this device + """ + start, stop = self.get_slice_bounds( + self.num_val_waveforms, world_size, rank + ) + with h5py.File(self.val_waveform_file) as f: + waveforms = [] + for key in f["waveforms"].keys(): + waveforms.append(torch.Tensor(f["waveforms"][key][start:stop])) + + return torch.stack(waveforms, dim=0) + + def get_test_waveforms(self): + raise NotImplementedError + + def sample(self): + """Defines how to sample waveforms for training""" + raise NotImplementedError diff --git a/projects/train/uv.lock b/projects/train/uv.lock index 9b536d7e1..bf834ed7c 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -38,9 +38,9 @@ requires-dist = [ { name = "cloudpathlib", specifier = ">=0.18.1,<0.19" }, { name = "jsonargparse", specifier = ">=4.27.1,<5" }, { name = "kr8s", specifier = ">=0.10.0,<0.11" }, - { name = "law", specifier = ">=0.1.19" }, + { name = "law", specifier = ">=0.1.20" }, { name = "luigi", specifier = "~=3.0" }, - { name = "ml4gw-hermes", specifier = ">=0.2.0" }, + { name = "ml4gw-hermes", specifier = ">=0.2.1" }, { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "psutil", specifier = ">=5.9.8,<6" }, { name = "pykube-ng", extras = ["oidc"], specifier = ">=23.6.0,<24" }, @@ -333,6 +333,49 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f9/49/6abb616eb3cbab6a7cca303dc02fdf3836de2e0b834bf966a7f5271a34d8/beautifulsoup4-4.13.3-py3-none-any.whl", hash = "sha256:99045d7d3f08f91f0d656bc9b7efbae189426cd913d830294a15eefa0ea4df16", size = 186015 }, ] +[[package]] +name = "bilby" +version = "2.5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "attrs" }, + { name = "bilby-cython" }, + { name = "corner" }, + { name = "dill" }, + { name = "dynesty" }, + { name = "emcee" }, + { name = "h5py" }, + { name = "matplotlib" }, + { name = "numpy" }, + { name = "pandas" }, + { name = "scipy" }, + { name = "tqdm" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6f/60/1c82067bd5d87fcc4dbe045c194866f21a00679ab529cb4b78c3c131d40e/bilby-2.5.2.tar.gz", hash = "sha256:b600494f8b9ca1a01124b6c8cd12711621450f5f9e54b3f24a1272bb878f2796", size = 11511325 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/b0/8314d3a6659bd160f25c94dae61fe0a38fe8bdfa1decb8844cb994e75771/bilby-2.5.2-py3-none-any.whl", hash = "sha256:ec849e05e9eceb67fdf0b5e01d62aa256c22abdb2c7cc84064516989a97ff25d", size = 2267539 }, +] + +[[package]] +name = "bilby-cython" +version = "0.5.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/4c/93/6877dd78b375f3e83c8209aecc18f18d1cd88269a8a5c2d3c7dd9ad4fc53/bilby_cython-0.5.3.tar.gz", hash = "sha256:44400b2abc6fa592b13d69f33460cb156d6edb48a783f5539ee0a21e1d0a3508", size = 238517 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/e7/24a4d73fcee5e828b5819466ee8d5c7310119c3b08122a033ba6fe99d80e/bilby.cython-0.5.3-cp310-cp310-macosx_12_0_x86_64.whl", hash = "sha256:27a34164fbaf140f89e0e0e395c6ea3ae5e57209c07c7c717c4385fecb9dd398", size = 351950 }, + { url = "https://files.pythonhosted.org/packages/00/66/aaf5ebb0467f0a3cd760aaefd5789a474785675717f7a9c18918d65556ec/bilby.cython-0.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f72f2865d4a01cb241c400efe4764e85194fe43f362061778f291498f0538df", size = 940117 }, + { url = "https://files.pythonhosted.org/packages/13/e9/8c63ca64b19a0536616ee1aa77c46492c8da81a3589963b85fd929962dfa/bilby.cython-0.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:ad135a68e800b60c03cca967347289b3e7481d2c1376f8188510e5191fd812f0", size = 359692 }, + { url = "https://files.pythonhosted.org/packages/de/c1/3799f4bd0e5a13375b621f21d12fd21e596d744b363e683f815afe593f1f/bilby.cython-0.5.3-cp311-cp311-macosx_12_0_x86_64.whl", hash = "sha256:813f2b53a55b0fbbc2a86ef9c025e3020eb086473663c110161e3b120728176d", size = 352052 }, + { url = "https://files.pythonhosted.org/packages/e4/6a/78511b5e08ed89f492297746e3f488159d90e2b89680777a3a37565916eb/bilby.cython-0.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4642cb38f2cea8129261e25fe4237e2c06982753ef3ce2611072e541b7ecf6fb", size = 983911 }, + { url = "https://files.pythonhosted.org/packages/38/f3/ef24cb7c8a7335f869829e5c0909515544924dcd5df6d4f58b377e1daeb2/bilby.cython-0.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:d9d0c1c490619981234568dae7f11fab1544783a748492d39d0d690a70d12f6a", size = 359647 }, + { url = "https://files.pythonhosted.org/packages/8b/c5/ad5ca082b2610defc488679690df8137300c6bb396b24f783e3d74873fa4/bilby.cython-0.5.3-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:264ccd8ca1adabc794931ed6deb5082ad0ed4b52694be8158cb421a80a752bca", size = 351851 }, + { url = "https://files.pythonhosted.org/packages/13/26/f0b46d56d278665b484ec421dc571fb28bdd81635137d00e0edc2c8fddc9/bilby.cython-0.5.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:44e5e381c2861e26a4e1fd5c591ea0c3c9a0e2f0d8c78f28f8704abf2945cd8d", size = 1014120 }, + { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801 }, +] + [[package]] name = "bokeh" version = "3.6.3" @@ -640,6 +683,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/52/94/86bfae441707205634d80392e873295652fc313dfd93c233c52c4dc07874/contourpy-1.3.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:44a29502ca9c7b5ba389e620d44f2fbe792b1fb5734e8b931ad307071ec58c53", size = 218221 }, ] +[[package]] +name = "corner" +version = "2.2.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "matplotlib" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/61/2d728798e9ae3bf899d962f77486cb29888d57f6fbf9561bc1435a6b1a74/corner-2.2.3.tar.gz", hash = "sha256:471b7b63395d8f1dee176bb779348ade38d56abd23404a48802a593607745e1c", size = 5932840 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/4a/5bd0a8b981c5a93153d9eb7c63143b407cc7f8dfc9f91eedc9b6f5289eca/corner-2.2.3-py3-none-any.whl", hash = "sha256:39674b223482456c3a78234dc7bdefd21188a2d47bb8cd468104a0501f6659ec", size = 15946 }, +] + [[package]] name = "cryptography" version = "44.0.1" @@ -752,6 +807,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/6b/7c87867d255cbce8167ed99fc65635e9395d2af0f0c915428f5b17ec412d/Cython-3.0.12-py2.py3-none-any.whl", hash = "sha256:0038c9bae46c459669390e53a1ec115f8096b2e4647ae007ff1bf4e6dee92806", size = 1171640 }, ] +[[package]] +name = "dill" +version = "0.4.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/12/80/630b4b88364e9a8c8c5797f4602d0f76ef820909ee32f0bacb9f90654042/dill-0.4.0.tar.gz", hash = "sha256:0633f1d2df477324f53a895b02c901fb961bdbf65a17122586ea7019292cbcf0", size = 186976 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/50/3d/9373ad9c56321fdab5b41197068e1d8c25883b3fea29dd361f9b55116869/dill-0.4.0-py3-none-any.whl", hash = "sha256:44f54bf6412c2c8464c14e8243eb163690a9800dbe2c367330883b19c7561049", size = 119668 }, +] + [[package]] name = "distlib" version = "0.3.9" @@ -782,6 +846,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl", hash = "sha256:bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637", size = 36533 }, ] +[[package]] +name = "dynesty" +version = "2.1.5" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d0/20/6a8a27d803900ad8fdfe7d730a66d2304f3fda17616eb8debe10085fa81a/dynesty-2.1.5.tar.gz", hash = "sha256:fd742d483a1f78086bc8a36316594fa1e18eeb3a27f8293fca147b377b5cd311", size = 31034553 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c8/01/db367a6adfa49d280cda609ad5d94b9b257058b4499c59633e6d15c5e552/dynesty-2.1.5-py2.py3-none-any.whl", hash = "sha256:c1d20f300d6e0fb64675f01683d7511b181d2a5ac372295b1fe5f4410670530b", size = 108214 }, +] + [[package]] name = "einops" version = "0.8.1" @@ -791,6 +864,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/62/9773de14fe6c45c23649e98b83231fffd7b9892b6cf863251dc2afa73643/einops-0.8.1-py3-none-any.whl", hash = "sha256:919387eb55330f5757c6bea9165c5ff5cfe63a642682ea788a6d472576d81737", size = 64359 }, ] +[[package]] +name = "emcee" +version = "3.1.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/cb/53/1045ee878cb24281387079f8ee4f0ade1622c6aae1ed1fd91a53e4fa5b19/emcee-3.1.6.tar.gz", hash = "sha256:11af4daf6ab8f9ca69681e3c29054665db7bbd87fd4eb8e437d2c3a1248c637d", size = 2871117 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/ef/2196b9bf88ffa1bde45853c72df021fbd07a8fa91a0f59a22d14a050dc04/emcee-3.1.6-py2.py3-none-any.whl", hash = "sha256:f2d63752023bdccf744461450e512a5b417ae7d28f18e12acd76a33de87580cb", size = 47351 }, +] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -1678,7 +1763,7 @@ wheels = [ [[package]] name = "ml4gw-hermes" -version = "0.2.0" +version = "0.2.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, @@ -1688,9 +1773,9 @@ dependencies = [ { name = "tblib" }, { name = "tritonclient", extra = ["all"] }, ] -sdist = { url = "https://files.pythonhosted.org/packages/01/03/75ed3d4b7bc1b8fea50c139a325b70386b4a7852b329cd8cec0e124e0126/ml4gw_hermes-0.2.0.tar.gz", hash = "sha256:e254f2b3dff8676e75d35c0eb0e7f41d76feb3a74614ab154493ac2f344294c5", size = 43060 } +sdist = { url = "https://files.pythonhosted.org/packages/34/75/4f4930e02da76fe78eca85af9240a1ee901a791a55b264bb0d8048c24719/ml4gw_hermes-0.2.1.tar.gz", hash = "sha256:9349daecee00515f7ff812099eb2ac46472eb1af8ca4cea1adf2c12d4243507c", size = 42163 } wheels = [ - { url = "https://files.pythonhosted.org/packages/72/b7/9d7b67be9b09af2b04a236156781e84f2ef48dafeb76cb28f8e6c3e6d3a1/ml4gw_hermes-0.2.0-py3-none-any.whl", hash = "sha256:30e32135e44d7babed8f4636268a11c47b811a6250b3bc9dc8c066d929d1a388", size = 57589 }, + { url = "https://files.pythonhosted.org/packages/73/db/82ff800c799e2b788cedf45b23d6128629737f556567777a4dc91452920d/ml4gw_hermes-0.2.1-py3-none-any.whl", hash = "sha256:b7ff0a1f917240cdf9763e771d5ae88857dc04d39f38c6f68e41d40e88cec28e", size = 57806 }, ] [[package]] @@ -2137,6 +2222,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "priors" +version = "0.1.0" +source = { editable = "../../libs/priors" } +dependencies = [ + { name = "astropy" }, + { name = "bilby" }, + { name = "numpy" }, + { name = "utils" }, +] + +[package.metadata] +requires-dist = [ + { name = "astropy", specifier = ">=5.0" }, + { name = "bilby", specifier = ">=2.2.2,<3" }, + { name = "numpy", specifier = ">=1.26.4,<2" }, + { name = "utils", editable = "../../libs/utils" }, +] + +[package.metadata.requires-dev] +dev = [{ name = "pytest", specifier = "~=7.3" }] + [[package]] name = "prometheus-client" version = "0.21.1" @@ -3354,6 +3461,7 @@ dependencies = [ { name = "lightning" }, { name = "lightray" }, { name = "ml4gw" }, + { name = "priors" }, { name = "ray", extra = ["default", "tune"] }, { name = "s3fs" }, { name = "torch" }, @@ -3382,6 +3490,7 @@ requires-dist = [ { name = "lightning", specifier = "==2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, { name = "torch", specifier = "==2.5.0" }, From 24f5ecbb7bf428e7895fce9365312f04a7e20dfc Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 2 Jul 2025 13:19:56 -0700 Subject: [PATCH 08/32] Make sure training and validation waveforms have the same coalescence point --- projects/train/apptainer.def | 2 +- projects/train/train.yaml | 3 +-- projects/train/train/data/base.py | 1 - projects/train/train/data/waveforms/generator/cbc.py | 6 +----- projects/train/train/data/waveforms/loader.py | 10 +++++++++- projects/train/train/data/waveforms/sampler.py | 1 + 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/projects/train/apptainer.def b/projects/train/apptainer.def index 22b95bb6d..283c97ceb 100644 --- a/projects/train/apptainer.def +++ b/projects/train/apptainer.def @@ -10,7 +10,7 @@ From: ghcr.io/astral-sh/uv:0.6.10-python3.10-bookworm-slim ../../aframe /opt/aframe/aframe ../../pyproject.toml /opt/aframe/pyproject.toml ../../libs/ledger /opt/aframe/libs/ledger -../../libs/prior /opt/aframe/libs/prior +../../libs/priors /opt/aframe/libs/priors ../../libs/architectures /opt/aframe/libs/architectures %post diff --git a/projects/train/train.yaml b/projects/train/train.yaml index 06182c1d2..ffc4462e8 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -70,11 +70,10 @@ data: class_path: train.data.waveforms.generator.cbc.CBCGenerator init_args: training_prior: priors.priors.end_o3_ratesandpops - val_waveform_file: ${oc.env:AFRAME_DATADIR}/train/val_waveforms.hdf5 + val_waveform_file: ${oc.env:AFRAME_TRAIN_DATA_DIR}/train/val_waveforms.hdf5 approximant: ml4gw.waveforms.IMRPhenomPv2 f_min: 20 f_ref: 40 - right_pad: 0.5 kernel_length: 8 # Extrinsic parameter distributions dec: diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index cd1cd7e78..f11e1c33d 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -28,7 +28,6 @@ # TODO: # Make the coalescence point/right pad parameterization consistent -# Ensure coalescence point placement is the same for training and validation # TODO: using this right now because diff --git a/projects/train/train/data/waveforms/generator/cbc.py b/projects/train/train/data/waveforms/generator/cbc.py index 6e0e215aa..cd572fc53 100644 --- a/projects/train/train/data/waveforms/generator/cbc.py +++ b/projects/train/train/data/waveforms/generator/cbc.py @@ -36,15 +36,11 @@ def __init__( is generated f_ref: Reference frequency - right_pad: - Position in seconds where coalesence is placed - relative to the right edge of the window **kwargs: Keyword arguments passed to `aframe.train.train.data.waveforms.generator.WaveformGenerator` """ super().__init__(*args, **kwargs) - self.right_pad = right_pad self.approximant = approximant self.f_ref = f_ref self.waveform_generator = TimeDomainCBCWaveformGenerator( @@ -53,7 +49,7 @@ def __init__( self.kernel_length, f_min, f_ref, - right_pad, + self.right_pad, ) def convert(self, parameters): diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index 257beabe1..3956e2011 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -29,7 +29,6 @@ def __init__( with h5py.File(training_waveform_file) as f: self.num_train_waveforms = len(f["waveforms"]["cross"]) - self.right_pad = f.attrs["duration"] - f.attrs["coalescence_time"] def get_train_waveforms(self, world_size, rank, device): """ @@ -43,6 +42,15 @@ def get_train_waveforms(self, world_size, rank, device): for key in f["waveforms"].keys(): waveforms.append(torch.Tensor(f["waveforms"][key][start:stop])) + if ( + self.right_pad + != f.attrs["duration"] - f.attrs["coalescence_time"] + ): + raise ValueError( + "Training and validation waveform files do not have " + "the same coalescence time and/or duration" + ) + self.train_waveforms = torch.stack(waveforms, dim=0).to(device) def sample(self, X: torch.Tensor): diff --git a/projects/train/train/data/waveforms/sampler.py b/projects/train/train/data/waveforms/sampler.py index 748fe92c2..632a15d5b 100644 --- a/projects/train/train/data/waveforms/sampler.py +++ b/projects/train/train/data/waveforms/sampler.py @@ -42,6 +42,7 @@ def __init__( with h5py.File(val_waveform_file) as f: key = list(f["waveforms"].keys())[0] self.num_val_waveforms = len(f["waveforms"][key]) + self.right_pad = f.attrs["duration"] - f.attrs["coalescence_time"] @property def duration(self): From e723c7483ce69eb50efbce4070f95b992f033a8f Mon Sep 17 00:00:00 2001 From: William Benoit Date: Thu, 3 Jul 2025 07:12:06 -0700 Subject: [PATCH 09/32] Add ml4gw generation params --- libs/ledger/ledger/injections.py | 19 +++++++++++++++++++ projects/train/train/callbacks.py | 2 +- projects/train/train/data/base.py | 1 + .../train/data/waveforms/generator/cbc.py | 6 ++---- .../train/train/data/waveforms/sampler.py | 8 -------- 5 files changed, 23 insertions(+), 13 deletions(-) diff --git a/libs/ledger/ledger/injections.py b/libs/ledger/ledger/injections.py index d111ca25e..a070ebe7f 100644 --- a/libs/ledger/ledger/injections.py +++ b/libs/ledger/ledger/injections.py @@ -158,6 +158,25 @@ def redshift(self, cosmology=DEFAULT_COSMOLOGY): cosmology.luminosity_distance, self.luminosity_distance * Mpc ).value + @property + def ml4gw_generation_params(self): + params = { + "mass_1": self.mass1, + "mass_2": self.mass2, + "chirp_mass": chirp_mass(self.mass1, self.mass2), + "mass_ratio": self.mass2 / self.mass1, + "s1x": self.spin1x, + "s1y": self.spin1y, + "s1z": self.spin1z, + "s2x": self.spin2x, + "s2y": self.spin2y, + "s2z": self.spin2z, + "inclination": self.inclination, + "distance": self.luminosity_distance, + "phic": self.phase, + } + return params + @property def generation_params(self): params = { diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index f7ed1c056..fdd69cb7b 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -91,7 +91,7 @@ def on_train_start(self, trainer, pl_module): background = background.to(device) cross = cross.to(device) plus = plus.to(device) - X_bg, X_inj, _ = trainer.datamodule.build_val_batches( + X_bg, X_inj = trainer.datamodule.build_val_batches( background, cross, plus ) # Make background and injected validation data into diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index f11e1c33d..2fb178b28 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -28,6 +28,7 @@ # TODO: # Make the coalescence point/right pad parameterization consistent +# Move waveform slicing to the waveform sampler? # TODO: using this right now because diff --git a/projects/train/train/data/waveforms/generator/cbc.py b/projects/train/train/data/waveforms/generator/cbc.py index cd572fc53..cc150b9f0 100644 --- a/projects/train/train/data/waveforms/generator/cbc.py +++ b/projects/train/train/data/waveforms/generator/cbc.py @@ -15,7 +15,6 @@ def __init__( approximant: Callable, f_min: float, f_ref: float, - right_pad: float, **kwargs, ): """ @@ -59,9 +58,8 @@ def convert(self, parameters): if key in parameters: parameters[key] *= 1 + parameters["redshift"] parameter_set = BilbyParameterSet(**parameters) - generation_params = parameter_set.generation_params( - reference_frequency=self.f_ref - ) + lal_params = parameter_set.convert_to_lal_param_set(self.f_ref) + generation_params = lal_params.ml4gw_generation_params return generation_params def forward(self, **parameters) -> torch.Tensor: diff --git a/projects/train/train/data/waveforms/sampler.py b/projects/train/train/data/waveforms/sampler.py index 632a15d5b..644a21d79 100644 --- a/projects/train/train/data/waveforms/sampler.py +++ b/projects/train/train/data/waveforms/sampler.py @@ -44,14 +44,6 @@ def __init__( self.num_val_waveforms = len(f["waveforms"][key]) self.right_pad = f.attrs["duration"] - f.attrs["coalescence_time"] - @property - def duration(self): - """ - Length of kernel before whitening removes - fduration / 2 from each side - """ - return self.fduration + self.kernel_length - def get_slice_bounds(self, total, world_size, rank) -> tuple[int, int]: """ Determine waveform indices to load for this device From 382d521b6ed6466d981feee718dad8b107144f2f Mon Sep 17 00:00:00 2001 From: William Benoit Date: Thu, 3 Jul 2025 09:45:34 -0700 Subject: [PATCH 10/32] Update ledgers and tests with right_pad --- libs/ledger/ledger/injections.py | 45 +++++++++++++--------------- libs/ledger/tests/test_injections.py | 25 ++++++++-------- libs/ledger/uv.lock | 7 +++-- 3 files changed, 37 insertions(+), 40 deletions(-) diff --git a/libs/ledger/ledger/injections.py b/libs/ledger/ledger/injections.py index a070ebe7f..114d5eb39 100644 --- a/libs/ledger/ledger/injections.py +++ b/libs/ledger/ledger/injections.py @@ -253,7 +253,7 @@ def convert_to_bilby_param_set(self, reference_frequency: float): class InjectionMetadata(Ledger): sample_rate: np.ndarray = metadata() duration: np.ndarray = metadata() - coalescence_time: float = metadata() + right_pad: float = metadata() num_injections: int = metadata(default=0) def __post_init__(self): @@ -311,15 +311,15 @@ class _WaveformGenerator: waveform_approximant: str sample_rate: float waveform_duration: float - coalescence_time: float + right_pad: float minimum_frequency: float reference_frequency: float def shift_coalescence(self, waveforms: np.ndarray, t_final: float): """ Shift a pair of polarizations such that the coalescence point is moved - to the time specified by self.coalescence_time. The shift is - accomplished by rolling the array + such that it is `self.right_pad` seconds from the right edge. + The shift is accomplished by rolling the array Args: waveforms: @@ -333,11 +333,7 @@ def shift_coalescence(self, waveforms: np.ndarray, t_final: float): Returns: The stacked, shifted polarizations """ - shift_time = ( - t_final - - (self.waveform_duration - self.coalescence_time) - + 1 / self.sample_rate - ) + shift_time = t_final - self.right_pad + 1 / self.sample_rate shift_idx = int(shift_time * self.sample_rate) return np.roll(waveforms, shift_idx, axis=-1) @@ -422,20 +418,20 @@ def from_parameters( sample_rate: float, waveform_duration: float, waveform_approximant: str, - coalescence_time: float, + right_pad: float, ex: Optional[Executor] = None, ): - if waveform_duration < coalescence_time: + if waveform_duration < right_pad: raise ValueError( - "Coalescence time must be less than waveform duration; " - f"got values of {coalescence_time} and {waveform_duration}" + "Right padding must be less than waveform duration; " + f"got values of {right_pad} and {waveform_duration}" ) waveform_generator = _WaveformGenerator( waveform_approximant=waveform_approximant, sample_rate=sample_rate, waveform_duration=waveform_duration, - coalescence_time=coalescence_time, + right_pad=right_pad, minimum_frequency=minimum_frequency, reference_frequency=reference_frequency, ) @@ -467,7 +463,7 @@ def from_parameters( polarizations["sample_rate"] = sample_rate polarizations["duration"] = waveform_duration polarizations["num_injections"] = len(params) - polarizations["coalescence_time"] = coalescence_time + polarizations["right_pad"] = right_pad return cls(**polarizations) @@ -579,14 +575,14 @@ def read( if all(i is None for i in [start, end, shifts]): return cls._load_with_idx(f, None) - coalescence_time = f.attrs["coalescence_time"] + left_pad = f.attrs["duration"] - f.attrs["right_pad"] times = f["parameters"]["injection_time"][:] mask = True if start is not None: - mask &= (times + coalescence_time) >= start + mask &= (times + left_pad) >= start if end is not None: - mask &= (times - coalescence_time) <= end + mask &= (times - left_pad) <= end if shifts is not None: shifts = np.array(shifts) ndim = shifts.ndim @@ -631,10 +627,10 @@ def inject(self, x: np.ndarray, start: float): initial timestamp `start` """ stop = start + x.shape[-1] / self.sample_rate - post_coalescence_time = self.duration - self.coalescence_time + left_pad = self.duration - self.right_pad - mask = self.injection_time >= (start - self.coalescence_time) - mask &= self.injection_time <= (stop + post_coalescence_time) + mask = self.injection_time >= (start - left_pad) + mask &= self.injection_time <= (stop + self.right_pad) if not mask.any(): return x @@ -645,7 +641,7 @@ def inject(self, x: np.ndarray, start: float): # potentially pad x to inject waveforms # that fall over the boundaries of chunks pad = [] - earliest = (times - self.coalescence_time - start).min() + earliest = (times - left_pad - start).min() if earliest < 0: # For consistency, we want to round down here # E.g., if earliest = -0.1 and sample_rate = 2048, @@ -658,7 +654,7 @@ def inject(self, x: np.ndarray, start: float): else: pad.append(0) - latest = (times + post_coalescence_time - stop).max() + latest = (times + self.right_pad - stop).max() if latest > 0: num_late = int(latest * self.sample_rate) pad.append(num_late) @@ -672,9 +668,8 @@ def inject(self, x: np.ndarray, start: float): # create matrix of indices of waveform_size for each waveform waveforms = waveforms.transpose((1, 0, 2)) _, num_waveforms, waveform_size = waveforms.shape - coalescence_time_idx = int(self.coalescence_time * self.sample_rate) - idx = np.arange(waveform_size) - coalescence_time_idx + idx = np.arange(waveform_size) - int(left_pad * self.sample_rate) idx = idx[None] idx = np.repeat(idx, num_waveforms, axis=0) diff --git a/libs/ledger/tests/test_injections.py b/libs/ledger/tests/test_injections.py index 3c9fef40c..53f5ce3c3 100644 --- a/libs/ledger/tests/test_injections.py +++ b/libs/ledger/tests/test_injections.py @@ -115,8 +115,8 @@ def waveform_duration(self): def sample_rate(self): return 2048 - @pytest.fixture(params=[0, 4, 7, 8]) - def coalescence_time(self, request): + @pytest.fixture(params=[0, 0.5, 1, 4]) + def right_pad(self, request): return request.param @pytest.fixture(params=[4, 10]) @@ -134,16 +134,16 @@ def test_align_waveforms( self, sample_rate, waveform_duration, - coalescence_time, + right_pad, dummy_signal, ): - if coalescence_time == waveform_duration: - coalescence_time -= 1 / sample_rate + if right_pad == 0: + right_pad += 1 / sample_rate gen = _WaveformGenerator( waveform_approximant="", sample_rate=sample_rate, waveform_duration=waveform_duration, - coalescence_time=coalescence_time, + right_pad=right_pad, minimum_frequency=20, reference_frequency=20, ) @@ -151,9 +151,8 @@ def test_align_waveforms( t_final = dummy_signal.shape[-1] // 2 / sample_rate waveforms = gen.align_waveforms(dummy_signal, t_final) assert waveforms.shape[-1] == int(sample_rate * waveform_duration) - assert ( - np.argmax(waveforms) - == (coalescence_time % waveform_duration) * sample_rate + assert np.argmax(waveforms) == int( + (waveform_duration - right_pad) * sample_rate ) @@ -192,7 +191,7 @@ def ligo_response_set(self, response_set_cls, duration, sample_rate, N): sample_rate=sample_rate, duration=duration, num_injections=N, - coalescence_time=duration / 2, + right_pad=duration / 2, **kwargs, ) assert str(exc.value).startswith("Specified waveform duration") @@ -205,7 +204,7 @@ def ligo_response_set(self, response_set_cls, duration, sample_rate, N): sample_rate=sample_rate, duration=duration, num_injections=N - 1, - coalescence_time=duration / 2, + right_pad=duration / 2, **kwargs, ) assert str(exc.value).startswith("LigoResponseSet") @@ -214,7 +213,7 @@ def ligo_response_set(self, response_set_cls, duration, sample_rate, N): sample_rate=sample_rate, duration=duration, num_injections=N, - coalescence_time=duration / 2, + right_pad=duration / 2, **kwargs, ) @@ -320,7 +319,7 @@ def test_read_with_shifts(self, ligo_response_set, tmp_path, N): # now try with times ligo_response_set.injection_time = np.arange(N) % (N // 2) ligo_response_set.duration = 2 - ligo_response_set.coalescence_time = ligo_response_set.duration / 2 + ligo_response_set.right_pad = ligo_response_set.duration / 2 for ifo in "hl": key = f"{ifo}1" old = getattr(ligo_response_set, key) diff --git a/libs/ledger/uv.lock b/libs/ledger/uv.lock index bc2d52714..811da5aa7 100644 --- a/libs/ledger/uv.lock +++ b/libs/ledger/uv.lock @@ -3614,12 +3614,15 @@ dependencies = [ [package.metadata] requires-dist = [ - { name = "h5py", specifier = ">=3.6,<4.0" }, + { name = "h5py", specifier = "~=3.6" }, { name = "ml4gw", specifier = ">=0.7.2" }, - { name = "numpy", specifier = ">=1.26.4,<2.0.0" }, + { name = "numpy", specifier = ">=1.26.4,<2" }, { name = "s3fs", specifier = ">=2024,<2025" }, ] +[package.metadata.requires-dev] +dev = [{ name = "pytest", specifier = ">=8.3.0,<9" }] + [[package]] name = "wadler-lindig" version = "0.1.4" From 0299d498884dc9e342618c52c3014e3f649d839e Mon Sep 17 00:00:00 2001 From: William Benoit Date: Thu, 3 Jul 2025 11:04:27 -0700 Subject: [PATCH 11/32] Make everything in terms of right_pad --- aframe/pipelines/sandbox/configs/base.cfg | 8 ++++---- aframe/pipelines/sandbox/configs/bbh.cfg | 2 +- aframe/pipelines/sandbox/configs/bns.cfg | 2 +- aframe/pipelines/sandbox/configs/review.cfg | 4 ++-- aframe/tasks/data/waveforms/base.py | 5 +++-- aframe/tasks/data/waveforms/testing.py | 2 +- aframe/tasks/data/waveforms/training.py | 2 +- aframe/tasks/data/waveforms/validation.py | 2 +- docs/projects/data.md | 4 ++-- projects/data/data/waveforms/rejection.py | 6 +++--- projects/data/data/waveforms/testing.py | 9 +++++---- projects/data/data/waveforms/training.py | 9 +++++---- projects/train/train/data/waveforms/loader.py | 7 ++----- projects/train/train/data/waveforms/sampler.py | 2 +- 14 files changed, 32 insertions(+), 32 deletions(-) diff --git a/aframe/pipelines/sandbox/configs/base.cfg b/aframe/pipelines/sandbox/configs/base.cfg index 4d311478b..79f15d130 100644 --- a/aframe/pipelines/sandbox/configs/base.cfg +++ b/aframe/pipelines/sandbox/configs/base.cfg @@ -27,7 +27,7 @@ waveform_approximant = IMRPhenomPv2 waveform_duration = 8 minimum_frequency = 20 reference_frequency = 50 -coalescence_time = 6 +right_pad = 2 # training parameters kernel_length = 1.5 @@ -70,7 +70,7 @@ waveform_duration = &::luigi_base::waveform_duration minimum_frequency = &::luigi_base::minimum_frequency reference_frequency = &::luigi_base::reference_frequency waveform_approximant = &::luigi_base::waveform_approximant -coalescence_time = &::luigi_base::coalescence_time +right_pad = &::luigi_base::right_pad prior = &::luigi_base::prior request_memory = 64GB @@ -86,7 +86,7 @@ waveform_duration = &::luigi_base::waveform_duration minimum_frequency = &::luigi_base::minimum_frequency reference_frequency = &::luigi_base::reference_frequency waveform_approximant = &::luigi_base::waveform_approximant -coalescence_time = &::luigi_base::coalescence_time +right_pad = &::luigi_base::right_pad prior = &::luigi_base::prior snr_threshold = 4 request_memory = 16GB @@ -124,7 +124,7 @@ minimum_frequency = &::luigi_base::minimum_frequency reference_frequency = &::luigi_base::reference_frequency waveform_duration = &::luigi_base::waveform_duration waveform_approximant = &::luigi_base::waveform_approximant -coalescence_time = &::luigi_base::coalescence_time +right_pad = &::luigi_base::right_pad highpass = &::luigi_base::highpass lowpass = &::luigi_base::lowpass request_memory = 16GB diff --git a/aframe/pipelines/sandbox/configs/bbh.cfg b/aframe/pipelines/sandbox/configs/bbh.cfg index 768db8aa2..aee9d6caf 100644 --- a/aframe/pipelines/sandbox/configs/bbh.cfg +++ b/aframe/pipelines/sandbox/configs/bbh.cfg @@ -11,7 +11,7 @@ inherit = $AFRAME_REPO/aframe/pipelines/sandbox/configs/base.cfg # override bbh specific parameters [luigi_base] waveform_duration = 8 -coalescence_time = 4 +right_pad = 2 kernel_length = 1.5 prior = priors.priors.end_o3_ratesandpops diff --git a/aframe/pipelines/sandbox/configs/bns.cfg b/aframe/pipelines/sandbox/configs/bns.cfg index 82c705e9a..faecdea95 100644 --- a/aframe/pipelines/sandbox/configs/bns.cfg +++ b/aframe/pipelines/sandbox/configs/bns.cfg @@ -13,7 +13,7 @@ inherit = $AFRAME_REPO/aframe/pipelines/sandbox/configs/base.cfg # override bns specific parameters [luigi_base] waveform_duration = 90 -coalescence_time = 45 +right_pad = 45 kernel_length = 8 prior = priors.priors.end_o3_ratesandpops_bns q = 45.6 diff --git a/aframe/pipelines/sandbox/configs/review.cfg b/aframe/pipelines/sandbox/configs/review.cfg index b2f167d6f..1c678ee30 100644 --- a/aframe/pipelines/sandbox/configs/review.cfg +++ b/aframe/pipelines/sandbox/configs/review.cfg @@ -27,7 +27,7 @@ waveform_approximant = IMRPhenomPv2 waveform_duration = 8 minimum_frequency = 20 reference_frequency = 50 -coalescence_time = 6 +right_pad = 2 # training parameters kernel_length = 1.5 @@ -79,7 +79,7 @@ minimum_frequency = &::luigi_base::minimum_frequency reference_frequency = &::luigi_base::reference_frequency waveform_duration = &::luigi_base::waveform_duration waveform_approximant = &::luigi_base::waveform_approximant -coalescence_time = &::luigi_base::coalescence_time +right_pad = &::luigi_base::right_pad highpass = &::luigi_base::highpass lowpass = &::luigi_base::lowpass request_memory = 6GB diff --git a/aframe/tasks/data/waveforms/base.py b/aframe/tasks/data/waveforms/base.py index 7df9aaa7f..c9584dc9b 100644 --- a/aframe/tasks/data/waveforms/base.py +++ b/aframe/tasks/data/waveforms/base.py @@ -29,9 +29,10 @@ class WaveformParams(law.Task): default="IMRPhenomXPHM", description="Approximant to use for waveform generation", ) - coalescence_time = luigi.FloatParameter( + right_pad = luigi.FloatParameter( description="Location of the defining point of the signal " - "within the generated waveform" + "within the generated waveform relative to the right edge " + "of the waveform (in seconds)", ) diff --git a/aframe/tasks/data/waveforms/testing.py b/aframe/tasks/data/waveforms/testing.py index f77c21f3d..319753084 100644 --- a/aframe/tasks/data/waveforms/testing.py +++ b/aframe/tasks/data/waveforms/testing.py @@ -182,7 +182,7 @@ def run(self): self.sample_rate, self.waveform_duration, self.waveform_approximant, - self.coalescence_time, + self.right_pad, self.highpass, self.lowpass, self.snr_threshold, diff --git a/aframe/tasks/data/waveforms/training.py b/aframe/tasks/data/waveforms/training.py index b9f9d66a4..54baab85f 100644 --- a/aframe/tasks/data/waveforms/training.py +++ b/aframe/tasks/data/waveforms/training.py @@ -56,7 +56,7 @@ def run(self): minimum_frequency=self.minimum_frequency, reference_frequency=self.reference_frequency, waveform_approximant=self.waveform_approximant, - coalescence_time=self.coalescence_time, + right_pad=self.right_pad, ) chunks = (min(64, num_signals), waveforms.get_waveforms().shape[-1]) with self.output().open("w") as f: diff --git a/aframe/tasks/data/waveforms/validation.py b/aframe/tasks/data/waveforms/validation.py index 7de8ea9b5..8c1491fcf 100644 --- a/aframe/tasks/data/waveforms/validation.py +++ b/aframe/tasks/data/waveforms/validation.py @@ -124,7 +124,7 @@ def run(self): self.sample_rate, self.waveform_duration, self.waveform_approximant, - self.coalescence_time, + self.right_pad, self.highpass, self.lowpass, self.snr_threshold, diff --git a/docs/projects/data.md b/docs/projects/data.md index 5dbd6febb..bacc1968a 100644 --- a/docs/projects/data.md +++ b/docs/projects/data.md @@ -76,7 +76,7 @@ apptainer run $AFRAME_CONTAINER_ROOT/data.sif \ --minimum_frequency 20 \ --reference_frequency 50 \ --waveform_approximant IMRPhenomXPHM \ - --coalescence_time 6 \ + --right_pad 2 \ --output_file ~/aframe/data/train/train_waveforms.hdf5 ``` @@ -93,7 +93,7 @@ apptainer run $AFRAME_CONTAINER_ROOT/data.sif \ --sample_rate 2048 \ --waveform_duration 8 \ --waveform_approximant IMRPhenomXPHM \ - --coalescence_time 6 \ + --right_pad 2 \ --highpass 32 \ --snr_threshold 4 \ --psd ~/aframe/data/train/background/background-1240579783-7829.hdf5 diff --git a/projects/data/data/waveforms/rejection.py b/projects/data/data/waveforms/rejection.py index ad4b0ec52..acd36c5fe 100644 --- a/projects/data/data/waveforms/rejection.py +++ b/projects/data/data/waveforms/rejection.py @@ -25,7 +25,7 @@ def rejection_sample( sample_rate: float, waveform_duration: float, waveform_approximant: str, - coalescence_time: float, + right_pad: float, highpass: float, lowpass: float, snr_threshold: float, @@ -70,7 +70,7 @@ def rejection_sample( sample_rate, waveform_duration, waveform_approximant, - coalescence_time, + right_pad, ) polarizations = { "cross": torch.Tensor(polarization_set.cross), @@ -145,7 +145,7 @@ def rejection_sample( idx += num_accepted num_signals -= num_accepted - parameters["coalescence_time"] = coalescence_time + parameters["right_pad"] = right_pad parameters["sample_rate"] = sample_rate parameters["duration"] = waveform_duration parameters["num_injections"] = num_injections diff --git a/projects/data/data/waveforms/testing.py b/projects/data/data/waveforms/testing.py index a9bfb2f92..f71a98012 100644 --- a/projects/data/data/waveforms/testing.py +++ b/projects/data/data/waveforms/testing.py @@ -23,7 +23,7 @@ def testing_waveforms( sample_rate: float, waveform_duration: float, waveform_approximant: str, - coalescence_time: float, + right_pad: float, highpass: float, lowpass: float, snr_threshold: float, @@ -71,9 +71,10 @@ def testing_waveforms( Duration of waveform in seconds waveform_approximant: Name of the waveform approximant to use. - coalescence_time: + right_pad: Location of the defining point of the signal within - the generated waveform + the generated waveform relative to the right edge + of the waveform (in seconds). highpass: The frequency to use for a highpass filter, specified in Hz @@ -140,7 +141,7 @@ def testing_waveforms( sample_rate, waveform_duration, waveform_approximant, - coalescence_time, + right_pad, highpass, lowpass, snr_threshold, diff --git a/projects/data/data/waveforms/training.py b/projects/data/data/waveforms/training.py index dd047f7f6..4441f1307 100644 --- a/projects/data/data/waveforms/training.py +++ b/projects/data/data/waveforms/training.py @@ -14,7 +14,7 @@ def training_waveforms( minimum_frequency: float, reference_frequency: float, waveform_approximant: str, - coalescence_time: float, + right_pad: float, ): """ Generates random training waveforms polarizations from a @@ -39,9 +39,10 @@ def training_waveforms( reference to waveform_approximant: Name of the waveform approximant to use. - coalescence_time: + right_pad: Location of the defining point of the signal within - the generated waveform + the generated waveform relative to the right edge + of the waveform (in seconds). Returns: An IntrinsicParameterSet generated from the sampled parameters @@ -59,7 +60,7 @@ def training_waveforms( sample_rate, waveform_duration, waveform_approximant, - coalescence_time, + right_pad, ) return waveforms diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index 3956e2011..5fdd33771 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -42,13 +42,10 @@ def get_train_waveforms(self, world_size, rank, device): for key in f["waveforms"].keys(): waveforms.append(torch.Tensor(f["waveforms"][key][start:stop])) - if ( - self.right_pad - != f.attrs["duration"] - f.attrs["coalescence_time"] - ): + if self.right_pad != f.attrs["right_pad"]: raise ValueError( "Training and validation waveform files do not have " - "the same coalescence time and/or duration" + "the same right pad" ) self.train_waveforms = torch.stack(waveforms, dim=0).to(device) diff --git a/projects/train/train/data/waveforms/sampler.py b/projects/train/train/data/waveforms/sampler.py index 644a21d79..9562e17d7 100644 --- a/projects/train/train/data/waveforms/sampler.py +++ b/projects/train/train/data/waveforms/sampler.py @@ -42,7 +42,7 @@ def __init__( with h5py.File(val_waveform_file) as f: key = list(f["waveforms"].keys())[0] self.num_val_waveforms = len(f["waveforms"][key]) - self.right_pad = f.attrs["duration"] - f.attrs["coalescence_time"] + self.right_pad = f.attrs["right_pad"] def get_slice_bounds(self, total, world_size, rank) -> tuple[int, int]: """ From 1084f61a24280787712e83581fac3376686ce898 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Thu, 3 Jul 2025 11:40:27 -0700 Subject: [PATCH 12/32] Use ledger objects and remove unused functions/arguments --- projects/train/train/cli.py | 6 +++ projects/train/train/data/base.py | 21 ---------- .../train/data/waveforms/generator/cbc.py | 5 ++- projects/train/train/data/waveforms/loader.py | 25 +++++------ .../train/train/data/waveforms/sampler.py | 42 +++++++++---------- 5 files changed, 41 insertions(+), 58 deletions(-) diff --git a/projects/train/train/cli.py b/projects/train/train/cli.py index 4d9a035a0..b43e65467 100644 --- a/projects/train/train/cli.py +++ b/projects/train/train/cli.py @@ -51,6 +51,12 @@ def add_arguments_to_parser(self, parser): apply_on="parse", ) + parser.link_arguments( + "data.init_args.ifos", + "data.init_args.waveform_sampler.init_args.ifos", + apply_on="parse", + ) + def main(args=None): cli = AframeCLI( diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 2fb178b28..0e6f0df1f 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -14,12 +14,10 @@ from ml4gw.transforms import Whiten from ml4gw.utils.slicing import unfold_windows -from ledger.injections import WaveformSet, waveform_class_factory from train import augmentations as aug from train.data.utils import fs as fs_utils from train.metrics import get_timeslides from train.data.waveforms.sampler import WaveformSampler -from utils import x_per_y from utils.preprocessing import PsdEstimator Tensor = torch.Tensor @@ -310,15 +308,6 @@ def num_workers(self): # Utilities for initial data loading and preparation # ================================================ # - @property - def waveform_set_cls(self): - cls = waveform_class_factory( - self.hparams.ifos, - WaveformSet, - "WaveformSet", - ) - return cls - def prepare_data(self): """ Download s3 data if it doesn't exist. @@ -368,16 +357,6 @@ def slice_waveforms(self, waveforms: torch.Tensor) -> torch.Tensor: return waveforms - def get_slice_bounds(self, total, world_size, rank) -> tuple[int, int]: - """ - Figure which chunk of waveforms we should be - slicing given our rank and world size - """ - per_dev = x_per_y(abs(total), world_size) - start = rank * per_dev - stop = (rank + 1) * per_dev - return start, stop - def load_val_background(self, fnames: list[str]): self._logger.info("Loading validation background data") val_background = [] diff --git a/projects/train/train/data/waveforms/generator/cbc.py b/projects/train/train/data/waveforms/generator/cbc.py index cc150b9f0..fe8ea498b 100644 --- a/projects/train/train/data/waveforms/generator/cbc.py +++ b/projects/train/train/data/waveforms/generator/cbc.py @@ -13,6 +13,7 @@ def __init__( self, *args, approximant: Callable, + duration: float, f_min: float, f_ref: float, **kwargs, @@ -30,6 +31,8 @@ def __init__( A callable that takes parameter tensors and returns the cross and plus polarizations. For example, `ml4gw.waveforms.IMRPhenomD()` + duration: + Duration of the waveform in seconds f_min: Lowest frequency at which waveform signal content is generated @@ -45,7 +48,7 @@ def __init__( self.waveform_generator = TimeDomainCBCWaveformGenerator( approximant, self.sample_rate, - self.kernel_length, + duration, f_min, f_ref, self.right_pad, diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index 5fdd33771..678a61ede 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -1,6 +1,5 @@ from pathlib import Path -import h5py import torch from .sampler import WaveformSampler @@ -27,8 +26,13 @@ def __init__( super().__init__(*args, **kwargs) self.training_waveform_file = training_waveform_file - with h5py.File(training_waveform_file) as f: - self.num_train_waveforms = len(f["waveforms"]["cross"]) + waveform_set = self.waveform_set_cls.read(training_waveform_file) + if waveform_set.right_pad != self.right_pad: + raise ValueError( + "Training waveform file does not have the same " + "right pad as validation waveform file" + ) + self.num_train_waveforms = len(waveform_set) def get_train_waveforms(self, world_size, rank, device): """ @@ -37,18 +41,9 @@ def get_train_waveforms(self, world_size, rank, device): start, stop = self.get_slice_bounds( self.num_train_waveforms, world_size, rank ) - with h5py.File(self.val_waveform_file) as f: - waveforms = [] - for key in f["waveforms"].keys(): - waveforms.append(torch.Tensor(f["waveforms"][key][start:stop])) - - if self.right_pad != f.attrs["right_pad"]: - raise ValueError( - "Training and validation waveform files do not have " - "the same right pad" - ) - - self.train_waveforms = torch.stack(waveforms, dim=0).to(device) + waveform_set = self.waveform_set_cls.read(self.training_waveform_file) + waveforms = torch.Tensor(waveform_set.waveforms[start:stop]) + self.train_waveforms = waveforms.to(device) def sample(self, X: torch.Tensor): """ diff --git a/projects/train/train/data/waveforms/sampler.py b/projects/train/train/data/waveforms/sampler.py index 9562e17d7..ea87be969 100644 --- a/projects/train/train/data/waveforms/sampler.py +++ b/projects/train/train/data/waveforms/sampler.py @@ -1,9 +1,11 @@ from pathlib import Path +from typing import List -import h5py import torch from utils import x_per_y +from ledger.injections import WaveformSet, waveform_class_factory + Distribution = torch.distributions.Distribution @@ -12,12 +14,8 @@ class WaveformSampler(torch.nn.Module): Base object defining methods that waveform producing classes should implement. Should not be instantiated on its own. Args: - fduration: - Desired length in seconds of the time domain - response of the whitening filter built from PSDs. - See `ml4gw.spectral.truncate_inverse_power_spectrum` - kernel_length: - Length in seconds of window passed to neural network. + ifos: + List of interferometers that are being trained on. sample_rate: Sample rate in Hz of generated waveforms val_waveform_file: @@ -27,22 +25,28 @@ class WaveformSampler(torch.nn.Module): def __init__( self, *args, - fduration: float, - kernel_length: float, + ifos: List[str], sample_rate: float, val_waveform_file: Path, **kwargs, ) -> None: super().__init__(*args, **kwargs) - self.fduration = fduration - self.kernel_length = kernel_length + self.ifos = ifos self.sample_rate = sample_rate self.val_waveform_file = val_waveform_file - with h5py.File(val_waveform_file) as f: - key = list(f["waveforms"].keys())[0] - self.num_val_waveforms = len(f["waveforms"][key]) - self.right_pad = f.attrs["right_pad"] + waveform_set = self.waveform_set_cls.read(val_waveform_file) + self.num_val_waveforms = len(waveform_set) + self.right_pad = waveform_set.right_pad + + @property + def waveform_set_cls(self): + cls = waveform_class_factory( + self.ifos, + WaveformSet, + "WaveformSet", + ) + return cls def get_slice_bounds(self, total, world_size, rank) -> tuple[int, int]: """ @@ -63,12 +67,8 @@ def get_val_waveforms(self, world_size, rank): start, stop = self.get_slice_bounds( self.num_val_waveforms, world_size, rank ) - with h5py.File(self.val_waveform_file) as f: - waveforms = [] - for key in f["waveforms"].keys(): - waveforms.append(torch.Tensor(f["waveforms"][key][start:stop])) - - return torch.stack(waveforms, dim=0) + waveform_set = self.waveform_set_cls.read(self.val_waveform_file) + return torch.Tensor(waveform_set.waveforms[start:stop]) def get_test_waveforms(self): raise NotImplementedError From b99f4293f25f4162e802d0134b1605eaacc4053c Mon Sep 17 00:00:00 2001 From: William Benoit Date: Mon, 7 Jul 2025 06:45:45 -0700 Subject: [PATCH 13/32] Update batch file name in export task --- projects/train/train/data/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 0e6f0df1f..9f753b4ec 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -25,8 +25,8 @@ TransformedDist = torch.distributions.TransformedDistribution # TODO: -# Make the coalescence point/right pad parameterization consistent # Move waveform slicing to the waveform sampler? +# Make separate training prior # TODO: using this right now because From d9f3ccdc7beb710ae984fe99afd8834d1a2d60d6 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 8 Jul 2025 04:29:41 -0700 Subject: [PATCH 14/32] Fix validation issue and add constraint to training waveforms --- projects/train/train.yaml | 4 ++-- projects/train/train/callbacks.py | 7 +++---- projects/train/train/data/base.py | 15 ++++++--------- .../train/train/data/supervised/time_domain.py | 4 ++-- .../train/train/data/waveforms/generator/cbc.py | 10 ++++++++++ 5 files changed, 23 insertions(+), 17 deletions(-) diff --git a/projects/train/train.yaml b/projects/train/train.yaml index ffc4462e8..dd8660d7f 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -74,7 +74,7 @@ data: approximant: ml4gw.waveforms.IMRPhenomPv2 f_min: 20 f_ref: 40 - kernel_length: 8 + duration: 8 # Extrinsic parameter distributions dec: class_path: ml4gw.distributions.Cosine @@ -131,7 +131,7 @@ trainer: # strategy: set to ddp if len(devices) > 1 #precision: 16-mixed accelerator: auto - max_epochs: 400 + max_epochs: 200 check_val_every_n_epoch: 1 log_every_n_steps: 20 benchmark: true diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index fdd69cb7b..f755b81d2 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -85,14 +85,13 @@ def on_train_start(self, trainer, pl_module): X = (X,) # build val batch by hand - [background, _, _], [cross, plus] = next( + [background, _, _], [signals] = next( iter(trainer.datamodule.val_dataloader()) ) background = background.to(device) - cross = cross.to(device) - plus = plus.to(device) + signals = signals.to(device) X_bg, X_inj = trainer.datamodule.build_val_batches( - background, cross, plus + background, signals ) # Make background and injected validation data into # tuples for consistency if necessary diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 9f753b4ec..5d6291dc9 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -489,7 +489,7 @@ def on_after_batch_transfer(self, batch, _): # on the local device, the relevant tensors will be # empty, so just pass them through with a 0 shift to # indicate that this should be ignored - [background, _, timeslide_idx], [cross, plus] = batch + [background, _, timeslide_idx], [signals] = batch # If we're validating, unfold the background # data into a batch of overlapping kernels now that @@ -497,7 +497,7 @@ def on_after_batch_transfer(self, batch, _): # much data from CPU to GPU. Once everything is # on-device, pre-inject signals into background. shift = self.timeslides[timeslide_idx].shift_size - X_bg, X_fg = self.build_val_batches(background, cross, plus) + X_bg, X_fg = self.build_val_batches(background, signals) batch = (shift, X_bg, X_fg) return batch @@ -514,7 +514,9 @@ def inject(self, X): # ================================================ # @torch.no_grad() def build_val_batches( - self, background: Tensor, cross: Tensor, plus: Tensor + self, + background: Tensor, + signals: Tensor, ) -> tuple[Tensor, Tensor, Tensor]: """ Unfold a timeseries of background data @@ -524,8 +526,7 @@ def build_val_batches( Args: background: A tensor of background data - cross: A tensor of cross polarization waveforms - plus: A tensor of plus polarization waveforms + signals: A tensor of signals to inject Returns: raw strain background kernels, injected kernels, and psds @@ -539,10 +540,6 @@ def build_val_batches( # split data into kernel and psd data and estimate psd X, psd = self.psd_estimator(background) - # Sample sky locations and project polarizations - dec, psi, phi = self.sample_extrinsic(cross) - signals = self.projector(dec, psi, phi, cross=cross, plus=plus) - # sometimes at the end of a segment, there won't be # enough background kernels and so we'll have to inject # our signals on overlapping data and ditch some at the end diff --git a/projects/train/train/data/supervised/time_domain.py b/projects/train/train/data/supervised/time_domain.py index ddec2a4c3..35418b3be 100644 --- a/projects/train/train/data/supervised/time_domain.py +++ b/projects/train/train/data/supervised/time_domain.py @@ -4,8 +4,8 @@ class TimeDomainSupervisedAframeDataset(SupervisedAframeDataset): - def build_val_batches(self, background, cross, plus): - X_bg, X_inj, psds = super().build_val_batches(background, cross, plus) + def build_val_batches(self, background, signals): + X_bg, X_inj, psds = super().build_val_batches(background, signals) X_bg = self.whitener(X_bg, psds) # whiten each view of injections X_fg = [] diff --git a/projects/train/train/data/waveforms/generator/cbc.py b/projects/train/train/data/waveforms/generator/cbc.py index fe8ea498b..28cd6d288 100644 --- a/projects/train/train/data/waveforms/generator/cbc.py +++ b/projects/train/train/data/waveforms/generator/cbc.py @@ -1,9 +1,11 @@ from typing import Callable import torch +from bilby.core.prior import ConditionalPriorDict, Constraint from ml4gw.waveforms.generator import TimeDomainCBCWaveformGenerator from ledger.injections import BilbyParameterSet +from priors.utils import mass_constraints from .generator import WaveformGenerator @@ -43,6 +45,14 @@ def __init__( `aframe.train.train.data.waveforms.generator.WaveformGenerator` """ super().__init__(*args, **kwargs) + + # For CBC generation, need to make sure that the mass ratio + # does not exceed what ml4gw can handle. + self.training_prior = ConditionalPriorDict( + self.training_prior, conversion_function=mass_constraints + ) + self.training_prior["mass_ratio"] = Constraint(0.02, 0.999) + self.approximant = approximant self.f_ref = f_ref self.waveform_generator = TimeDomainCBCWaveformGenerator( From cb84e2b6b0ac9daa80b3c42e419d2fae40fb0117 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 12 Aug 2025 12:18:41 -0700 Subject: [PATCH 15/32] Add multimodal functionality --- aframe/tasks/export/export.py | 2 +- libs/utils/utils/preprocessing.py | 21 +++++++++++++++---- projects/train/train/callbacks.py | 4 +++- .../train/train/data/supervised/multimodal.py | 4 ++-- 4 files changed, 23 insertions(+), 8 deletions(-) diff --git a/aframe/tasks/export/export.py b/aframe/tasks/export/export.py index 5c0f2f9f9..3e4731a32 100644 --- a/aframe/tasks/export/export.py +++ b/aframe/tasks/export/export.py @@ -29,7 +29,7 @@ class ExportParams(law.Task): ) train_task = luigi.TaskParameter() platform = luigi.Parameter( - default="TENSORRT", + default="TORCHSCRIPT", description="Platform to use for exporting model for inference", ) diff --git a/libs/utils/utils/preprocessing.py b/libs/utils/utils/preprocessing.py index 19160fd33..948ef0869 100644 --- a/libs/utils/utils/preprocessing.py +++ b/libs/utils/utils/preprocessing.py @@ -111,12 +111,14 @@ def __init__( highpass: Optional[float] = None, lowpass: Optional[float] = None, return_whitened: bool = False, + return_asd: bool = False, ) -> None: super().__init__() self.stride_size = int(sample_rate / inference_sampling_rate) self.kernel_size = int(kernel_length * sample_rate) self.augmentor = augmentor self.return_whitened = return_whitened + self.return_asd = return_asd # do foreground length calculation in units of samples, # then convert back to length to guard for intification @@ -148,8 +150,14 @@ def forward(self, x: Tensor) -> Tensor: "but found shape {}".format(x.shape) ) - x, psd = self.psd_estimator(x) - whitened = self.whitener(x.double(), psd) + x, psd = self.psd_estimator(x.double()) + whitened = self.whitener(x, psd) + + x = x.float() + + asd = psd**0.5 + asd *= 1e23 + asd = asd.float() # unfold x and then put it into the expected shape. # Note that if x has both signal and background @@ -160,9 +168,14 @@ def forward(self, x: Tensor) -> Tensor: if self.augmentor is not None: x = self.augmentor(x) - if self.return_whitened: + if self.return_whitened and self.return_asd: + return x, whitened, asd + elif self.return_whitened: return x, whitened - return x + elif self.return_asd: + return x, asd + else: + return x class MultiModalPreprocessor(torch.nn.Module): diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index f755b81d2..b3c09a95d 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -90,7 +90,7 @@ def on_train_start(self, trainer, pl_module): ) background = background.to(device) signals = signals.to(device) - X_bg, X_inj = trainer.datamodule.build_val_batches( + X_bg, X_inj, val_asds = trainer.datamodule.build_val_batches( background, signals ) # Make background and injected validation data into @@ -108,6 +108,7 @@ def on_train_start(self, trainer, pl_module): for i, x in enumerate(X): h5file[f"input_{i}"] = x.cpu().numpy() h5file["y"] = y.cpu().numpy() + h5file["asds"] = train_asds.cpu().numpy() s3_file.write(f.getvalue()) with s3.open(f"{save_dir}/val_batch.hdf5", "wb") as s3_file: @@ -122,6 +123,7 @@ def on_train_start(self, trainer, pl_module): for i, x in enumerate(X): f[f"input_{i}"] = x.cpu().numpy() f["y"] = y.cpu().numpy() + f["asds"] = train_asds.cpu().numpy() with h5py.File( os.path.join(save_dir, "val_batch.hdf5"), "w" diff --git a/projects/train/train/data/supervised/multimodal.py b/projects/train/train/data/supervised/multimodal.py index 06f48746f..4c02b0c79 100644 --- a/projects/train/train/data/supervised/multimodal.py +++ b/projects/train/train/data/supervised/multimodal.py @@ -37,7 +37,7 @@ def on_after_batch_transfer(self, batch, _): # if we're training, perform random augmentations # on input data and use it to impact labels [X], waveforms = batch - (X, X_fft), y = self.augment(X, waveforms) + (X, X_fft), y = self.inject(X, waveforms) batch = (X, X_fft, y) elif self.trainer.validating or self.trainer.sanity_checking: # If we're in validation mode but we're not validating @@ -86,7 +86,7 @@ def compute_frequency_domain_data(self, X, psds): return X_fft - def augment(self, X, waveforms): + def inject(self, X, waveforms): X, y, psds = super().augment(X, waveforms) X = self.whitener(X, psds) X_fft = self.compute_frequency_domain_data(X, psds) From be0000153dbc10d201aea5e84f3da8efb9da74d6 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 15 Jul 2025 07:05:28 -0700 Subject: [PATCH 16/32] Switch to amplfi-style prior for training --- projects/train/configs/training_prior.yaml | 72 ++++++++++++++++ projects/train/conversion.py | 68 +++++++++++++++ projects/train/train.yaml | 2 +- .../train/data/waveforms/generator/cbc.py | 22 ----- .../data/waveforms/generator/generator.py | 23 ++--- projects/train/train/prior.py | 85 +++++++++++++++++++ 6 files changed, 234 insertions(+), 38 deletions(-) create mode 100644 projects/train/configs/training_prior.yaml create mode 100644 projects/train/conversion.py create mode 100644 projects/train/train/prior.py diff --git a/projects/train/configs/training_prior.yaml b/projects/train/configs/training_prior.yaml new file mode 100644 index 000000000..c7d2f97e0 --- /dev/null +++ b/projects/train/configs/training_prior.yaml @@ -0,0 +1,72 @@ +# default prior for use with `IMRPhenomPv2` +class_path: aframe.train.prior.AmplfiPrior +init_args: + conversion_function: aframe.train.conversion.precessing_to_lalsimulation_parameters + priors: + chirp_mass: + class_path: torch.distributions.Uniform + init_args: + low: 5 + high: 250 + validate_args: false + mass_ratio: + class_path: torch.distributions.Uniform + init_args: + low: 0.05 + high: 0.999 + validate_args: false + distance: + class_path: ml4gw.distributions.UniformComovingVolume + init_args: + minimum: 0 + maximum: 2 + distance_type: redshift + validate_args: false + inclination: + class_path: ml4gw.distributions.Sine + init_args: + low: 0 + high: 3.14159 + validate_args: false + phic: + class_path: torch.distributions.Uniform + init_args: + low: 0.0 + high: 6.28318 + validate_args: false + a_1: + class_path: torch.distributions.Uniform + init_args: + low: 0.0 + high: 0.999 + validate_args: false + a_2: + class_path: torch.distributions.Uniform + init_args: + low: 0.0 + high: 0.999 + validate_args: false + tilt_1: + class_path: torch.distributions.Uniform + init_args: + low: 0.0 + high: 3.14159 + validate_args: false + tilt_2: + class_path: torch.distributions.Uniform + init_args: + low: 0.0 + high: 3.14159 + validate_args: false + phi_jl: + class_path: torch.distributions.Uniform + init_args: + low: 0.0 + high: 6.28318 + validate_args: false + phi_12: + class_path: torch.distributions.Uniform + init_args: + low: 0.0 + high: 6.28318 + validate_args: false diff --git a/projects/train/conversion.py b/projects/train/conversion.py new file mode 100644 index 000000000..d12cb7bf5 --- /dev/null +++ b/projects/train/conversion.py @@ -0,0 +1,68 @@ +import torch +from ml4gw.waveforms.conversion import ( + bilby_spins_to_lalsim, + chirp_mass_and_mass_ratio_to_components, +) + + +def precessing_to_lalsimulation_parameters( + parameters: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + """ + Convert precessing spin parameters to lalsimulation parameters + """ + mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components( + parameters["chirp_mass"], parameters["mass_ratio"] + ) + + parameters["mass_1"] = mass_1 + parameters["mass_2"] = mass_2 + + # TODO: hard coding f_ref = 40 here b/c not sure best way to link this + # to the f_ref specified in the config file + incl, s1x, s1y, s1z, s2x, s2y, s2z = bilby_spins_to_lalsim( + parameters["inclination"], + parameters["phi_jl"], + parameters["tilt_1"], + parameters["tilt_2"], + parameters["phi_12"], + parameters["a_1"], + parameters["a_2"], + parameters["mass_1"], + parameters["mass_2"], + 40, + torch.zeros(len(mass_1), device=mass_1.device), + ) + + parameters["s1x"] = s1x + parameters["s1y"] = s1y + parameters["s1z"] = s1z + parameters["s2x"] = s2x + parameters["s2y"] = s2y + parameters["s2z"] = s2z + parameters["inclination"] = incl + return parameters + + +def aligned_to_lalsimulation_parameters( + parameters: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + """ + Convert aligned spin parameters to lalsimulation parameters + """ + mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components( + parameters["chirp_mass"], parameters["mass_ratio"] + ) + + parameters["mass_1"] = mass_1 + parameters["mass_2"] = mass_2 + + parameters["s1x"] = torch.zeros_like(mass_1) + parameters["s1y"] = torch.zeros_like(mass_1) + + parameters["s2x"] = torch.zeros_like(mass_1) + parameters["s2y"] = torch.zeros_like(mass_1) + + parameters["s1z"] = parameters["chi1"] + parameters["s2z"] = parameters["chi2"] + return parameters diff --git a/projects/train/train.yaml b/projects/train/train.yaml index dd8660d7f..007bfa67d 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -69,7 +69,7 @@ data: waveform_sampler: class_path: train.data.waveforms.generator.cbc.CBCGenerator init_args: - training_prior: priors.priors.end_o3_ratesandpops + training_prior: ./training_prior.yaml val_waveform_file: ${oc.env:AFRAME_TRAIN_DATA_DIR}/train/val_waveforms.hdf5 approximant: ml4gw.waveforms.IMRPhenomPv2 f_min: 20 diff --git a/projects/train/train/data/waveforms/generator/cbc.py b/projects/train/train/data/waveforms/generator/cbc.py index 28cd6d288..adb572752 100644 --- a/projects/train/train/data/waveforms/generator/cbc.py +++ b/projects/train/train/data/waveforms/generator/cbc.py @@ -1,12 +1,8 @@ from typing import Callable import torch -from bilby.core.prior import ConditionalPriorDict, Constraint from ml4gw.waveforms.generator import TimeDomainCBCWaveformGenerator -from ledger.injections import BilbyParameterSet -from priors.utils import mass_constraints - from .generator import WaveformGenerator @@ -46,13 +42,6 @@ def __init__( """ super().__init__(*args, **kwargs) - # For CBC generation, need to make sure that the mass ratio - # does not exceed what ml4gw can handle. - self.training_prior = ConditionalPriorDict( - self.training_prior, conversion_function=mass_constraints - ) - self.training_prior["mass_ratio"] = Constraint(0.02, 0.999) - self.approximant = approximant self.f_ref = f_ref self.waveform_generator = TimeDomainCBCWaveformGenerator( @@ -64,17 +53,6 @@ def __init__( self.right_pad, ) - def convert(self, parameters): - # TODO: This assumes a detector-frame prior. Remove this - # when we switch to source-frame prior. - for key in ["mass_1", "mass_2", "chirp_mass", "total_mass"]: - if key in parameters: - parameters[key] *= 1 + parameters["redshift"] - parameter_set = BilbyParameterSet(**parameters) - lal_params = parameter_set.convert_to_lal_param_set(self.f_ref) - generation_params = lal_params.ml4gw_generation_params - return generation_params - def forward(self, **parameters) -> torch.Tensor: hc, hp = self.waveform_generator(**parameters) waveforms = torch.stack([hc, hp], dim=1) diff --git a/projects/train/train/data/waveforms/generator/generator.py b/projects/train/train/data/waveforms/generator/generator.py index 374a0e367..30cfaa5f6 100644 --- a/projects/train/train/data/waveforms/generator/generator.py +++ b/projects/train/train/data/waveforms/generator/generator.py @@ -1,5 +1,6 @@ -from typing import Callable, TYPE_CHECKING +from typing import TYPE_CHECKING +from ....prior import AframePrior from ..sampler import WaveformSampler import torch @@ -12,18 +13,18 @@ class WaveformGenerator(WaveformSampler): def __init__( self, *args, - training_prior: Callable, + training_prior: AframePrior, **kwargs, ): """ A torch module for generating waveforms on the fly. Args: training_prior: - A callable that returns a prior distribution - for the parameters of the waveform generator. + A callable that takes an integer N and returns a + dictionary of parameter Tensors, each of length `N` """ super().__init__(*args, **kwargs) - self.training_prior, _ = training_prior() + self.training_prior = training_prior def get_train_waveforms(self, *_): """ @@ -34,17 +35,9 @@ def get_train_waveforms(self, *_): def sample(self, X: torch.Tensor): N = len(X) - parameters = self.training_prior.sample(N) - generation_params = self.convert(parameters) - generation_params = { - k: torch.Tensor(v).to(X.device) - for k, v in generation_params.items() - } - hc, hp = self(**generation_params) + parameters = self.training_prior(N, device=X.device) + hc, hp = self(**parameters) return hc, hp - def convert(self, parameters: dict) -> dict: - raise NotImplementedError - def forward(self): raise NotImplementedError diff --git a/projects/train/train/prior.py b/projects/train/train/prior.py new file mode 100644 index 000000000..d0c2b5b36 --- /dev/null +++ b/projects/train/train/prior.py @@ -0,0 +1,85 @@ +from typing import Callable, Optional + +import torch + + +class AframePrior: + def __init__( + self, + priors: dict[str, torch.distributions.Distribution], + conversion_function: Optional[Callable] = None, + ): + """ + A class for sampling parameters from a prior distribution + + Args: + priors: + A dictionary of parameter samplers that take an integer N + and return a tensor of shape (N, ...) representing + samples from the prior distribution + conversion_function: + A callable that takes a dictionary of sampled parameters + and returns a dictionary of waveform generation parameters + """ + super().__init__() + self.priors = priors + self.conversion_function = conversion_function or (lambda x: x) + + def __call__( + self, + N: int, + device: str = "cpu", + ) -> dict[str, torch.Tensor]: + """ + Generates random samples from the prior + + Args: + N: Number of samples to generate + device: Device to place the samples + """ + # sample parameters from prior + parameters = { + k: v.sample((N,)).to(device) for k, v in self.priors.items() + } + # perform any necessary conversions + # to from sampled parameters to + # waveform generation parameters + parameters = self.conversion_function(parameters) + return parameters + + def log_prob(self, samples: dict[str, torch.Tensor]) -> torch.Tensor: + """ + Calculate the log probability of samples under the prior + + Args: + samples: + Dictionary where key is parameter and + value is tensor of samples + """ + + first = samples[list(samples.keys())[0]] + log_probs = torch.ones(len(first), device=first.device) + for param, tensor in samples.items(): + log_probs += self.priors[param].log_prob(tensor).to(first.device) + return log_probs + + +class ParameterTransformer(torch.nn.Module): + """ + Helper class for applying preprocessing + transformations to inference parameters + """ + + def __init__(self, **transforms: Callable): + super().__init__() + self.transforms = transforms + + def forward( + self, + parameters: dict[str, torch.Tensor], + ): + # transform parameters + transformed = {k: v(parameters[k]) for k, v in self.transforms.items()} + # update parameter dict + parameters.update(transformed) + return parameters From 03d38b7953d3b828f7af53996b36391b6464b05c Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 15 Jul 2025 07:08:18 -0700 Subject: [PATCH 17/32] Update aframe-init to copy prior config --- scripts/aframe_init.py | 1 + 1 file changed, 1 insertion(+) diff --git a/scripts/aframe_init.py b/scripts/aframe_init.py index cad8a9db6..9c692c025 100644 --- a/scripts/aframe_init.py +++ b/scripts/aframe_init.py @@ -21,6 +21,7 @@ root / "aframe" / "pipelines" / "sandbox" / "configs" / "base.cfg", root / "projects" / "train" / "train.yaml", root / "projects" / "export" / "export.yaml", + root / "projects" / "train" / "configs" / "training_prior.yaml", ] REVIEW_CONFIGS = [ From f20a93fda596d52598d0ec4eb76b13ed8081b489 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 15 Jul 2025 09:20:28 -0700 Subject: [PATCH 18/32] Fix class paths and update ml4gw dep --- projects/train/configs/training_prior.yaml | 4 +- projects/train/pyproject.toml | 7 +- projects/train/{ => train}/conversion.py | 0 projects/train/uv.lock | 1026 +++++++++++++++++++- 4 files changed, 998 insertions(+), 39 deletions(-) rename projects/train/{ => train}/conversion.py (100%) diff --git a/projects/train/configs/training_prior.yaml b/projects/train/configs/training_prior.yaml index c7d2f97e0..12c66b5f6 100644 --- a/projects/train/configs/training_prior.yaml +++ b/projects/train/configs/training_prior.yaml @@ -1,7 +1,7 @@ # default prior for use with `IMRPhenomPv2` -class_path: aframe.train.prior.AmplfiPrior +class_path: train.prior.AframePrior init_args: - conversion_function: aframe.train.conversion.precessing_to_lalsimulation_parameters + conversion_function: train.conversion.precessing_to_lalsimulation_parameters priors: chirp_mass: class_path: torch.distributions.Uniform diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index c01346f3e..d47050cc4 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fsspec[s3]>=2024,<2025", "urllib3>=1.25.4,<1.27", "utils", - "ml4gw>=0.7.2", + "ml4gw>=0.7.5", "aframe", "ledger", "priors", @@ -37,7 +37,10 @@ dependencies = [ train = "train.cli:main" [dependency-groups] -dev = ["pytest~=7.3"] +dev = [ + "pytest~=7.3", + "jupyter>=1.0.0", +] [tool.uv] diff --git a/projects/train/conversion.py b/projects/train/train/conversion.py similarity index 100% rename from projects/train/conversion.py rename to projects/train/train/conversion.py diff --git a/projects/train/uv.lock b/projects/train/uv.lock index bf834ed7c..57dc9764c 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -220,6 +220,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/46/eb/e7f063ad1fec6b3178a3cd82d1a3c4de82cccf283fc42746168188e1cdd5/anyio-4.8.0-py3-none-any.whl", hash = "sha256:b5011f270ab5eb0abf13385f851315585cc37ef330dd88e27ec3d34d651fd47a", size = 96041 }, ] +[[package]] +name = "appnope" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/35/5d/752690df9ef5b76e169e68d6a129fa6d08a7100ca7f754c89495db3c6019/appnope-0.1.4.tar.gz", hash = "sha256:1de3860566df9caf38f01f86f65e0e13e379af54f9e4bee1e66b48f2efffd1ee", size = 4170 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/29/5ecc3a15d5a33e31b26c11426c45c501e439cb865d0bff96315d86443b78/appnope-0.1.4-py2.py3-none-any.whl", hash = "sha256:502575ee11cd7a28c0205f379b525beefebab9d161b7c964670864014ed7213c", size = 4321 }, +] + [[package]] name = "architectures" version = "0.1.0" @@ -242,6 +251,52 @@ requires-dist = [ [package.metadata.requires-dev] dev = [{ name = "pytest", specifier = ">=8.2.1,<9" }] +[[package]] +name = "argon2-cffi" +version = "25.1.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "argon2-cffi-bindings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0e/89/ce5af8a7d472a67cc819d5d998aa8c82c5d860608c4db9f46f1162d7dab9/argon2_cffi-25.1.0.tar.gz", hash = "sha256:694ae5cc8a42f4c4e2bf2ca0e64e51e23a040c6a517a85074683d3959e1346c1", size = 45706 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4f/d3/a8b22fa575b297cd6e3e3b0155c7e25db170edf1c74783d6a31a2490b8d9/argon2_cffi-25.1.0-py3-none-any.whl", hash = "sha256:fdc8b074db390fccb6eb4a3604ae7231f219aa669a2652e0f20e16ba513d5741", size = 14657 }, +] + +[[package]] +name = "argon2-cffi-bindings" +version = "21.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b9/e9/184b8ccce6683b0aa2fbb7ba5683ea4b9c5763f1356347f1312c32e3c66e/argon2-cffi-bindings-21.2.0.tar.gz", hash = "sha256:bb89ceffa6c791807d1305ceb77dbfacc5aa499891d2c55661c6459651fc39e3", size = 1779911 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d4/13/838ce2620025e9666aa8f686431f67a29052241692a3dd1ae9d3692a89d3/argon2_cffi_bindings-21.2.0-cp36-abi3-macosx_10_9_x86_64.whl", hash = "sha256:ccb949252cb2ab3a08c02024acb77cfb179492d5701c7cbdbfd776124d4d2367", size = 29658 }, + { url = "https://files.pythonhosted.org/packages/b3/02/f7f7bb6b6af6031edb11037639c697b912e1dea2db94d436e681aea2f495/argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9524464572e12979364b7d600abf96181d3541da11e23ddf565a32e70bd4dc0d", size = 80583 }, + { url = "https://files.pythonhosted.org/packages/ec/f7/378254e6dd7ae6f31fe40c8649eea7d4832a42243acaf0f1fff9083b2bed/argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b746dba803a79238e925d9046a63aa26bf86ab2a2fe74ce6b009a1c3f5c8f2ae", size = 86168 }, + { url = "https://files.pythonhosted.org/packages/74/f6/4a34a37a98311ed73bb80efe422fed95f2ac25a4cacc5ae1d7ae6a144505/argon2_cffi_bindings-21.2.0-cp36-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:58ed19212051f49a523abb1dbe954337dc82d947fb6e5a0da60f7c8471a8476c", size = 82709 }, + { url = "https://files.pythonhosted.org/packages/74/2b/73d767bfdaab25484f7e7901379d5f8793cccbb86c6e0cbc4c1b96f63896/argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:bd46088725ef7f58b5a1ef7ca06647ebaf0eb4baff7d1d0d177c6cc8744abd86", size = 83613 }, + { url = "https://files.pythonhosted.org/packages/4f/fd/37f86deef67ff57c76f137a67181949c2d408077e2e3dd70c6c42912c9bf/argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_i686.whl", hash = "sha256:8cd69c07dd875537a824deec19f978e0f2078fdda07fd5c42ac29668dda5f40f", size = 84583 }, + { url = "https://files.pythonhosted.org/packages/6f/52/5a60085a3dae8fded8327a4f564223029f5f54b0cb0455a31131b5363a01/argon2_cffi_bindings-21.2.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f1152ac548bd5b8bcecfb0b0371f082037e47128653df2e8ba6e914d384f3c3e", size = 88475 }, + { url = "https://files.pythonhosted.org/packages/8b/95/143cd64feb24a15fa4b189a3e1e7efbaeeb00f39a51e99b26fc62fbacabd/argon2_cffi_bindings-21.2.0-cp36-abi3-win32.whl", hash = "sha256:603ca0aba86b1349b147cab91ae970c63118a0f30444d4bc80355937c950c082", size = 27698 }, + { url = "https://files.pythonhosted.org/packages/37/2c/e34e47c7dee97ba6f01a6203e0383e15b60fb85d78ac9a15cd066f6fe28b/argon2_cffi_bindings-21.2.0-cp36-abi3-win_amd64.whl", hash = "sha256:b2ef1c30440dbbcba7a5dc3e319408b59676e2e039e2ae11a8775ecf482b192f", size = 30817 }, + { url = "https://files.pythonhosted.org/packages/5a/e4/bf8034d25edaa495da3c8a3405627d2e35758e44ff6eaa7948092646fdcc/argon2_cffi_bindings-21.2.0-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:e415e3f62c8d124ee16018e491a009937f8cf7ebf5eb430ffc5de21b900dad93", size = 53104 }, +] + +[[package]] +name = "arrow" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "python-dateutil" }, + { name = "types-python-dateutil" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/2e/00/0f6e8fcdb23ea632c866620cc872729ff43ed91d284c866b515c6342b173/arrow-1.3.0.tar.gz", hash = "sha256:d4540617648cb5f895730f1ad8c82a65f2dad0166f57b75f3ca54759c4d67a85", size = 131960 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/ed/e97229a566617f2ae958a6b13e7cc0f585470eac730a73e9e82c32a3cdd2/arrow-1.3.0-py3-none-any.whl", hash = "sha256:c728b120ebc00eb84e01882a6f5e7927a53960aa990ce7dd2b10f39005a67f80", size = 66419 }, +] + [[package]] name = "astropy" version = "6.1.7" @@ -287,6 +342,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/14/28e7254183bb53502b323b7e170eaeb86501cc774e76f82ef04216699afb/astropy_iers_data-0.2025.2.24.0.34.4-py3-none-any.whl", hash = "sha256:be8b3b75b09b1aa1d22de9b5243854e00d19936aca6d8f64466d16197c04bb28", size = 1946502 }, ] +[[package]] +name = "asttokens" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/4a/e7/82da0a03e7ba5141f05cce0d302e6eed121ae055e0456ca228bf693984bc/asttokens-3.0.0.tar.gz", hash = "sha256:0dcd8baa8d62b0c1d118b399b2ddba3c4aff271d0d7a9e0d4c1681c79035bbc7", size = 61978 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/25/8a/c46dcc25341b5bce5472c718902eb3d38600a903b14fa6aeecef3f21a46f/asttokens-3.0.0-py3-none-any.whl", hash = "sha256:e3078351a059199dd5138cb1c706e6430c05eff2ff136af5eb4790f9d28932e2", size = 26918 }, +] + +[[package]] +name = "async-lru" +version = "2.0.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b2/4d/71ec4d3939dc755264f680f6c2b4906423a304c3d18e96853f0a595dfe97/async_lru-2.0.5.tar.gz", hash = "sha256:481d52ccdd27275f42c43a928b4a50c3bfb2d67af4e78b170e3e0bb39c66e5bb", size = 10380 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/03/49/d10027df9fce941cb8184e78a02857af36360d33e1721df81c5ed2179a1a/async_lru-2.0.5-py3-none-any.whl", hash = "sha256:ab95404d8d2605310d345932697371a5f40def0487c03d6d0ad9138de52c9943", size = 6069 }, +] + [[package]] name = "async-timeout" version = "5.0.1" @@ -305,6 +381,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fc/30/d4986a882011f9df997a55e6becd864812ccfcd821d64aac8570ee39f719/attrs-25.1.0-py3-none-any.whl", hash = "sha256:c75a69e28a550a7e93789579c22aa26b0f5b83b75dc4e08fe092980051e1090a", size = 63152 }, ] +[[package]] +name = "babel" +version = "2.17.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7d/6b/d52e42361e1aa00709585ecc30b3f9684b3ab62530771402248b1b1d6240/babel-2.17.0.tar.gz", hash = "sha256:0c54cffb19f690cdcc52a3b50bcbf71e07a808d1c80d549f2459b9d2cf0afb9d", size = 9951852 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537 }, +] + [[package]] name = "bayesian-optimization" version = "1.5.1" @@ -376,6 +461,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/de/02429d598ec5ed4c70113a2c3e8b76a5b113885f85eacdcdaf19cbb6d23d/bilby.cython-0.5.3-cp312-cp312-win_amd64.whl", hash = "sha256:2758256339d7c3703014b265d3a77e0299d5c6264f962bc311c989ac453cbd60", size = 357801 }, ] +[[package]] +name = "bleach" +version = "6.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/76/9a/0e33f5054c54d349ea62c277191c020c2d6ef1d65ab2cb1993f91ec846d1/bleach-6.2.0.tar.gz", hash = "sha256:123e894118b8a599fd80d3ec1a6d4cc7ce4e5882b1317a7e1ba69b56e95f991f", size = 203083 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/55/96142937f66150805c25c4d0f31ee4132fd33497753400734f9dfdcbdc66/bleach-6.2.0-py3-none-any.whl", hash = "sha256:117d9c6097a7c3d22fd578fcd8d35ff1e125df6736f554da4e432fdd63f31e5e", size = 163406 }, +] + +[package.optional-dependencies] +css = [ + { name = "tinycss2" }, +] + [[package]] name = "bokeh" version = "3.6.3" @@ -639,6 +741,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b3/61/39e7db0cb326c9c8f6a49fad4fc9c2f1241f05a4e10f0643fc31ce26a7e0/colorful-0.5.6-py2.py3-none-any.whl", hash = "sha256:eab8c1c809f5025ad2b5238a50bd691e26850da8cac8f90d660ede6ea1af9f1e", size = 201369 }, ] +[[package]] +name = "comm" +version = "0.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/a8/fb783cb0abe2b5fded9f55e5703015cdf1c9c85b3669087c538dd15a6a86/comm-0.2.2.tar.gz", hash = "sha256:3fd7a84065306e07bea1773df6eb8282de51ba82f77c72f9c85716ab11fe980e", size = 6210 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/75/49e5bfe642f71f272236b5b2d2691cf915a7283cc0ceda56357b61daa538/comm-0.2.2-py3-none-any.whl", hash = "sha256:e6fb86cb70ff661ee8c9c14e7d36d6de3b4066f1441be4063df9c5009f0a64d3", size = 7180 }, +] + [[package]] name = "contourpy" version = "1.3.1" @@ -807,6 +921,45 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/27/6b/7c87867d255cbce8167ed99fc65635e9395d2af0f0c915428f5b17ec412d/Cython-3.0.12-py2.py3-none-any.whl", hash = "sha256:0038c9bae46c459669390e53a1ec115f8096b2e4647ae007ff1bf4e6dee92806", size = 1171640 }, ] +[[package]] +name = "debugpy" +version = "1.8.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/bd/75/087fe07d40f490a78782ff3b0a30e3968936854105487decdb33446d4b0e/debugpy-1.8.14.tar.gz", hash = "sha256:7cd287184318416850aa8b60ac90105837bb1e59531898c07569d197d2ed5322", size = 1641444 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fc/df/156df75a41aaebd97cee9d3870fe68f8001b6c1c4ca023e221cfce69bece/debugpy-1.8.14-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:93fee753097e85623cab1c0e6a68c76308cd9f13ffdf44127e6fab4fbf024339", size = 2076510 }, + { url = "https://files.pythonhosted.org/packages/69/cd/4fc391607bca0996db5f3658762106e3d2427beaef9bfd363fd370a3c054/debugpy-1.8.14-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d937d93ae4fa51cdc94d3e865f535f185d5f9748efb41d0d49e33bf3365bd79", size = 3559614 }, + { url = "https://files.pythonhosted.org/packages/1a/42/4e6d2b9d63e002db79edfd0cb5656f1c403958915e0e73ab3e9220012eec/debugpy-1.8.14-cp310-cp310-win32.whl", hash = "sha256:c442f20577b38cc7a9aafecffe1094f78f07fb8423c3dddb384e6b8f49fd2987", size = 5208588 }, + { url = "https://files.pythonhosted.org/packages/97/b1/cc9e4e5faadc9d00df1a64a3c2d5c5f4b9df28196c39ada06361c5141f89/debugpy-1.8.14-cp310-cp310-win_amd64.whl", hash = "sha256:f117dedda6d969c5c9483e23f573b38f4e39412845c7bc487b6f2648df30fe84", size = 5241043 }, + { url = "https://files.pythonhosted.org/packages/67/e8/57fe0c86915671fd6a3d2d8746e40485fd55e8d9e682388fbb3a3d42b86f/debugpy-1.8.14-cp311-cp311-macosx_14_0_universal2.whl", hash = "sha256:1b2ac8c13b2645e0b1eaf30e816404990fbdb168e193322be8f545e8c01644a9", size = 2175064 }, + { url = "https://files.pythonhosted.org/packages/3b/97/2b2fd1b1c9569c6764ccdb650a6f752e4ac31be465049563c9eb127a8487/debugpy-1.8.14-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cf431c343a99384ac7eab2f763980724834f933a271e90496944195318c619e2", size = 3132359 }, + { url = "https://files.pythonhosted.org/packages/c0/ee/b825c87ed06256ee2a7ed8bab8fb3bb5851293bf9465409fdffc6261c426/debugpy-1.8.14-cp311-cp311-win32.whl", hash = "sha256:c99295c76161ad8d507b413cd33422d7c542889fbb73035889420ac1fad354f2", size = 5133269 }, + { url = "https://files.pythonhosted.org/packages/d5/a6/6c70cd15afa43d37839d60f324213843174c1d1e6bb616bd89f7c1341bac/debugpy-1.8.14-cp311-cp311-win_amd64.whl", hash = "sha256:7816acea4a46d7e4e50ad8d09d963a680ecc814ae31cdef3622eb05ccacf7b01", size = 5158156 }, + { url = "https://files.pythonhosted.org/packages/d9/2a/ac2df0eda4898f29c46eb6713a5148e6f8b2b389c8ec9e425a4a1d67bf07/debugpy-1.8.14-cp312-cp312-macosx_14_0_universal2.whl", hash = "sha256:8899c17920d089cfa23e6005ad9f22582fd86f144b23acb9feeda59e84405b84", size = 2501268 }, + { url = "https://files.pythonhosted.org/packages/10/53/0a0cb5d79dd9f7039169f8bf94a144ad3efa52cc519940b3b7dde23bcb89/debugpy-1.8.14-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f6bb5c0dcf80ad5dbc7b7d6eac484e2af34bdacdf81df09b6a3e62792b722826", size = 4221077 }, + { url = "https://files.pythonhosted.org/packages/f8/d5/84e01821f362327bf4828728aa31e907a2eca7c78cd7c6ec062780d249f8/debugpy-1.8.14-cp312-cp312-win32.whl", hash = "sha256:281d44d248a0e1791ad0eafdbbd2912ff0de9eec48022a5bfbc332957487ed3f", size = 5255127 }, + { url = "https://files.pythonhosted.org/packages/33/16/1ed929d812c758295cac7f9cf3dab5c73439c83d9091f2d91871e648093e/debugpy-1.8.14-cp312-cp312-win_amd64.whl", hash = "sha256:5aa56ef8538893e4502a7d79047fe39b1dae08d9ae257074c6464a7b290b806f", size = 5297249 }, + { url = "https://files.pythonhosted.org/packages/97/1a/481f33c37ee3ac8040d3d51fc4c4e4e7e61cb08b8bc8971d6032acc2279f/debugpy-1.8.14-py2.py3-none-any.whl", hash = "sha256:5cd9a579d553b6cb9759a7908a41988ee6280b961f24f63336835d9418216a20", size = 5256230 }, +] + +[[package]] +name = "decorator" +version = "5.2.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/43/fa/6d96a0978d19e17b68d634497769987b16c8f4cd0a7a05048bec693caa6b/decorator-5.2.1.tar.gz", hash = "sha256:65f266143752f734b0a7cc83c46f4618af75b8c5911b00ccb61d0ac9b6da0360", size = 56711 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4e/8c/f3147f5c4b73e7550fe5f9352eaa956ae838d5c51eb58e7a25b9f3e2643b/decorator-5.2.1-py3-none-any.whl", hash = "sha256:d316bb415a2d9e2d2b3abcc4084c6502fc09240e292cd76a76afc106a1c8e04a", size = 9190 }, +] + +[[package]] +name = "defusedxml" +version = "0.7.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0f/d5/c66da9b79e5bdb124974bfe172b4daf3c984ebd9c2a06e2b8a4dc7331c72/defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69", size = 75520 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/6c/aa3f2f849e01cb6a001cd8554a88d4c77c5c1a31c95bdf1cf9301e6d9ef4/defusedxml-0.7.1-py2.py3-none-any.whl", hash = "sha256:a352e7e428770286cc899e2542b6cdaedb2b4953ff269a210103ec58f6198a61", size = 25604 }, +] + [[package]] name = "dill" version = "0.4.0" @@ -885,6 +1038,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b", size = 16453 }, ] +[[package]] +name = "executing" +version = "2.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/91/50/a9d80c47ff289c611ff12e63f7c5d13942c65d68125160cefd768c73e6e4/executing-2.2.0.tar.gz", hash = "sha256:5d108c028108fe2551d1a7b2e8b713341e2cb4fc0aa7dcf966fa4327a5226755", size = 978693 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, +] + +[[package]] +name = "fastjsonschema" +version = "2.21.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8b/50/4b769ce1ac4071a1ef6d86b1a3fb56cdc3a37615e8c5519e1af96cdac366/fastjsonschema-2.21.1.tar.gz", hash = "sha256:794d4f0a58f848961ba16af7b9c85a3e88cd360df008c59aac6fc5ae9323b5d4", size = 373939 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/90/2b/0817a2b257fe88725c25589d89aec060581aabf668707a8d03b2e9e0cb2a/fastjsonschema-2.21.1-py3-none-any.whl", hash = "sha256:c9e5b7e908310918cf494a434eeb31384dd84a98b57a30bcb1f535015b554667", size = 23924 }, +] + [[package]] name = "filelock" version = "3.17.0" @@ -927,6 +1098,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/bf/ff/44934a031ce5a39125415eb405b9efb76fe7f9586b75291d66ae5cbfc4e6/fonttools-4.56.0-py3-none-any.whl", hash = "sha256:1088182f68c303b50ca4dc0c82d42083d176cba37af1937e1a976a31149d4d14", size = 1089800 }, ] +[[package]] +name = "fqdn" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/30/3e/a80a8c077fd798951169626cde3e239adeba7dab75deb3555716415bd9b0/fqdn-1.5.1.tar.gz", hash = "sha256:105ed3677e767fb5ca086a0c1f4bb66ebc3c100be518f0e0d755d9eae164d89f", size = 6015 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cf/58/8acf1b3e91c58313ce5cb67df61001fc9dcd21be4fadb76c1a2d540e09ed/fqdn-1.5.1-py3-none-any.whl", hash = "sha256:3a179af3761e4df6eb2e026ff9e1a3033d3587bf980a0b1b2e1e5d08d7358014", size = 9121 }, +] + [[package]] name = "frozenlist" version = "1.5.0" @@ -1327,6 +1507,129 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ef/a6/62565a6e1cf69e10f5727360368e451d4b7f58beeac6173dc9db836a5b46/iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374", size = 5892 }, ] +[[package]] +name = "ipykernel" +version = "6.29.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "appnope", marker = "sys_platform == 'darwin'" }, + { name = "comm" }, + { name = "debugpy" }, + { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "matplotlib-inline" }, + { name = "nest-asyncio" }, + { name = "packaging" }, + { name = "psutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e9/5c/67594cb0c7055dc50814b21731c22a601101ea3b1b50a9a1b090e11f5d0f/ipykernel-6.29.5.tar.gz", hash = "sha256:f093a22c4a40f8828f8e330a9c297cb93dcab13bd9678ded6de8e5cf81c56215", size = 163367 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/94/5c/368ae6c01c7628438358e6d337c19b05425727fbb221d2a3c4303c372f42/ipykernel-6.29.5-py3-none-any.whl", hash = "sha256:afdb66ba5aa354b09b91379bac28ae4afebbb30e8b39510c9690afb7a10421b5", size = 117173 }, +] + +[[package]] +name = "ipython" +version = "8.37.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version < '3.11' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version < '3.11' and sys_platform == 'linux'", + "(python_full_version < '3.11' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", +] +dependencies = [ + { name = "colorama", marker = "python_full_version < '3.11' and sys_platform == 'win32'" }, + { name = "decorator", marker = "python_full_version < '3.11'" }, + { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, + { name = "jedi", marker = "python_full_version < '3.11'" }, + { name = "matplotlib-inline", marker = "python_full_version < '3.11'" }, + { name = "pexpect", marker = "python_full_version < '3.11' and sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit", marker = "python_full_version < '3.11'" }, + { name = "pygments", marker = "python_full_version < '3.11'" }, + { name = "stack-data", marker = "python_full_version < '3.11'" }, + { name = "traitlets", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/31/10ac88f3357fc276dc8a64e8880c82e80e7459326ae1d0a211b40abf6665/ipython-8.37.0.tar.gz", hash = "sha256:ca815841e1a41a1e6b73a0b08f3038af9b2252564d01fc405356d34033012216", size = 5606088 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/91/d0/274fbf7b0b12643cbbc001ce13e6a5b1607ac4929d1b11c72460152c9fc3/ipython-8.37.0-py3-none-any.whl", hash = "sha256:ed87326596b878932dbcb171e3e698845434d8c61b8d8cd474bf663041a9dcf2", size = 831864 }, +] + +[[package]] +name = "ipython" +version = "9.4.0" +source = { registry = "https://pypi.org/simple" } +resolution-markers = [ + "python_full_version >= '3.12' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version >= '3.12' and sys_platform == 'linux'", + "(python_full_version >= '3.12' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version == '3.11.*' and platform_machine == 'x86_64' and sys_platform == 'darwin'", + "python_full_version == '3.11.*' and sys_platform == 'linux'", + "(python_full_version == '3.11.*' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", +] +dependencies = [ + { name = "colorama", marker = "python_full_version >= '3.11' and sys_platform == 'win32'" }, + { name = "decorator", marker = "python_full_version >= '3.11'" }, + { name = "ipython-pygments-lexers", marker = "python_full_version >= '3.11'" }, + { name = "jedi", marker = "python_full_version >= '3.11'" }, + { name = "matplotlib-inline", marker = "python_full_version >= '3.11'" }, + { name = "pexpect", marker = "python_full_version >= '3.11' and sys_platform != 'emscripten' and sys_platform != 'win32'" }, + { name = "prompt-toolkit", marker = "python_full_version >= '3.11'" }, + { name = "pygments", marker = "python_full_version >= '3.11'" }, + { name = "stack-data", marker = "python_full_version >= '3.11'" }, + { name = "traitlets", marker = "python_full_version >= '3.11'" }, + { name = "typing-extensions", marker = "python_full_version == '3.11.*'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/80/406f9e3bde1c1fd9bf5a0be9d090f8ae623e401b7670d8f6fdf2ab679891/ipython-9.4.0.tar.gz", hash = "sha256:c033c6d4e7914c3d9768aabe76bbe87ba1dc66a92a05db6bfa1125d81f2ee270", size = 4385338 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/63/f8/0031ee2b906a15a33d6bfc12dd09c3dfa966b3cb5b284ecfb7549e6ac3c4/ipython-9.4.0-py3-none-any.whl", hash = "sha256:25850f025a446d9b359e8d296ba175a36aedd32e83ca9b5060430fe16801f066", size = 611021 }, +] + +[[package]] +name = "ipython-pygments-lexers" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pygments", marker = "python_full_version >= '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ef/4c/5dd1d8af08107f88c7f741ead7a40854b8ac24ddf9ae850afbcf698aa552/ipython_pygments_lexers-1.1.1.tar.gz", hash = "sha256:09c0138009e56b6854f9535736f4171d855c8c08a563a0dcd8022f78355c7e81", size = 8393 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d9/33/1f075bf72b0b747cb3288d011319aaf64083cf2efef8354174e3ed4540e2/ipython_pygments_lexers-1.1.1-py3-none-any.whl", hash = "sha256:a9462224a505ade19a605f71f8fa63c2048833ce50abc86768a0d81d876dc81c", size = 8074 }, +] + +[[package]] +name = "ipywidgets" +version = "8.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "comm" }, + { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "jupyterlab-widgets" }, + { name = "traitlets" }, + { name = "widgetsnbextension" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3e/48/d3dbac45c2814cb73812f98dd6b38bbcc957a4e7bb31d6ea9c03bf94ed87/ipywidgets-8.1.7.tar.gz", hash = "sha256:15f1ac050b9ccbefd45dccfbb2ef6bed0029d8278682d569d71b8dd96bee0376", size = 116721 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/6a/9166369a2f092bd286d24e6307de555d63616e8ddb373ebad2b5635ca4cd/ipywidgets-8.1.7-py3-none-any.whl", hash = "sha256:764f2602d25471c213919b8a1997df04bef869251db4ca8efba1b76b1bd9f7bb", size = 139806 }, +] + +[[package]] +name = "isoduration" +version = "20.11.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "arrow" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7c/1a/3c8edc664e06e6bd06cce40c6b22da5f1429aa4224d0c590f3be21c91ead/isoduration-20.11.0.tar.gz", hash = "sha256:ac2f9015137935279eac671f94f89eb00584f940f5dc49462a0c4ee692ba1bd9", size = 11649 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/55/e5326141505c5d5e34c5e0935d2908a74e4561eca44108fbfb9c13d2911a/isoduration-20.11.0-py3-none-any.whl", hash = "sha256:b2904c2a4228c3d44f409c8ae8e2370eb21a26f7ac2ec5446df141dde3452042", size = 11321 }, +] + [[package]] name = "jaxtyping" version = "0.2.38" @@ -1339,6 +1642,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/db/7e/da7b57a1f3af7303a0f3c8594d820fc0d3a9bbe3810a357eb21eb166e76b/jaxtyping-0.2.38-py3-none-any.whl", hash = "sha256:bc209ab8ec29917b6f0c7dec4a8ea1fc276f7d94f25b71c01d1243ec2b21ae12", size = 56375 }, ] +[[package]] +name = "jedi" +version = "0.19.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "parso" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/3a/79a912fbd4d8dd6fbb02bf69afd3bb72cf0c729bb3063c6f4498603db17a/jedi-0.19.2.tar.gz", hash = "sha256:4770dc3de41bde3966b02eb84fbcf557fb33cce26ad23da12c742fb50ecb11f0", size = 1231287 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c0/5a/9cac0c82afec3d09ccd97c8b6502d48f165f9124db81b4bcb90b4af974ee/jedi-0.19.2-py2.py3-none-any.whl", hash = "sha256:a8ef22bde8490f57fe5c7681a3c83cb58874daf72b4784de3cce5b6ef6edb5b9", size = 1572278 }, +] + [[package]] name = "jinja2" version = "3.1.5" @@ -1369,6 +1684,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/91/29/df4b9b42f2be0b623cbd5e2140cafcaa2bef0759a00b7b70104dcfe2fb51/joblib-1.4.2-py3-none-any.whl", hash = "sha256:06d478d5674cbc267e7496a410ee875abd68e4340feff4490bcb7afb88060ae6", size = 301817 }, ] +[[package]] +name = "json5" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/12/be/c6c745ec4c4539b25a278b70e29793f10382947df0d9efba2fa09120895d/json5-0.12.0.tar.gz", hash = "sha256:0b4b6ff56801a1c7dc817b0241bca4ce474a0e6a163bfef3fc594d3fd263ff3a", size = 51907 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/9f/3500910d5a98549e3098807493851eeef2b89cdd3032227558a104dfe926/json5-0.12.0-py3-none-any.whl", hash = "sha256:6d37aa6c08b0609f16e1ec5ff94697e2cbbfbad5ac112afa05794da9ab7810db", size = 36079 }, +] + [[package]] name = "jsonargparse" version = "4.37.0" @@ -1387,6 +1711,15 @@ signatures = [ { name = "typeshed-client" }, ] +[[package]] +name = "jsonpointer" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6a/0a/eebeb1fa92507ea94016a2a790b93c2ae41a7e18778f85471dc54475ed25/jsonpointer-3.0.0.tar.gz", hash = "sha256:2b2d729f2091522d61c3b31f82e11870f60b68f43fbc705cb76bf4b832af59ef", size = 9114 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/71/92/5e77f98553e9e75130c78900d000368476aed74276eb8ae8796f65f00918/jsonpointer-3.0.0-py2.py3-none-any.whl", hash = "sha256:13e088adc14fca8b6aa8177c044e12701e6ad4b28ff10e65f2267a90109c9942", size = 7595 }, +] + [[package]] name = "jsonschema" version = "4.23.0" @@ -1402,6 +1735,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/69/4a/4f9dbeb84e8850557c02365a0eee0649abe5eb1d84af92a25731c6c0f922/jsonschema-4.23.0-py3-none-any.whl", hash = "sha256:fbadb6f8b144a8f8cf9f0b89ba94501d143e50411a1278633f56a7acf7fd5566", size = 88462 }, ] +[package.optional-dependencies] +format-nongpl = [ + { name = "fqdn" }, + { name = "idna" }, + { name = "isoduration" }, + { name = "jsonpointer" }, + { name = "rfc3339-validator" }, + { name = "rfc3986-validator" }, + { name = "uri-template" }, + { name = "webcolors" }, +] + [[package]] name = "jsonschema-specifications" version = "2024.10.1" @@ -1414,6 +1759,208 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/0f/8910b19ac0670a0f80ce1008e5e751c4a57e14d2c4c13a482aa6079fa9d6/jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf", size = 18459 }, ] +[[package]] +name = "jupyter" +version = "1.1.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ipykernel" }, + { name = "ipywidgets" }, + { name = "jupyter-console" }, + { name = "jupyterlab" }, + { name = "nbconvert" }, + { name = "notebook" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/58/f3/af28ea964ab8bc1e472dba2e82627d36d470c51f5cd38c37502eeffaa25e/jupyter-1.1.1.tar.gz", hash = "sha256:d55467bceabdea49d7e3624af7e33d59c37fff53ed3a350e1ac957bed731de7a", size = 5714959 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/38/64/285f20a31679bf547b75602702f7800e74dbabae36ef324f716c02804753/jupyter-1.1.1-py2.py3-none-any.whl", hash = "sha256:7a59533c22af65439b24bbe60373a4e95af8f16ac65a6c00820ad378e3f7cc83", size = 2657 }, +] + +[[package]] +name = "jupyter-client" +version = "8.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-core" }, + { name = "python-dateutil" }, + { name = "pyzmq" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/22/bf9f12fdaeae18019a468b68952a60fe6dbab5d67cd2a103cac7659b41ca/jupyter_client-8.6.3.tar.gz", hash = "sha256:35b3a0947c4a6e9d589eb97d7d4cd5e90f910ee73101611f01283732bd6d9419", size = 342019 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/11/85/b0394e0b6fcccd2c1eeefc230978a6f8cb0c5df1e4cd3e7625735a0d7d1e/jupyter_client-8.6.3-py3-none-any.whl", hash = "sha256:e8a19cc986cc45905ac3362915f410f3af85424b4c0905e94fa5f2cb08e8f23f", size = 106105 }, +] + +[[package]] +name = "jupyter-console" +version = "6.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ipykernel" }, + { name = "ipython", version = "8.37.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.11'" }, + { name = "ipython", version = "9.4.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.11'" }, + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "prompt-toolkit" }, + { name = "pygments" }, + { name = "pyzmq" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bd/2d/e2fd31e2fc41c14e2bcb6c976ab732597e907523f6b2420305f9fc7fdbdb/jupyter_console-6.6.3.tar.gz", hash = "sha256:566a4bf31c87adbfadf22cdf846e3069b59a71ed5da71d6ba4d8aaad14a53539", size = 34363 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/77/71d78d58f15c22db16328a476426f7ac4a60d3a5a7ba3b9627ee2f7903d4/jupyter_console-6.6.3-py3-none-any.whl", hash = "sha256:309d33409fcc92ffdad25f0bcdf9a4a9daa61b6f341177570fdac03de5352485", size = 24510 }, +] + +[[package]] +name = "jupyter-core" +version = "5.8.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "platformdirs" }, + { name = "pywin32", marker = "platform_python_implementation != 'PyPy' and sys_platform == 'win32'" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/1b/72906d554acfeb588332eaaa6f61577705e9ec752ddb486f302dafa292d9/jupyter_core-5.8.1.tar.gz", hash = "sha256:0a5f9706f70e64786b75acba995988915ebd4601c8a52e534a40b51c95f59941", size = 88923 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2f/57/6bffd4b20b88da3800c5d691e0337761576ee688eb01299eae865689d2df/jupyter_core-5.8.1-py3-none-any.whl", hash = "sha256:c28d268fc90fb53f1338ded2eb410704c5449a358406e8a948b75706e24863d0", size = 28880 }, +] + +[[package]] +name = "jupyter-events" +version = "0.12.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jsonschema", extra = ["format-nongpl"] }, + { name = "packaging" }, + { name = "python-json-logger" }, + { name = "pyyaml" }, + { name = "referencing" }, + { name = "rfc3339-validator" }, + { name = "rfc3986-validator" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9d/c3/306d090461e4cf3cd91eceaff84bede12a8e52cd821c2d20c9a4fd728385/jupyter_events-0.12.0.tar.gz", hash = "sha256:fc3fce98865f6784c9cd0a56a20644fc6098f21c8c33834a8d9fe383c17e554b", size = 62196 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e2/48/577993f1f99c552f18a0428731a755e06171f9902fa118c379eb7c04ea22/jupyter_events-0.12.0-py3-none-any.whl", hash = "sha256:6464b2fa5ad10451c3d35fabc75eab39556ae1e2853ad0c0cc31b656731a97fb", size = 19430 }, +] + +[[package]] +name = "jupyter-lsp" +version = "2.2.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-server" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/85/b4/3200b0b09c12bc3b72d943d923323c398eff382d1dcc7c0dbc8b74630e40/jupyter-lsp-2.2.5.tar.gz", hash = "sha256:793147a05ad446f809fd53ef1cd19a9f5256fd0a2d6b7ce943a982cb4f545001", size = 48741 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/e0/7bd7cff65594fd9936e2f9385701e44574fc7d721331ff676ce440b14100/jupyter_lsp-2.2.5-py3-none-any.whl", hash = "sha256:45fbddbd505f3fbfb0b6cb2f1bc5e15e83ab7c79cd6e89416b248cb3c00c11da", size = 69146 }, +] + +[[package]] +name = "jupyter-server" +version = "2.16.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "argon2-cffi" }, + { name = "jinja2" }, + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "jupyter-events" }, + { name = "jupyter-server-terminals" }, + { name = "nbconvert" }, + { name = "nbformat" }, + { name = "overrides" }, + { name = "packaging" }, + { name = "prometheus-client" }, + { name = "pywinpty", marker = "(os_name == 'nt' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "pyzmq" }, + { name = "send2trash" }, + { name = "terminado" }, + { name = "tornado" }, + { name = "traitlets" }, + { name = "websocket-client" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/41/c8/ba2bbcd758c47f1124c4ca14061e8ce60d9c6fd537faee9534a95f83521a/jupyter_server-2.16.0.tar.gz", hash = "sha256:65d4b44fdf2dcbbdfe0aa1ace4a842d4aaf746a2b7b168134d5aaed35621b7f6", size = 728177 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/1f/5ebbced977171d09a7b0c08a285ff9a20aafb9c51bde07e52349ff1ddd71/jupyter_server-2.16.0-py3-none-any.whl", hash = "sha256:3d8db5be3bc64403b1c65b400a1d7f4647a5ce743f3b20dbdefe8ddb7b55af9e", size = 386904 }, +] + +[[package]] +name = "jupyter-server-terminals" +version = "0.5.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywinpty", marker = "(os_name == 'nt' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "terminado" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/fc/d5/562469734f476159e99a55426d697cbf8e7eb5efe89fb0e0b4f83a3d3459/jupyter_server_terminals-0.5.3.tar.gz", hash = "sha256:5ae0295167220e9ace0edcfdb212afd2b01ee8d179fe6f23c899590e9b8a5269", size = 31430 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/2d/2b32cdbe8d2a602f697a649798554e4f072115438e92249624e532e8aca6/jupyter_server_terminals-0.5.3-py3-none-any.whl", hash = "sha256:41ee0d7dc0ebf2809c668e0fc726dfaf258fcd3e769568996ca731b6194ae9aa", size = 13656 }, +] + +[[package]] +name = "jupyterlab" +version = "4.4.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "async-lru" }, + { name = "httpx" }, + { name = "ipykernel" }, + { name = "jinja2" }, + { name = "jupyter-core" }, + { name = "jupyter-lsp" }, + { name = "jupyter-server" }, + { name = "jupyterlab-server" }, + { name = "notebook-shim" }, + { name = "packaging" }, + { name = "setuptools" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "tornado" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e2/4d/7ca5b46ea56742880d71a768a9e6fb8f8482228427eb89492d55c5d0bb7d/jupyterlab-4.4.4.tar.gz", hash = "sha256:163fee1ef702e0a057f75d2eed3ed1da8a986d59eb002cbeb6f0c2779e6cd153", size = 23044296 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f8/82/66910ce0995dbfdb33609f41c99fe32ce483b9624a3e7d672af14ff63b9f/jupyterlab-4.4.4-py3-none-any.whl", hash = "sha256:711611e4f59851152eb93316c3547c3ec6291f16bb455f1f4fa380d25637e0dd", size = 12296310 }, +] + +[[package]] +name = "jupyterlab-pygments" +version = "0.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/90/51/9187be60d989df97f5f0aba133fa54e7300f17616e065d1ada7d7646b6d6/jupyterlab_pygments-0.3.0.tar.gz", hash = "sha256:721aca4d9029252b11cfa9d185e5b5af4d54772bb8072f9b7036f4170054d35d", size = 512900 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b1/dd/ead9d8ea85bf202d90cc513b533f9c363121c7792674f78e0d8a854b63b4/jupyterlab_pygments-0.3.0-py3-none-any.whl", hash = "sha256:841a89020971da1d8693f1a99997aefc5dc424bb1b251fd6322462a1b8842780", size = 15884 }, +] + +[[package]] +name = "jupyterlab-server" +version = "2.27.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "babel" }, + { name = "jinja2" }, + { name = "json5" }, + { name = "jsonschema" }, + { name = "jupyter-server" }, + { name = "packaging" }, + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/0a/c9/a883ce65eb27905ce77ace410d83587c82ea64dc85a48d1f7ed52bcfa68d/jupyterlab_server-2.27.3.tar.gz", hash = "sha256:eb36caca59e74471988f0ae25c77945610b887f777255aa21f8065def9e51ed4", size = 76173 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/54/09/2032e7d15c544a0e3cd831c51d77a8ca57f7555b2e1b2922142eddb02a84/jupyterlab_server-2.27.3-py3-none-any.whl", hash = "sha256:e697488f66c3db49df675158a77b3b017520d772c6e1548c7d9bcc5df7944ee4", size = 59700 }, +] + +[[package]] +name = "jupyterlab-widgets" +version = "3.0.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b9/7d/160595ca88ee87ac6ba95d82177d29ec60aaa63821d3077babb22ce031a5/jupyterlab_widgets-3.0.15.tar.gz", hash = "sha256:2920888a0c2922351a9202817957a68c07d99673504d6cd37345299e971bb08b", size = 213149 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/6a/ca128561b22b60bd5a0c4ea26649e68c8556b82bc70a0c396eebc977fe86/jupyterlab_widgets-3.0.15-py3-none-any.whl", hash = "sha256:d59023d7d7ef71400d51e6fee9a88867f6e65e10a4201605d2d7f3e8f012a31c", size = 216571 }, +] + [[package]] name = "kiwisolver" version = "1.4.8" @@ -1746,19 +2293,44 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c9/b4/680aa700d99b48e8c4393fa08e9ab8c49c0555ee6f4c9c0a5e8ea8dfde5d/matplotlib-3.10.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ae80dc3a4add4665cf2faa90138384a7ffe2a4e37c58d83e115b54287c4f06ef", size = 8587361 }, ] +[[package]] +name = "matplotlib-inline" +version = "0.1.7" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/99/5b/a36a337438a14116b16480db471ad061c36c3694df7c2084a0da7ba538b7/matplotlib_inline-0.1.7.tar.gz", hash = "sha256:8423b23ec666be3d16e16b60bdd8ac4e86e840ebd1dd11a30b9f117f2fa0ab90", size = 8159 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/8e/9ad090d3553c280a8060fbf6e24dc1c0c29704ee7d1c372f0c174aa59285/matplotlib_inline-0.1.7-py3-none-any.whl", hash = "sha256:df192d39a4ff8f21b1895d72e6a13f5fcc5099f00fa84384e0ea28c2cc0653ca", size = 9899 }, +] + +[[package]] +name = "mistune" +version = "3.1.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions", marker = "python_full_version < '3.11'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c4/79/bda47f7dd7c3c55770478d6d02c9960c430b0cf1773b72366ff89126ea31/mistune-3.1.3.tar.gz", hash = "sha256:a7035c21782b2becb6be62f8f25d3df81ccb4d6fa477a6525b15af06539f02a0", size = 94347 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/01/4d/23c4e4f09da849e127e9f123241946c23c1e30f45a88366879e064211815/mistune-3.1.3-py3-none-any.whl", hash = "sha256:1a32314113cff28aa6432e99e522677c8587fd83e3d51c29b82a52409c842bd9", size = 53410 }, +] + [[package]] name = "ml4gw" -version = "0.7.2" +version = "0.7.5" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, { name = "numpy" }, + { name = "scipy" }, { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/18/3d/9c58325b0f3d606cebc0e52e8b8600aeeca699096184a5ca0c4a2a2b3024/ml4gw-0.7.2.tar.gz", hash = "sha256:fc9f61fbc6e2fd9ae6b8654d4e0468788b56cf8631f1e4d8b4a55dba80931a90", size = 101166 } +sdist = { url = "https://files.pythonhosted.org/packages/fd/bb/99c1e75bbacd81f8cc887380cd0dc260066959cc81926fc2472b8842dc0a/ml4gw-0.7.5.tar.gz", hash = "sha256:3c776664bd8594d3b87c450cf0a0e5369f2ba8bfe2ad5923236b31d768eb8fb5", size = 811225 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d2/2d/efb9febf79d4b54392f8b45910fed080d78ea2ba377319e34919f9cda1cc/ml4gw-0.7.2-py3-none-any.whl", hash = "sha256:a8f2f508420eba1c6dc045762b1c6452ce51b43ffa3c138e2ca4bd8436540ff5", size = 123216 }, + { url = "https://files.pythonhosted.org/packages/32/7d/13ae4a5199dc081b7855202131a2ea69e322e544c45ed2cc6d7ca5786aa8/ml4gw-0.7.5-py3-none-any.whl", hash = "sha256:430fb2994a820c659806e4bff745395a04e466bc4c50c0b3c14f680b6fab187c", size = 125108 }, ] [[package]] @@ -1898,6 +2470,70 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/b7/b9e70fde2c0f0c9af4cc5277782a89b66d35948ea3369ec9f598358c3ac5/multidict-6.1.0-py3-none-any.whl", hash = "sha256:48e171e52d1c4d33888e529b999e5900356b9ae588c2f09a52dcefb158b27506", size = 10051 }, ] +[[package]] +name = "nbclient" +version = "0.10.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-client" }, + { name = "jupyter-core" }, + { name = "nbformat" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/87/66/7ffd18d58eae90d5721f9f39212327695b749e23ad44b3881744eaf4d9e8/nbclient-0.10.2.tar.gz", hash = "sha256:90b7fc6b810630db87a6d0c2250b1f0ab4cf4d3c27a299b0cde78a4ed3fd9193", size = 62424 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/34/6d/e7fa07f03a4a7b221d94b4d586edb754a9b0dc3c9e2c93353e9fa4e0d117/nbclient-0.10.2-py3-none-any.whl", hash = "sha256:4ffee11e788b4a27fabeb7955547e4318a5298f34342a4bfd01f2e1faaeadc3d", size = 25434 }, +] + +[[package]] +name = "nbconvert" +version = "7.16.6" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "beautifulsoup4" }, + { name = "bleach", extra = ["css"] }, + { name = "defusedxml" }, + { name = "jinja2" }, + { name = "jupyter-core" }, + { name = "jupyterlab-pygments" }, + { name = "markupsafe" }, + { name = "mistune" }, + { name = "nbclient" }, + { name = "nbformat" }, + { name = "packaging" }, + { name = "pandocfilters" }, + { name = "pygments" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a3/59/f28e15fc47ffb73af68a8d9b47367a8630d76e97ae85ad18271b9db96fdf/nbconvert-7.16.6.tar.gz", hash = "sha256:576a7e37c6480da7b8465eefa66c17844243816ce1ccc372633c6b71c3c0f582", size = 857715 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/9a/cd673b2f773a12c992f41309ef81b99da1690426bd2f96957a7ade0d3ed7/nbconvert-7.16.6-py3-none-any.whl", hash = "sha256:1375a7b67e0c2883678c48e506dc320febb57685e5ee67faa51b18a90f3a712b", size = 258525 }, +] + +[[package]] +name = "nbformat" +version = "5.10.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastjsonschema" }, + { name = "jsonschema" }, + { name = "jupyter-core" }, + { name = "traitlets" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6d/fd/91545e604bc3dad7dca9ed03284086039b294c6b3d75c0d2fa45f9e9caf3/nbformat-5.10.4.tar.gz", hash = "sha256:322168b14f937a5d11362988ecac2a4952d3d8e3a2cbeb2319584631226d5b3a", size = 142749 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a9/82/0340caa499416c78e5d8f5f05947ae4bc3cba53c9f038ab6e9ed964e22f1/nbformat-5.10.4-py3-none-any.whl", hash = "sha256:3b48d6c8fbca4b299bf3982ea7db1af21580e4fec269ad087b9e81588891200b", size = 78454 }, +] + +[[package]] +name = "nest-asyncio" +version = "1.6.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/83/f8/51569ac65d696c8ecbee95938f89d4abf00f47d58d48f6fbabfe8f0baefe/nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe", size = 7418 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/c4/c2971a3ba4c6103a3d10c4b0f24f461ddc027f0f09763220cf35ca1401b3/nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c", size = 5195 }, +] + [[package]] name = "networkx" version = "3.4.2" @@ -1907,6 +2543,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/54/dd730b32ea14ea797530a4479b2ed46a6fb250f682a9cfb997e968bf0261/networkx-3.4.2-py3-none-any.whl", hash = "sha256:df5d4365b724cf81b8c6a7312509d0c22386097011ad1abe274afd5e9d3bbc5f", size = 1723263 }, ] +[[package]] +name = "notebook" +version = "7.4.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-server" }, + { name = "jupyterlab" }, + { name = "jupyterlab-server" }, + { name = "notebook-shim" }, + { name = "tornado" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/b1/4e/a40b5a94eb01fc51746db7854296d88b84905ab18ee0fcef853a60d708a3/notebook-7.4.4.tar.gz", hash = "sha256:392fd501e266f2fb3466c6fcd3331163a2184968cb5c5accf90292e01dfe528c", size = 13883628 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b3/c0/e64d2047fd752249b0b69f6aee2a7049eb94e7273e5baabc8b8ad05cc068/notebook-7.4.4-py3-none-any.whl", hash = "sha256:32840f7f777b6bff79bb101159336e9b332bdbfba1495b8739e34d1d65cbc1c0", size = 14288000 }, +] + +[[package]] +name = "notebook-shim" +version = "0.2.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "jupyter-server" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/d2/92fa3243712b9a3e8bafaf60aac366da1cada3639ca767ff4b5b3654ec28/notebook_shim-0.2.4.tar.gz", hash = "sha256:b4b2cfa1b65d98307ca24361f5b30fe785b53c3fd07b7a47e89acb5e6ac638cb", size = 13167 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/33/bd5b9137445ea4b680023eb0469b2bb969d61303dedb2aac6560ff3d14a1/notebook_shim-0.2.4-py3-none-any.whl", hash = "sha256:411a5be4e9dc882a074ccbcae671eda64cceb068767e9a3419096986560e1cef", size = 13307 }, +] + [[package]] name = "numpy" version = "1.26.4" @@ -2094,6 +2758,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/68/162c97ea78c957d68ecf78a5c5041d2e25bd5562bdf5d89a6cbf7f8429bf/opencensus_context-0.1.3-py2.py3-none-any.whl", hash = "sha256:073bb0590007af276853009fac7e4bab1d523c3f03baf4cb4511ca38967c6039", size = 5060 }, ] +[[package]] +name = "overrides" +version = "7.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/36/86/b585f53236dec60aba864e050778b25045f857e17f6e5ea0ae95fe80edd2/overrides-7.7.0.tar.gz", hash = "sha256:55158fa3d93b98cc75299b1e67078ad9003ca27945c76162c1c0766d6f91820a", size = 22812 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/2c/ab/fc8290c6a4c722e5514d80f62b2dc4c4df1a68a41d1364e625c35990fcf3/overrides-7.7.0-py3-none-any.whl", hash = "sha256:c7ed9d062f78b8e4c1a7b70bd8796b35ead4d9f510227ef9c5dc7626c60d7e49", size = 17832 }, +] + [[package]] name = "packaging" version = "24.2" @@ -2138,6 +2811,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/29/d4/1244ab8edf173a10fd601f7e13b9566c1b525c4f365d6bee918e68381889/pandas-2.2.3-cp312-cp312-win_amd64.whl", hash = "sha256:59ef3764d0fe818125a5097d2ae867ca3fa64df032331b7e0917cf5d7bf66b13", size = 11504248 }, ] +[[package]] +name = "pandocfilters" +version = "1.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/70/6f/3dd4940bbe001c06a65f88e36bad298bc7a0de5036115639926b0c5c0458/pandocfilters-1.5.1.tar.gz", hash = "sha256:002b4a555ee4ebc03f8b66307e287fa492e4a77b4ea14d3f934328297bb4939e", size = 8454 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ef/af/4fbc8cab944db5d21b7e2a5b8e9211a03a79852b1157e2c102fcc61ac440/pandocfilters-1.5.1-py2.py3-none-any.whl", hash = "sha256:93be382804a9cdb0a7267585f157e5d1731bbe5545a85b268d6f5fe6232de2bc", size = 8663 }, +] + +[[package]] +name = "parso" +version = "0.8.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/66/94/68e2e17afaa9169cf6412ab0f28623903be73d1b32e208d9e8e541bb086d/parso-0.8.4.tar.gz", hash = "sha256:eb3a7b58240fb99099a345571deecc0f9540ea5f4dd2fe14c2a99d6b281ab92d", size = 400609 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c6/ac/dac4a63f978e4dcb3c6d3a78c4d8e0192a113d288502a1216950c41b1027/parso-0.8.4-py2.py3-none-any.whl", hash = "sha256:a418670a20291dacd2dddc80c377c5c3791378ee1e8d12bffc35420643d43f18", size = 103650 }, +] + [[package]] name = "pegasus-wms-api" version = "5.0.9" @@ -2156,6 +2847,18 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/da/bd/eb4c316b7abd8927f14aeed58abe2943caf854efefa692d8298f3cff7b58/pegasus-wms.common-5.0.9.tar.gz", hash = "sha256:f729a6095a749ba2afe30d147dd3ca817992f1c6042b80f95ee6ab03e5ee0745", size = 47143 } +[[package]] +name = "pexpect" +version = "4.9.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/92/cc564bf6381ff43ce1f4d06852fc19a2f11d180f23dc32d9588bee2f149d/pexpect-4.9.0.tar.gz", hash = "sha256:ee7d41123f3c9911050ea2c2dac107568dc43b2d3b0c7557a33212c398ead30f", size = 166450 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/c3/059298687310d527a58bb01f3b1965787ee3b40dce76752eda8b44e9a2c5/pexpect-4.9.0-py2.py3-none-any.whl", hash = "sha256:7236d1e080e4936be2dc3e326cec0af72acf9212a7e1d060210e70a47e253523", size = 63772 }, +] + [[package]] name = "pillow" version = "11.1.0" @@ -2253,6 +2956,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ff/c2/ab7d37426c179ceb9aeb109a85cda8948bb269b7561a0be870cc656eefe4/prometheus_client-0.21.1-py3-none-any.whl", hash = "sha256:594b45c410d6f4f8888940fe80b5cc2521b305a1fafe1c58609ef715a001f301", size = 54682 }, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.51" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/bb/6e/9d084c929dfe9e3bfe0c6a47e31f78a25c54627d64a66e884a8bf5474f1c/prompt_toolkit-3.0.51.tar.gz", hash = "sha256:931a162e3b27fc90c86f1b48bb1fb2c528c2761475e57c9c06de13311c7b54ed", size = 428940 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ce/4f/5249960887b1fbe561d9ff265496d170b55a735b76724f10ef19f9e40716/prompt_toolkit-3.0.51-py3-none-any.whl", hash = "sha256:52742911fde84e2d423e2f9a4cf1de7d7ac4e51958f648d9540e0fb8db077b07", size = 387810 }, +] + [[package]] name = "propcache" version = "0.3.0" @@ -2349,6 +3064,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/05/33/2d74d588408caedd065c2497bdb5ef83ce6082db01289a1e1147f6639802/psutil-5.9.8-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:d16bbddf0693323b8c6123dd804100241da461e41d6e332fb0ba6058f630f8c8", size = 249898 }, ] +[[package]] +name = "ptyprocess" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/20/e5/16ff212c1e452235a90aeb09066144d0c5a6a8c0834397e03f5224495c4e/ptyprocess-0.7.0.tar.gz", hash = "sha256:5c5d0a3b48ceee0b48485e0c26037c0acd7d29765ca3fbb5cb3831d347423220", size = 70762 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/22/a6/858897256d0deac81a172289110f31629fc4cee19b6f01283303e18c8db3/ptyprocess-0.7.0-py2.py3-none-any.whl", hash = "sha256:4b41f3967fce3af57cc7e94b888626c18bf37a083e3651ca8feeb66d492fef35", size = 13993 }, +] + +[[package]] +name = "pure-eval" +version = "0.2.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/cd/05/0a34433a064256a578f1783a10da6df098ceaa4a57bbeaa96a6c0352786b/pure_eval-0.2.3.tar.gz", hash = "sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42", size = 19752 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/37/efad0257dc6e593a18957422533ff0f87ede7c9c6ea010a2177d738fb82f/pure_eval-0.2.3-py3-none-any.whl", hash = "sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0", size = 11842 }, +] + [[package]] name = "py-spy" version = "0.4.0" @@ -2561,6 +3294,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b4/11/97233cf23ad5411ac6f13b1d6ee3888f90ace4f974d9bf9db887aa428912/pyerfa-2.0.1.5-cp39-abi3-win_amd64.whl", hash = "sha256:66292d437dcf75925b694977aa06eb697126e7b86553e620371ed3e48b5e0ad0", size = 349410 }, ] +[[package]] +name = "pygments" +version = "2.19.2" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/b0/77/a5b8c569bf593b0140bde72ea885a803b82086995367bf2037de0159d924/pygments-2.19.2.tar.gz", hash = "sha256:636cb2477cec7f8952536970bc533bc43743542f70392ae026374600add5b887", size = 4968631 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c7/21/705964c7812476f378728bdf590ca4b771ec72385c533964653c68e86bdc/pygments-2.19.2-py3-none-any.whl", hash = "sha256:86540386c03d588bb81d44bc3928634ff26449851e99741617ecb9037ee5ec0b", size = 1225217 }, +] + [[package]] name = "pyjwt" version = "2.10.1" @@ -2684,6 +3426,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ec/57/56b9bcc3c9c6a792fcbaf139543cee77261f3651ca9da0c93f5c1221264b/python_dateutil-2.9.0.post0-py2.py3-none-any.whl", hash = "sha256:a8b2bc7bffae282281c8140a97d3aa9c14da0b136dfe83f850eea9a5f7470427", size = 229892 }, ] +[[package]] +name = "python-json-logger" +version = "3.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/9e/de/d3144a0bceede957f961e975f3752760fbe390d57fbe194baf709d8f1f7b/python_json_logger-3.3.0.tar.gz", hash = "sha256:12b7e74b17775e7d565129296105bbe3910842d9d0eb083fc83a6a617aa8df84", size = 16642 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/20/0f2523b9e50a8052bc6a8b732dfc8568abbdc42010aef03a2d750bdab3b2/python_json_logger-3.3.0-py3-none-any.whl", hash = "sha256:dd980fae8cffb24c13caf6e158d3d61c0d6d22342f932cb6e9deedab3d35eec7", size = 15163 }, +] + [[package]] name = "python-jsonpath" version = "1.3.0" @@ -2796,6 +3547,17 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/0f/d40f8373608caed2255781a3ad9a51d03a594a1248cd632d6a298daca693/pywin32-308-cp312-cp312-win_arm64.whl", hash = "sha256:9b4de86c8d909aed15b7011182c8cab38c8850de36e6afb1f0db22b8959e3091", size = 7976033 }, ] +[[package]] +name = "pywinpty" +version = "2.0.15" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2d/7c/917f9c4681bb8d34bfbe0b79d36bbcd902651aeab48790df3d30ba0202fb/pywinpty-2.0.15.tar.gz", hash = "sha256:312cf39153a8736c617d45ce8b6ad6cd2107de121df91c455b10ce6bba7a39b2", size = 29017 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a6/b7/855db919ae526d2628f3f2e6c281c4cdff7a9a8af51bb84659a9f07b1861/pywinpty-2.0.15-cp310-cp310-win_amd64.whl", hash = "sha256:8e7f5de756a615a38b96cd86fa3cd65f901ce54ce147a3179c45907fa11b4c4e", size = 1405161 }, + { url = "https://files.pythonhosted.org/packages/5e/ac/6884dcb7108af66ad53f73ef4dad096e768c9203a6e6ce5e6b0c4a46e238/pywinpty-2.0.15-cp311-cp311-win_amd64.whl", hash = "sha256:9a6bcec2df2707aaa9d08b86071970ee32c5026e10bcc3cc5f6f391d85baf7ca", size = 1405249 }, + { url = "https://files.pythonhosted.org/packages/88/e5/9714def18c3a411809771a3fbcec70bffa764b9675afb00048a620fca604/pywinpty-2.0.15-cp312-cp312-win_amd64.whl", hash = "sha256:83a8f20b430bbc5d8957249f875341a60219a4e971580f2ba694fbfb54a45ebc", size = 1405243 }, +] + [[package]] name = "pyyaml" version = "6.0.2" @@ -2831,6 +3593,57 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0c/e8/4f648c598b17c3d06e8753d7d13d57542b30d56e6c2dedf9c331ae56312e/PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8", size = 156338 }, ] +[[package]] +name = "pyzmq" +version = "27.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "cffi", marker = "implementation_name == 'pypy'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/f1/06/50a4e9648b3e8b992bef8eb632e457307553a89d294103213cfd47b3da69/pyzmq-27.0.0.tar.gz", hash = "sha256:b1f08eeb9ce1510e6939b6e5dcd46a17765e2333daae78ecf4606808442e52cf", size = 280478 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9c/09/1681d4b047626d352c083770618ac29655ab1f5c20eee31dc94c000b9b7b/pyzmq-27.0.0-cp310-cp310-macosx_10_15_universal2.whl", hash = "sha256:b973ee650e8f442ce482c1d99ca7ab537c69098d53a3d046676a484fd710c87a", size = 1329291 }, + { url = "https://files.pythonhosted.org/packages/9d/b2/9c9385225fdd54db9506ed8accbb9ea63ca813ba59d43d7f282a6a16a30b/pyzmq-27.0.0-cp310-cp310-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:661942bc7cd0223d569d808f2e5696d9cc120acc73bf3e88a1f1be7ab648a7e4", size = 905952 }, + { url = "https://files.pythonhosted.org/packages/41/73/333c72c7ec182cdffe25649e3da1c3b9f3cf1cede63cfdc23d1384d4a601/pyzmq-27.0.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:50360fb2a056ffd16e5f4177eee67f1dd1017332ea53fb095fe7b5bf29c70246", size = 666165 }, + { url = "https://files.pythonhosted.org/packages/a5/fe/fc7b9c1a50981928e25635a926653cb755364316db59ccd6e79cfb9a0b4f/pyzmq-27.0.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:cf209a6dc4b420ed32a7093642843cbf8703ed0a7d86c16c0b98af46762ebefb", size = 853755 }, + { url = "https://files.pythonhosted.org/packages/8c/4c/740ed4b6e8fa160cd19dc5abec8db68f440564b2d5b79c1d697d9862a2f7/pyzmq-27.0.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:c2dace4a7041cca2fba5357a2d7c97c5effdf52f63a1ef252cfa496875a3762d", size = 1654868 }, + { url = "https://files.pythonhosted.org/packages/97/00/875b2ecfcfc78ab962a59bd384995186818524ea957dc8ad3144611fae12/pyzmq-27.0.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:63af72b2955fc77caf0a77444baa2431fcabb4370219da38e1a9f8d12aaebe28", size = 2033443 }, + { url = "https://files.pythonhosted.org/packages/60/55/6dd9c470c42d713297c5f2a56f7903dc1ebdb4ab2edda996445c21651900/pyzmq-27.0.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:e8c4adce8e37e75c4215297d7745551b8dcfa5f728f23ce09bf4e678a9399413", size = 1891288 }, + { url = "https://files.pythonhosted.org/packages/28/5d/54b0ef50d40d7c65a627f4a4b4127024ba9820f2af8acd933a4d30ae192e/pyzmq-27.0.0-cp310-cp310-win32.whl", hash = "sha256:5d5ef4718ecab24f785794e0e7536436698b459bfbc19a1650ef55280119d93b", size = 567936 }, + { url = "https://files.pythonhosted.org/packages/18/ea/dedca4321de748ca48d3bcdb72274d4d54e8d84ea49088d3de174bd45d88/pyzmq-27.0.0-cp310-cp310-win_amd64.whl", hash = "sha256:e40609380480b3d12c30f841323f42451c755b8fece84235236f5fe5ffca8c1c", size = 628686 }, + { url = "https://files.pythonhosted.org/packages/d4/a7/fcdeedc306e71e94ac262cba2d02337d885f5cdb7e8efced8e5ffe327808/pyzmq-27.0.0-cp310-cp310-win_arm64.whl", hash = "sha256:6b0397b0be277b46762956f576e04dc06ced265759e8c2ff41a0ee1aa0064198", size = 559039 }, + { url = "https://files.pythonhosted.org/packages/44/df/84c630654106d9bd9339cdb564aa941ed41b023a0264251d6743766bb50e/pyzmq-27.0.0-cp311-cp311-macosx_10_15_universal2.whl", hash = "sha256:21457825249b2a53834fa969c69713f8b5a79583689387a5e7aed880963ac564", size = 1332718 }, + { url = "https://files.pythonhosted.org/packages/c1/8e/f6a5461a07654d9840d256476434ae0ff08340bba562a455f231969772cb/pyzmq-27.0.0-cp311-cp311-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:1958947983fef513e6e98eff9cb487b60bf14f588dc0e6bf35fa13751d2c8251", size = 908248 }, + { url = "https://files.pythonhosted.org/packages/7c/93/82863e8d695a9a3ae424b63662733ae204a295a2627d52af2f62c2cd8af9/pyzmq-27.0.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c0dc628b5493f9a8cd9844b8bee9732ef587ab00002157c9329e4fc0ef4d3afa", size = 668647 }, + { url = "https://files.pythonhosted.org/packages/f3/85/15278769b348121eacdbfcbd8c4d40f1102f32fa6af5be1ffc032ed684be/pyzmq-27.0.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f7bbe9e1ed2c8d3da736a15694d87c12493e54cc9dc9790796f0321794bbc91f", size = 856600 }, + { url = "https://files.pythonhosted.org/packages/d4/af/1c469b3d479bd095edb28e27f12eee10b8f00b356acbefa6aeb14dd295d1/pyzmq-27.0.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:dc1091f59143b471d19eb64f54bae4f54bcf2a466ffb66fe45d94d8d734eb495", size = 1657748 }, + { url = "https://files.pythonhosted.org/packages/8c/f4/17f965d0ee6380b1d6326da842a50e4b8b9699745161207945f3745e8cb5/pyzmq-27.0.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:7011ade88c8e535cf140f8d1a59428676fbbce7c6e54fefce58bf117aefb6667", size = 2034311 }, + { url = "https://files.pythonhosted.org/packages/e0/6e/7c391d81fa3149fd759de45d298003de6cfab343fb03e92c099821c448db/pyzmq-27.0.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:2c386339d7e3f064213aede5d03d054b237937fbca6dd2197ac8cf3b25a6b14e", size = 1893630 }, + { url = "https://files.pythonhosted.org/packages/0e/e0/eaffe7a86f60e556399e224229e7769b717f72fec0706b70ab2c03aa04cb/pyzmq-27.0.0-cp311-cp311-win32.whl", hash = "sha256:0546a720c1f407b2172cb04b6b094a78773491497e3644863cf5c96c42df8cff", size = 567706 }, + { url = "https://files.pythonhosted.org/packages/c9/05/89354a8cffdcce6e547d48adaaf7be17007fc75572123ff4ca90a4ca04fc/pyzmq-27.0.0-cp311-cp311-win_amd64.whl", hash = "sha256:15f39d50bd6c9091c67315ceb878a4f531957b121d2a05ebd077eb35ddc5efed", size = 630322 }, + { url = "https://files.pythonhosted.org/packages/fa/07/4ab976d5e1e63976719389cc4f3bfd248a7f5f2bb2ebe727542363c61b5f/pyzmq-27.0.0-cp311-cp311-win_arm64.whl", hash = "sha256:c5817641eebb391a2268c27fecd4162448e03538387093cdbd8bf3510c316b38", size = 558435 }, + { url = "https://files.pythonhosted.org/packages/93/a7/9ad68f55b8834ede477842214feba6a4c786d936c022a67625497aacf61d/pyzmq-27.0.0-cp312-abi3-macosx_10_15_universal2.whl", hash = "sha256:cbabc59dcfaac66655c040dfcb8118f133fb5dde185e5fc152628354c1598e52", size = 1305438 }, + { url = "https://files.pythonhosted.org/packages/ba/ee/26aa0f98665a22bc90ebe12dced1de5f3eaca05363b717f6fb229b3421b3/pyzmq-27.0.0-cp312-abi3-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:cb0ac5179cba4b2f94f1aa208fbb77b62c4c9bf24dd446278b8b602cf85fcda3", size = 895095 }, + { url = "https://files.pythonhosted.org/packages/cf/85/c57e7ab216ecd8aa4cc7e3b83b06cc4e9cf45c87b0afc095f10cd5ce87c1/pyzmq-27.0.0-cp312-abi3-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53a48f0228eab6cbf69fde3aa3c03cbe04e50e623ef92ae395fce47ef8a76152", size = 651826 }, + { url = "https://files.pythonhosted.org/packages/69/9a/9ea7e230feda9400fb0ae0d61d7d6ddda635e718d941c44eeab22a179d34/pyzmq-27.0.0-cp312-abi3-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:111db5f395e09f7e775f759d598f43cb815fc58e0147623c4816486e1a39dc22", size = 839750 }, + { url = "https://files.pythonhosted.org/packages/08/66/4cebfbe71f3dfbd417011daca267539f62ed0fbc68105357b68bbb1a25b7/pyzmq-27.0.0-cp312-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c8878011653dcdc27cc2c57e04ff96f0471e797f5c19ac3d7813a245bcb24371", size = 1641357 }, + { url = "https://files.pythonhosted.org/packages/ac/f6/b0f62578c08d2471c791287149cb8c2aaea414ae98c6e995c7dbe008adfb/pyzmq-27.0.0-cp312-abi3-musllinux_1_2_i686.whl", hash = "sha256:c0ed2c1f335ba55b5fdc964622254917d6b782311c50e138863eda409fbb3b6d", size = 2020281 }, + { url = "https://files.pythonhosted.org/packages/37/b9/4f670b15c7498495da9159edc374ec09c88a86d9cd5a47d892f69df23450/pyzmq-27.0.0-cp312-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:e918d70862d4cfd4b1c187310015646a14e1f5917922ab45b29f28f345eeb6be", size = 1877110 }, + { url = "https://files.pythonhosted.org/packages/66/31/9dee25c226295b740609f0d46db2fe972b23b6f5cf786360980524a3ba92/pyzmq-27.0.0-cp312-abi3-win32.whl", hash = "sha256:88b4e43cab04c3c0f0d55df3b1eef62df2b629a1a369b5289a58f6fa8b07c4f4", size = 559297 }, + { url = "https://files.pythonhosted.org/packages/9b/12/52da5509800f7ff2d287b2f2b4e636e7ea0f001181cba6964ff6c1537778/pyzmq-27.0.0-cp312-abi3-win_amd64.whl", hash = "sha256:dce4199bf5f648a902ce37e7b3afa286f305cd2ef7a8b6ec907470ccb6c8b371", size = 619203 }, + { url = "https://files.pythonhosted.org/packages/93/6d/7f2e53b19d1edb1eb4f09ec7c3a1f945ca0aac272099eab757d15699202b/pyzmq-27.0.0-cp312-abi3-win_arm64.whl", hash = "sha256:56e46bbb85d52c1072b3f809cc1ce77251d560bc036d3a312b96db1afe76db2e", size = 551927 }, + { url = "https://files.pythonhosted.org/packages/09/6f/be6523a7f3821c0b5370912ef02822c028611360e0d206dd945bdbf9eaef/pyzmq-27.0.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:656c1866505a5735d0660b7da6d7147174bbf59d4975fc2b7f09f43c9bc25745", size = 835950 }, + { url = "https://files.pythonhosted.org/packages/c6/1e/a50fdd5c15018de07ab82a61bc460841be967ee7bbe7abee3b714d66f7ac/pyzmq-27.0.0-pp310-pypy310_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:74175b9e12779382432dd1d1f5960ebe7465d36649b98a06c6b26be24d173fab", size = 799876 }, + { url = "https://files.pythonhosted.org/packages/88/a1/89eb5b71f5a504f8f887aceb8e1eb3626e00c00aa8085381cdff475440dc/pyzmq-27.0.0-pp310-pypy310_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d8c6de908465697a8708e4d6843a1e884f567962fc61eb1706856545141d0cbb", size = 567400 }, + { url = "https://files.pythonhosted.org/packages/56/aa/4571dbcff56cfb034bac73fde8294e123c975ce3eea89aff31bf6dc6382b/pyzmq-27.0.0-pp310-pypy310_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c644aaacc01d0df5c7072826df45e67301f191c55f68d7b2916d83a9ddc1b551", size = 747031 }, + { url = "https://files.pythonhosted.org/packages/46/e0/d25f30fe0991293c5b2f5ef3b070d35fa6d57c0c7428898c3ab4913d0297/pyzmq-27.0.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:10f70c1d9a446a85013a36871a296007f6fe4232b530aa254baf9da3f8328bc0", size = 544726 }, + { url = "https://files.pythonhosted.org/packages/98/a6/92394373b8dbc1edc9d53c951e8d3989d518185174ee54492ec27711779d/pyzmq-27.0.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:cd1dc59763effd1576f8368047c9c31468fce0af89d76b5067641137506792ae", size = 835948 }, + { url = "https://files.pythonhosted.org/packages/56/f3/4dc38d75d9995bfc18773df3e41f2a2ca9b740b06f1a15dbf404077e7588/pyzmq-27.0.0-pp311-pypy311_pp73-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:60e8cc82d968174650c1860d7b716366caab9973787a1c060cf8043130f7d0f7", size = 799874 }, + { url = "https://files.pythonhosted.org/packages/ab/ba/64af397e0f421453dc68e31d5e0784d554bf39013a2de0872056e96e58af/pyzmq-27.0.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:14fe7aaac86e4e93ea779a821967360c781d7ac5115b3f1a171ced77065a0174", size = 567400 }, + { url = "https://files.pythonhosted.org/packages/63/87/ec956cbe98809270b59a22891d5758edae147a258e658bf3024a8254c855/pyzmq-27.0.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6ad0562d4e6abb785be3e4dd68599c41be821b521da38c402bc9ab2a8e7ebc7e", size = 747031 }, + { url = "https://files.pythonhosted.org/packages/be/8a/4a3764a68abc02e2fbb0668d225b6fda5cd39586dd099cee8b2ed6ab0452/pyzmq-27.0.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:9df43a2459cd3a3563404c1456b2c4c69564daa7dbaf15724c09821a3329ce46", size = 544726 }, +] + [[package]] name = "ray" version = "2.42.1" @@ -2929,6 +3742,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6f/bb/5deac77a9af870143c684ab46a7934038a53eb4aa975bc0687ed6ca2c610/requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5", size = 23892 }, ] +[[package]] +name = "rfc3339-validator" +version = "0.1.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/ea/a9387748e2d111c3c2b275ba970b735e04e15cdb1eb30693b6b5708c4dbd/rfc3339_validator-0.1.4.tar.gz", hash = "sha256:138a2abdf93304ad60530167e51d2dfb9549521a836871b88d7f4695d0022f6b", size = 5513 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7b/44/4e421b96b67b2daff264473f7465db72fbdf36a07e05494f50300cc7b0c6/rfc3339_validator-0.1.4-py2.py3-none-any.whl", hash = "sha256:24f6ec1eda14ef823da9e36ec7113124b39c04d50a4d3d3a3c2859577e7791fa", size = 3490 }, +] + +[[package]] +name = "rfc3986-validator" +version = "0.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/da/88/f270de456dd7d11dcc808abfa291ecdd3f45ff44e3b549ffa01b126464d0/rfc3986_validator-0.1.1.tar.gz", hash = "sha256:3d44bde7921b3b9ec3ae4e3adca370438eccebc676456449b145d533b240d055", size = 6760 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9e/51/17023c0f8f1869d8806b979a2bffa3f861f26a3f1a66b094288323fba52f/rfc3986_validator-0.1.1-py2.py3-none-any.whl", hash = "sha256:2f235c432ef459970b4306369336b9d5dbdda31b510ca1e327636e01f528bfa9", size = 4242 }, +] + [[package]] name = "rpds-py" version = "0.23.1" @@ -3066,40 +3900,37 @@ wheels = [ [[package]] name = "scipy" -version = "1.15.2" +version = "1.14.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "numpy" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b7/b9/31ba9cd990e626574baf93fbc1ac61cf9ed54faafd04c479117517661637/scipy-1.15.2.tar.gz", hash = "sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec", size = 59417316 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/95/df/ef233fff6838fe6f7840d69b5ef9f20d2b5c912a8727b21ebf876cb15d54/scipy-1.15.2-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9", size = 38692502 }, - { url = "https://files.pythonhosted.org/packages/5c/20/acdd4efb8a68b842968f7bc5611b1aeb819794508771ad104de418701422/scipy-1.15.2-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5", size = 30085508 }, - { url = "https://files.pythonhosted.org/packages/42/55/39cf96ca7126f1e78ee72a6344ebdc6702fc47d037319ad93221063e6cf4/scipy-1.15.2-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e", size = 22359166 }, - { url = "https://files.pythonhosted.org/packages/51/48/708d26a4ab8a1441536bf2dfcad1df0ca14a69f010fba3ccbdfc02df7185/scipy-1.15.2-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9", size = 25112047 }, - { url = "https://files.pythonhosted.org/packages/dd/65/f9c5755b995ad892020381b8ae11f16d18616208e388621dfacc11df6de6/scipy-1.15.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3", size = 35536214 }, - { url = "https://files.pythonhosted.org/packages/de/3c/c96d904b9892beec978562f64d8cc43f9cca0842e65bd3cd1b7f7389b0ba/scipy-1.15.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d", size = 37646981 }, - { url = "https://files.pythonhosted.org/packages/3d/74/c2d8a24d18acdeae69ed02e132b9bc1bb67b7bee90feee1afe05a68f9d67/scipy-1.15.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58", size = 37230048 }, - { url = "https://files.pythonhosted.org/packages/42/19/0aa4ce80eca82d487987eff0bc754f014dec10d20de2f66754fa4ea70204/scipy-1.15.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa", size = 40010322 }, - { url = "https://files.pythonhosted.org/packages/d0/d2/f0683b7e992be44d1475cc144d1f1eeae63c73a14f862974b4db64af635e/scipy-1.15.2-cp310-cp310-win_amd64.whl", hash = "sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65", size = 41233385 }, - { url = "https://files.pythonhosted.org/packages/40/1f/bf0a5f338bda7c35c08b4ed0df797e7bafe8a78a97275e9f439aceb46193/scipy-1.15.2-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4", size = 38703651 }, - { url = "https://files.pythonhosted.org/packages/de/54/db126aad3874601048c2c20ae3d8a433dbfd7ba8381551e6f62606d9bd8e/scipy-1.15.2-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1", size = 30102038 }, - { url = "https://files.pythonhosted.org/packages/61/d8/84da3fffefb6c7d5a16968fe5b9f24c98606b165bb801bb0b8bc3985200f/scipy-1.15.2-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971", size = 22375518 }, - { url = "https://files.pythonhosted.org/packages/44/78/25535a6e63d3b9c4c90147371aedb5d04c72f3aee3a34451f2dc27c0c07f/scipy-1.15.2-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655", size = 25142523 }, - { url = "https://files.pythonhosted.org/packages/e0/22/4b4a26fe1cd9ed0bc2b2cb87b17d57e32ab72c346949eaf9288001f8aa8e/scipy-1.15.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e", size = 35491547 }, - { url = "https://files.pythonhosted.org/packages/32/ea/564bacc26b676c06a00266a3f25fdfe91a9d9a2532ccea7ce6dd394541bc/scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0", size = 37634077 }, - { url = "https://files.pythonhosted.org/packages/43/c2/bfd4e60668897a303b0ffb7191e965a5da4056f0d98acfb6ba529678f0fb/scipy-1.15.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40", size = 37231657 }, - { url = "https://files.pythonhosted.org/packages/4a/75/5f13050bf4f84c931bcab4f4e83c212a36876c3c2244475db34e4b5fe1a6/scipy-1.15.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462", size = 40035857 }, - { url = "https://files.pythonhosted.org/packages/b9/8b/7ec1832b09dbc88f3db411f8cdd47db04505c4b72c99b11c920a8f0479c3/scipy-1.15.2-cp311-cp311-win_amd64.whl", hash = "sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737", size = 41217654 }, - { url = "https://files.pythonhosted.org/packages/4b/5d/3c78815cbab499610f26b5bae6aed33e227225a9fa5290008a733a64f6fc/scipy-1.15.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd", size = 38756184 }, - { url = "https://files.pythonhosted.org/packages/37/20/3d04eb066b471b6e171827548b9ddb3c21c6bbea72a4d84fc5989933910b/scipy-1.15.2-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301", size = 30163558 }, - { url = "https://files.pythonhosted.org/packages/a4/98/e5c964526c929ef1f795d4c343b2ff98634ad2051bd2bbadfef9e772e413/scipy-1.15.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93", size = 22437211 }, - { url = "https://files.pythonhosted.org/packages/1d/cd/1dc7371e29195ecbf5222f9afeedb210e0a75057d8afbd942aa6cf8c8eca/scipy-1.15.2-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20", size = 25232260 }, - { url = "https://files.pythonhosted.org/packages/f0/24/1a181a9e5050090e0b5138c5f496fee33293c342b788d02586bc410c6477/scipy-1.15.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e", size = 35198095 }, - { url = "https://files.pythonhosted.org/packages/c0/53/eaada1a414c026673eb983f8b4a55fe5eb172725d33d62c1b21f63ff6ca4/scipy-1.15.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8", size = 37297371 }, - { url = "https://files.pythonhosted.org/packages/e9/06/0449b744892ed22b7e7b9a1994a866e64895363572677a316a9042af1fe5/scipy-1.15.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11", size = 36872390 }, - { url = "https://files.pythonhosted.org/packages/6a/6f/a8ac3cfd9505ec695c1bc35edc034d13afbd2fc1882a7c6b473e280397bb/scipy-1.15.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53", size = 39700276 }, - { url = "https://files.pythonhosted.org/packages/f5/6f/e6e5aff77ea2a48dd96808bb51d7450875af154ee7cbe72188afb0b37929/scipy-1.15.2-cp312-cp312-win_amd64.whl", hash = "sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded", size = 40942317 }, +sdist = { url = "https://files.pythonhosted.org/packages/62/11/4d44a1f274e002784e4dbdb81e0ea96d2de2d1045b2132d5af62cc31fd28/scipy-1.14.1.tar.gz", hash = "sha256:5a275584e726026a5699459aa72f828a610821006228e841b94275c4a7c08417", size = 58620554 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/64/68/3bc0cfaf64ff507d82b1e5d5b64521df4c8bf7e22bc0b897827cbee9872c/scipy-1.14.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:b28d2ca4add7ac16ae8bb6632a3c86e4b9e4d52d3e34267f6e1b0c1f8d87e389", size = 39069598 }, + { url = "https://files.pythonhosted.org/packages/43/a5/8d02f9c372790326ad405d94f04d4339482ec082455b9e6e288f7100513b/scipy-1.14.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d0d2821003174de06b69e58cef2316a6622b60ee613121199cb2852a873f8cf3", size = 29879676 }, + { url = "https://files.pythonhosted.org/packages/07/42/0e0bea9666fcbf2cb6ea0205db42c81b1f34d7b729ba251010edf9c80ebd/scipy-1.14.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:8bddf15838ba768bb5f5083c1ea012d64c9a444e16192762bd858f1e126196d0", size = 23088696 }, + { url = "https://files.pythonhosted.org/packages/15/47/298ab6fef5ebf31b426560e978b8b8548421d4ed0bf99263e1eb44532306/scipy-1.14.1-cp310-cp310-macosx_14_0_x86_64.whl", hash = "sha256:97c5dddd5932bd2a1a31c927ba5e1463a53b87ca96b5c9bdf5dfd6096e27efc3", size = 25470699 }, + { url = "https://files.pythonhosted.org/packages/d8/df/cdb6be5274bc694c4c22862ac3438cb04f360ed9df0aecee02ce0b798380/scipy-1.14.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2ff0a7e01e422c15739ecd64432743cf7aae2b03f3084288f399affcefe5222d", size = 35606631 }, + { url = "https://files.pythonhosted.org/packages/47/78/b0c2c23880dd1e99e938ad49ccfb011ae353758a2dc5ed7ee59baff684c3/scipy-1.14.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e32dced201274bf96899e6491d9ba3e9a5f6b336708656466ad0522d8528f69", size = 41178528 }, + { url = "https://files.pythonhosted.org/packages/5d/aa/994b45c34b897637b853ec04334afa55a85650a0d11dacfa67232260fb0a/scipy-1.14.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:8426251ad1e4ad903a4514712d2fa8fdd5382c978010d1c6f5f37ef286a713ad", size = 42784535 }, + { url = "https://files.pythonhosted.org/packages/e7/1c/8daa6df17a945cb1a2a1e3bae3c49643f7b3b94017ff01a4787064f03f84/scipy-1.14.1-cp310-cp310-win_amd64.whl", hash = "sha256:a49f6ed96f83966f576b33a44257d869756df6cf1ef4934f59dd58b25e0327e5", size = 44772117 }, + { url = "https://files.pythonhosted.org/packages/b2/ab/070ccfabe870d9f105b04aee1e2860520460ef7ca0213172abfe871463b9/scipy-1.14.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:2da0469a4ef0ecd3693761acbdc20f2fdeafb69e6819cc081308cc978153c675", size = 39076999 }, + { url = "https://files.pythonhosted.org/packages/a7/c5/02ac82f9bb8f70818099df7e86c3ad28dae64e1347b421d8e3adf26acab6/scipy-1.14.1-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:c0ee987efa6737242745f347835da2cc5bb9f1b42996a4d97d5c7ff7928cb6f2", size = 29894570 }, + { url = "https://files.pythonhosted.org/packages/ed/05/7f03e680cc5249c4f96c9e4e845acde08eb1aee5bc216eff8a089baa4ddb/scipy-1.14.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:3a1b111fac6baec1c1d92f27e76511c9e7218f1695d61b59e05e0fe04dc59617", size = 23103567 }, + { url = "https://files.pythonhosted.org/packages/5e/fc/9f1413bef53171f379d786aabc104d4abeea48ee84c553a3e3d8c9f96a9c/scipy-1.14.1-cp311-cp311-macosx_14_0_x86_64.whl", hash = "sha256:8475230e55549ab3f207bff11ebfc91c805dc3463ef62eda3ccf593254524ce8", size = 25499102 }, + { url = "https://files.pythonhosted.org/packages/c2/4b/b44bee3c2ddc316b0159b3d87a3d467ef8d7edfd525e6f7364a62cd87d90/scipy-1.14.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:278266012eb69f4a720827bdd2dc54b2271c97d84255b2faaa8f161a158c3b37", size = 35586346 }, + { url = "https://files.pythonhosted.org/packages/93/6b/701776d4bd6bdd9b629c387b5140f006185bd8ddea16788a44434376b98f/scipy-1.14.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fef8c87f8abfb884dac04e97824b61299880c43f4ce675dd2cbeadd3c9b466d2", size = 41165244 }, + { url = "https://files.pythonhosted.org/packages/06/57/e6aa6f55729a8f245d8a6984f2855696c5992113a5dc789065020f8be753/scipy-1.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b05d43735bb2f07d689f56f7b474788a13ed8adc484a85aa65c0fd931cf9ccd2", size = 42817917 }, + { url = "https://files.pythonhosted.org/packages/ea/c2/5ecadc5fcccefaece775feadcd795060adf5c3b29a883bff0e678cfe89af/scipy-1.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:716e389b694c4bb564b4fc0c51bc84d381735e0d39d3f26ec1af2556ec6aad94", size = 44781033 }, + { url = "https://files.pythonhosted.org/packages/c0/04/2bdacc8ac6387b15db6faa40295f8bd25eccf33f1f13e68a72dc3c60a99e/scipy-1.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:631f07b3734d34aced009aaf6fedfd0eb3498a97e581c3b1e5f14a04164a456d", size = 39128781 }, + { url = "https://files.pythonhosted.org/packages/c8/53/35b4d41f5fd42f5781dbd0dd6c05d35ba8aa75c84ecddc7d44756cd8da2e/scipy-1.14.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:af29a935803cc707ab2ed7791c44288a682f9c8107bc00f0eccc4f92c08d6e07", size = 29939542 }, + { url = "https://files.pythonhosted.org/packages/66/67/6ef192e0e4d77b20cc33a01e743b00bc9e68fb83b88e06e636d2619a8767/scipy-1.14.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:2843f2d527d9eebec9a43e6b406fb7266f3af25a751aa91d62ff416f54170bc5", size = 23148375 }, + { url = "https://files.pythonhosted.org/packages/f6/32/3a6dedd51d68eb7b8e7dc7947d5d841bcb699f1bf4463639554986f4d782/scipy-1.14.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:eb58ca0abd96911932f688528977858681a59d61a7ce908ffd355957f7025cfc", size = 25578573 }, + { url = "https://files.pythonhosted.org/packages/f0/5a/efa92a58dc3a2898705f1dc9dbaf390ca7d4fba26d6ab8cfffb0c72f656f/scipy-1.14.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:30ac8812c1d2aab7131a79ba62933a2a76f582d5dbbc695192453dae67ad6310", size = 35319299 }, + { url = "https://files.pythonhosted.org/packages/8e/ee/8a26858ca517e9c64f84b4c7734b89bda8e63bec85c3d2f432d225bb1886/scipy-1.14.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f9ea80f2e65bdaa0b7627fb00cbeb2daf163caa015e59b7516395fe3bd1e066", size = 40849331 }, + { url = "https://files.pythonhosted.org/packages/a5/cd/06f72bc9187840f1c99e1a8750aad4216fc7dfdd7df46e6280add14b4822/scipy-1.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:edaf02b82cd7639db00dbff629995ef185c8df4c3ffa71a5562a595765a06ce1", size = 42544049 }, + { url = "https://files.pythonhosted.org/packages/aa/7d/43ab67228ef98c6b5dd42ab386eae2d7877036970a0d7e3dd3eb47a0d530/scipy-1.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:2ff38e22128e6c03ff73b6bb0f85f897d2362f8c052e3b8ad00532198fbdae3f", size = 44521212 }, ] [[package]] @@ -3125,6 +3956,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a6/24/4d91e05817e92e3a61c8a21e08fd0f390f5301f1c448b137c57c4bc6e543/semver-3.0.4-py3-none-any.whl", hash = "sha256:9c824d87ba7f7ab4a1890799cec8596f15c1241cb473404ea1cb0c55e4b04746", size = 17912 }, ] +[[package]] +name = "send2trash" +version = "1.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/fd/3a/aec9b02217bb79b87bbc1a21bc6abc51e3d5dcf65c30487ac96c0908c722/Send2Trash-1.8.3.tar.gz", hash = "sha256:b18e7a3966d99871aefeb00cfbcfdced55ce4871194810fc71f4aa484b953abf", size = 17394 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/b0/4562db6223154aa4e22f939003cb92514c79f3d4dccca3444253fd17f902/Send2Trash-1.8.3-py3-none-any.whl", hash = "sha256:0c31227e0bd08961c7665474a3d1ef7193929fedda4233843689baa056be46c9", size = 18072 }, +] + [[package]] name = "sentry-sdk" version = "2.22.0" @@ -3252,6 +4092,20 @@ dependencies = [ ] sdist = { url = "https://files.pythonhosted.org/packages/42/ad/0ed334e53b3f093817fe9973d08ceacc83854784e69547aeb1202ad8538a/spython-0.2.14.tar.gz", hash = "sha256:49e22fbbdebe456b27ca17d30061489db8e0f95e62be3623267a23b85e3ce0f0", size = 69374 } +[[package]] +name = "stack-data" +version = "0.6.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "asttokens" }, + { name = "executing" }, + { name = "pure-eval" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/28/e3/55dcc2cfbc3ca9c29519eb6884dd1415ecb53b0e934862d3559ddcb7e20b/stack_data-0.6.3.tar.gz", hash = "sha256:836a778de4fec4dcd1dcd89ed8abff8a221f58308462e1c4aa2a3cf30148f0b9", size = 44707 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f1/7b/ce1eafaf1a76852e2ec9b22edecf1daa58175c090266e9f6c64afcd81d91/stack_data-0.6.3-py3-none-any.whl", hash = "sha256:d5558e0c25a4cb0853cddad3d77da9891a08cb85dd9f9f91b9f8cd66e511e695", size = 24521 }, +] + [[package]] name = "sympy" version = "1.13.1" @@ -3296,6 +4150,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/44/71/f3e7c9b2ab67e28c572ab4e9d5fa3499e0d252650f96d8a3a03e26677f53/tensorboardX-2.6.2.2-py2.py3-none-any.whl", hash = "sha256:160025acbf759ede23fd3526ae9d9bfbfd8b68eb16c38a010ebe326dc6395db8", size = 101700 }, ] +[[package]] +name = "terminado" +version = "0.18.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "ptyprocess", marker = "os_name != 'nt'" }, + { name = "pywinpty", marker = "(os_name == 'nt' and platform_machine != 'x86_64' and sys_platform == 'darwin') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, + { name = "tornado" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8a/11/965c6fd8e5cc254f1fe142d547387da17a8ebfd75a3455f637c663fb38a0/terminado-0.18.1.tar.gz", hash = "sha256:de09f2c4b85de4765f7714688fff57d3e75bad1f909b589fde880460c753fd2e", size = 32701 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/9e/2064975477fdc887e47ad42157e214526dcad8f317a948dee17e1659a62f/terminado-0.18.1-py3-none-any.whl", hash = "sha256:a4468e1b37bb318f8a86514f65814e1afc977cf29b3992a4500d9dd305dcceb0", size = 14154 }, +] + [[package]] name = "threadpoolctl" version = "3.5.0" @@ -3305,6 +4173,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/4b/2c/ffbf7a134b9ab11a67b0cf0726453cedd9c5043a4fe7a35d1cefa9a1bcfb/threadpoolctl-3.5.0-py3-none-any.whl", hash = "sha256:56c1e26c150397e58c4926da8eeee87533b1e32bef131bd4bf6a2f45f3185467", size = 18414 }, ] +[[package]] +name = "tinycss2" +version = "1.4.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "webencodings" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/7a/fd/7a5ee21fd08ff70d3d33a5781c255cbe779659bd03278feb98b19ee550f4/tinycss2-1.4.0.tar.gz", hash = "sha256:10c0972f6fc0fbee87c3edb76549357415e94548c1ae10ebccdea16fb404a9b7", size = 87085 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e6/34/ebdc18bae6aa14fbee1a08b63c015c72b64868ff7dae68808ab500c492e2/tinycss2-1.4.0-py3-none-any.whl", hash = "sha256:3a49cf47b7675da0b15d0c6e1df8df4ebd96e9394bb905a5775adb0d884c5289", size = 26610 }, +] + [[package]] name = "tomli" version = "2.2.1" @@ -3473,6 +4353,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "jupyter" }, { name = "pytest" }, ] @@ -3489,7 +4370,7 @@ requires-dist = [ { name = "ledger", editable = "../../libs/ledger" }, { name = "lightning", specifier = "==2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, - { name = "ml4gw", specifier = ">=0.7.2" }, + { name = "ml4gw", specifier = ">=0.7.5" }, { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, @@ -3501,7 +4382,19 @@ requires-dist = [ ] [package.metadata.requires-dev] -dev = [{ name = "pytest", specifier = "~=7.3" }] +dev = [ + { name = "jupyter", specifier = ">=1.0.0" }, + { name = "pytest", specifier = "~=7.3" }, +] + +[[package]] +name = "traitlets" +version = "5.14.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/eb/79/72064e6a701c2183016abbbfedaba506d81e30e232a68c9f0d6f6fcd1574/traitlets-5.14.3.tar.gz", hash = "sha256:9ed0579d3502c94b4b3732ac120375cda96f923114522847de4b3bb98b96b6b7", size = 161621 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/00/c0/8f5d070730d7836adc9c9b6408dec68c6ced86b304a9b26a14df072a6e8c/traitlets-5.14.3-py3-none-any.whl", hash = "sha256:b74e89e397b1ed28cc831db7aea759ba6640cb3de13090ca145426688ff1ac4f", size = 85359 }, +] [[package]] name = "triton" @@ -3542,6 +4435,15 @@ all = [ { name = "python-rapidjson" }, ] +[[package]] +name = "types-python-dateutil" +version = "2.9.0.20250708" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/95/6bdde7607da2e1e99ec1c1672a759d42f26644bbacf939916e086db34870/types_python_dateutil-2.9.0.20250708.tar.gz", hash = "sha256:ccdbd75dab2d6c9696c350579f34cffe2c281e4c5f27a585b2a2438dd1d5c8ab", size = 15834 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/72/52/43e70a8e57fefb172c22a21000b03ebcc15e47e97f5cb8495b9c2832efb4/types_python_dateutil-2.9.0.20250708-py3-none-any.whl", hash = "sha256:4d6d0cc1cc4d24a2dc3816024e502564094497b713f7befda4d5bc7a8e3fd21f", size = 17724 }, +] + [[package]] name = "typeshed-client" version = "2.7.0" @@ -3573,6 +4475,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/0f/dd/84f10e23edd882c6f968c21c2434fe67bd4a528967067515feca9e611e5e/tzdata-2025.1-py2.py3-none-any.whl", hash = "sha256:7e127113816800496f027041c570f50bcd464a020098a3b6b199517772303639", size = 346762 }, ] +[[package]] +name = "uri-template" +version = "1.3.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/31/c7/0336f2bd0bcbada6ccef7aaa25e443c118a704f828a0620c6fa0207c1b64/uri-template-1.3.0.tar.gz", hash = "sha256:0e00f8eb65e18c7de20d595a14336e9f337ead580c70934141624b6d1ffdacc7", size = 21678 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e7/00/3fca040d7cf8a32776d3d81a00c8ee7457e00f80c649f1e4a863c8321ae9/uri_template-1.3.0-py3-none-any.whl", hash = "sha256:a44a133ea12d44a0c0f06d7d42a52d71282e77e2f937d8abd5655b8d56fc1363", size = 11140 }, +] + [[package]] name = "urllib3" version = "1.26.20" @@ -3658,6 +4569,51 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/1f/d3/1996ef42e58a049d6b2d6c3e3f0c8d7d38a286f707c515672a928fd9eb6c/wandb-0.18.7-py3-none-win_amd64.whl", hash = "sha256:4ba9fda6dd7db02a23c6b302411fe26c3fcfea4947cc130a65e1de19812d324e", size = 15472555 }, ] +[[package]] +name = "wcwidth" +version = "0.2.13" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 }, +] + +[[package]] +name = "webcolors" +version = "24.11.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/29/061ec845fb58521848f3739e466efd8250b4b7b98c1b6c5bf4d40b419b7e/webcolors-24.11.1.tar.gz", hash = "sha256:ecb3d768f32202af770477b8b65f318fa4f566c22948673a977b00d589dd80f6", size = 45064 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/60/e8/c0e05e4684d13459f93d312077a9a2efbe04d59c393bc2b8802248c908d4/webcolors-24.11.1-py3-none-any.whl", hash = "sha256:515291393b4cdf0eb19c155749a096f779f7d909f7cceea072791cb9095b92e9", size = 14934 }, +] + +[[package]] +name = "webencodings" +version = "0.5.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0b/02/ae6ceac1baeda530866a85075641cec12989bd8d31af6d5ab4a3e8c92f47/webencodings-0.5.1.tar.gz", hash = "sha256:b36a1c245f2d304965eb4e0a82848379241dc04b865afcc4aab16748587e1923", size = 9721 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl", hash = "sha256:a0af1213f3c2226497a97e2b3aa01a7e4bee4f403f95be16fc9acd2947514a78", size = 11774 }, +] + +[[package]] +name = "websocket-client" +version = "1.8.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/e6/30/fba0d96b4b5fbf5948ed3f4681f7da2f9f64512e1d303f94b4cc174c24a5/websocket_client-1.8.0.tar.gz", hash = "sha256:3239df9f44da632f96012472805d40a23281a991027ce11d2f45a6f24ac4c3da", size = 54648 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/84/44687a29792a70e111c5c477230a72c4b957d88d16141199bf9acb7537a3/websocket_client-1.8.0-py3-none-any.whl", hash = "sha256:17b44cc997f5c498e809b22cdf2d9c7a9e71c02c8cc2b6c56e7c2d1239bfa526", size = 58826 }, +] + +[[package]] +name = "widgetsnbextension" +version = "4.0.14" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/41/53/2e0253c5efd69c9656b1843892052a31c36d37ad42812b5da45c62191f7e/widgetsnbextension-4.0.14.tar.gz", hash = "sha256:a3629b04e3edb893212df862038c7232f62973373869db5084aed739b437b5af", size = 1097428 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ca/51/5447876806d1088a0f8f71e16542bf350918128d0a69437df26047c8e46f/widgetsnbextension-4.0.14-py3-none-any.whl", hash = "sha256:4875a9eaf72fbf5079dc372a51a9f268fc38d46f767cbf85c43a36da5cb9b575", size = 2196503 }, +] + [[package]] name = "wrapt" version = "1.17.2" From 9cd1ba8c1d9cb70c64a769770e592f6c9819e35e Mon Sep 17 00:00:00 2001 From: William Benoit Date: Sat, 16 Aug 2025 03:51:27 -0700 Subject: [PATCH 19/32] Re-organize data flow to get multimodal export and inference working --- libs/utils/utils/preprocessing.py | 18 +++++++++++++----- projects/export/export/snapshotter.py | 15 +++++++++++++++ projects/train/train/callbacks.py | 7 +++---- .../train/train/data/supervised/multimodal.py | 8 ++++---- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/libs/utils/utils/preprocessing.py b/libs/utils/utils/preprocessing.py index 948ef0869..8e084c60d 100644 --- a/libs/utils/utils/preprocessing.py +++ b/libs/utils/utils/preprocessing.py @@ -111,14 +111,12 @@ def __init__( highpass: Optional[float] = None, lowpass: Optional[float] = None, return_whitened: bool = False, - return_asd: bool = False, ) -> None: super().__init__() self.stride_size = int(sample_rate / inference_sampling_rate) self.kernel_size = int(kernel_length * sample_rate) self.augmentor = augmentor self.return_whitened = return_whitened - self.return_asd = return_asd # do foreground length calculation in units of samples, # then convert back to length to guard for intification @@ -136,6 +134,9 @@ def __init__( ) self.whitener = Whiten(fduration, sample_rate, highpass, lowpass) + freqs = torch.fft.rfftfreq(size, d=1 / sample_rate) + self.freq_mask = (freqs > highpass) & (freqs < lowpass) + def forward(self, x: Tensor) -> Tensor: # Get the number of channels so we know how to # reshape `x` appropriately after unfolding to @@ -155,9 +156,16 @@ def forward(self, x: Tensor) -> Tensor: x = x.float() - asd = psd**0.5 - asd *= 1e23 - asd = asd.float() + if self.return_asd: + asd = psd**0.5 + asd = asd.float() + asd = torch.nn.functional.interpolate( + asd.unsqueeze(0), + size=(len(self.freq_mask),), + mode="linear", + ) + asd = asd[:, :, self.freq_mask] + asd = asd.expand(x.shape[0], -1, -1) # unfold x and then put it into the expected shape. # Note that if x has both signal and background diff --git a/projects/export/export/snapshotter.py b/projects/export/export/snapshotter.py index 67411dad5..834f16bcd 100644 --- a/projects/export/export/snapshotter.py +++ b/projects/export/export/snapshotter.py @@ -5,7 +5,11 @@ from hermes.quiver import Platform from hermes.quiver.streaming import utils as streaming_utils +<<<<<<< HEAD from utils.preprocessing import BackgroundSnapshotter +======= +from utils.preprocessing import BackgroundSnapshotter, MultiModalPreprocessor +>>>>>>> b074291 (Re-organize data flow to get multimodal export and inference working) if TYPE_CHECKING: from hermes.quiver.model import EnsembleModel, ExposedTensor @@ -39,6 +43,17 @@ def add_streaming_input_preprocessor( ) -> "ExposedTensor": """Create a snapshotter model and add it to the repository""" +<<<<<<< HEAD +======= + batch_size, num_ifos, *kernel_size = input.shape + if q is not None: + if len(kernel_size) != 2: + raise ValueError( + "If q is not None, the input kernel should be 2D, " + f"got {len(kernel_size)} dimension(s)" + ) + +>>>>>>> b074291 (Re-organize data flow to get multimodal export and inference working) snapshotter = BackgroundSnapshotter( psd_length=psd_length, kernel_length=kernel_length, diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index b3c09a95d..2075a0896 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -90,8 +90,8 @@ def on_train_start(self, trainer, pl_module): ) background = background.to(device) signals = signals.to(device) - X_bg, X_inj, val_asds = trainer.datamodule.build_val_batches( - background, signals + X_bg, X_inj, val_X_bg_fft, val_X_fg_fft = ( + trainer.datamodule.build_val_batches(background, signals) ) # Make background and injected validation data into # tuples for consistency if necessary @@ -108,7 +108,6 @@ def on_train_start(self, trainer, pl_module): for i, x in enumerate(X): h5file[f"input_{i}"] = x.cpu().numpy() h5file["y"] = y.cpu().numpy() - h5file["asds"] = train_asds.cpu().numpy() s3_file.write(f.getvalue()) with s3.open(f"{save_dir}/val_batch.hdf5", "wb") as s3_file: @@ -123,7 +122,7 @@ def on_train_start(self, trainer, pl_module): for i, x in enumerate(X): f[f"input_{i}"] = x.cpu().numpy() f["y"] = y.cpu().numpy() - f["asds"] = train_asds.cpu().numpy() + f["X_fft"] = X_fft.cpu().numpy() with h5py.File( os.path.join(save_dir, "val_batch.hdf5"), "w" diff --git a/projects/train/train/data/supervised/multimodal.py b/projects/train/train/data/supervised/multimodal.py index 4c02b0c79..9cf9df96c 100644 --- a/projects/train/train/data/supervised/multimodal.py +++ b/projects/train/train/data/supervised/multimodal.py @@ -36,8 +36,8 @@ def on_after_batch_transfer(self, batch, _): if self.trainer.training: # if we're training, perform random augmentations # on input data and use it to impact labels - [X], waveforms = batch - (X, X_fft), y = self.inject(X, waveforms) + [X] = batch + (X, X_fft), y = self.inject(X) batch = (X, X_fft, y) elif self.trainer.validating or self.trainer.sanity_checking: # If we're in validation mode but we're not validating @@ -86,8 +86,8 @@ def compute_frequency_domain_data(self, X, psds): return X_fft - def inject(self, X, waveforms): - X, y, psds = super().augment(X, waveforms) + def inject(self, X): + X, y, psds = super().inject(X) X = self.whitener(X, psds) X_fft = self.compute_frequency_domain_data(X, psds) From 8997aebd5b649b3cb33d5b261c5cbf4634dc845c Mon Sep 17 00:00:00 2001 From: William Benoit Date: Tue, 19 Aug 2025 11:50:30 -0700 Subject: [PATCH 20/32] Add option for constraints on prior --- libs/ledger/ledger/injections.py | 10 ++++++++-- libs/priors/priors/priors.py | 8 +++++--- projects/train/train/callbacks.py | 1 + projects/train/train/constraints.py | 18 ++++++++++++++++++ projects/train/train/conversion.py | 22 ++++++++++++++++++++++ projects/train/train/prior.py | 27 ++++++++++++++++++++++++++- 6 files changed, 80 insertions(+), 6 deletions(-) create mode 100644 projects/train/train/constraints.py diff --git a/libs/ledger/ledger/injections.py b/libs/ledger/ledger/injections.py index 114d5eb39..b379f45b7 100644 --- a/libs/ledger/ledger/injections.py +++ b/libs/ledger/ledger/injections.py @@ -506,8 +506,14 @@ def waveforms(self) -> np.ndarray: if self._waveforms is None: fields = sorted(self.waveform_fields) waveforms = [getattr(self, i) for i in fields] - waveforms = np.stack(waveforms, axis=1) - self._waveforms = waveforms + shape = ( + waveforms[0].shape[0], + len(fields), + waveforms[0].shape[-1], + ) + self._waveforms = np.zeros(shape, dtype=np.float32) + for i, field in enumerate(fields): + self._waveforms[:, i, :] = getattr(self, field) return self._waveforms def num_waveform_fields(self): diff --git a/libs/priors/priors/priors.py b/libs/priors/priors/priors.py index 9220783c3..77eace1a8 100644 --- a/libs/priors/priors/priors.py +++ b/libs/priors/priors/priors.py @@ -204,11 +204,13 @@ def end_o3_ratesandpops_bns( prior["redshift"] = UniformSourceFrame( 0, 0.15, name="redshift", cosmology=cosmology ) - spin_prior = uniform_spin() - for key, value in spin_prior.items(): - prior[key] = value + prior["psi"] = 0 prior["a_1"] = Uniform(0, 0.4) prior["a_2"] = Uniform(0, 0.4) + prior["tilt_1"] = 0 + prior["tilt_2"] = 0 + prior["phi_12"] = 0 + prior["phi_jl"] = 0 detector_frame_prior = False return prior, detector_frame_prior diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index 2075a0896..2d35ce284 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -70,6 +70,7 @@ def on_train_end(self, trainer, pl_module): class SaveAugmentedBatch(Callback): def on_train_start(self, trainer, pl_module): if trainer.global_rank == 0: + breakpoint() # find device module is on device = pl_module.device save_dir = trainer.logger.save_dir diff --git a/projects/train/train/constraints.py b/projects/train/train/constraints.py new file mode 100644 index 000000000..aa2806a96 --- /dev/null +++ b/projects/train/train/constraints.py @@ -0,0 +1,18 @@ +import torch + + +def mass_ratio_constraint( + parameters: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + """ + Enforce mass_2 < mass_1. Assumes that the given parameter dictionary + already contains the component masses + """ + try: + mask = parameters["mass_1"] > parameters["mass_2"] + except KeyError as exc: + raise ValueError( + "Parameter dictionary did not contain component masses" + ) from exc + + return {key: parameters[key][mask] for key in parameters.keys()} diff --git a/projects/train/train/conversion.py b/projects/train/train/conversion.py index d12cb7bf5..fbacb9928 100644 --- a/projects/train/train/conversion.py +++ b/projects/train/train/conversion.py @@ -66,3 +66,25 @@ def aligned_to_lalsimulation_parameters( parameters["s1z"] = parameters["chi1"] parameters["s2z"] = parameters["chi2"] return parameters + + +def component_aligned_to_lalsimulation_parameters( + parameters: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + """ + Convert aligned spin parameters to lalsimulation parameters + and compnent masses to chirp mass and mass ratio + """ + m1, m2 = parameters["mass_1"], parameters["mass_2"] + parameters["chirp_mass"] = (m1 * m2) ** (3 / 5) / (m1 + m2) ** (1 / 5) + parameters["mass_ratio"] = m2 / m1 + + parameters["s1x"] = torch.zeros_like(m1) + parameters["s1y"] = torch.zeros_like(m1) + + parameters["s2x"] = torch.zeros_like(m1) + parameters["s2y"] = torch.zeros_like(m1) + + parameters["s1z"] = parameters["chi1"] + parameters["s2z"] = parameters["chi2"] + return parameters diff --git a/projects/train/train/prior.py b/projects/train/train/prior.py index d0c2b5b36..06c51a3b2 100644 --- a/projects/train/train/prior.py +++ b/projects/train/train/prior.py @@ -8,6 +8,7 @@ def __init__( self, priors: dict[str, torch.distributions.Distribution], conversion_function: Optional[Callable] = None, + constraint_function: Optional[Callable] = None, ): """ A class for sampling parameters from a prior distribution @@ -20,10 +21,15 @@ def __init__( conversion_function: A callable that takes a dictionary of sampled parameters and returns a dictionary of waveform generation parameters + constraint_function: + A callable that takes a dictionary of sampled parameters, + discards any that don't satisfy the constraint, and returns + the remaining parameters """ super().__init__() self.priors = priors self.conversion_function = conversion_function or (lambda x: x) + self.constraint_function = constraint_function or (lambda x: x) def __call__( self, @@ -45,11 +51,30 @@ def __call__( # to from sampled parameters to # waveform generation parameters parameters = self.conversion_function(parameters) + + # Discard any samples that don't meet the constraint + parameters = self.constraint_function(parameters) + + keys = list(parameters.keys()) + while len(parameters[keys[0]]) < N: + new_params = { + k: v.sample((N,)).to(device) for k, v in self.priors.items() + } + new_params = self.conversion_function(parameters) + new_params = self.constraint_function(parameters) + parameters = { + key: torch.cat([parameters[key], new_params[key]]) + for key in keys + } + return parameters def log_prob(self, samples: dict[str, torch.Tensor]) -> torch.Tensor: """ - Calculate the log probability of samples under the prior + Calculate the log probability of samples under the prior. + TODO: This calculation is incorrect if a constraint function + is specified. We don't use this function anywhere currently, + so not too big a deal, but we'll need to fix this if we do. Args: samples: From 061bb7cda24743d06baf3e53cfda970c429b4639 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 27 Aug 2025 09:38:07 -0500 Subject: [PATCH 21/32] Return only N sampled parameters --- projects/train/train/callbacks.py | 1 - projects/train/train/prior.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index 2d35ce284..2075a0896 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -70,7 +70,6 @@ def on_train_end(self, trainer, pl_module): class SaveAugmentedBatch(Callback): def on_train_start(self, trainer, pl_module): if trainer.global_rank == 0: - breakpoint() # find device module is on device = pl_module.device save_dir = trainer.logger.save_dir diff --git a/projects/train/train/prior.py b/projects/train/train/prior.py index 06c51a3b2..48fab8c76 100644 --- a/projects/train/train/prior.py +++ b/projects/train/train/prior.py @@ -67,6 +67,8 @@ def __call__( for key in keys } + return {k: v[:N] for k, v in parameters.items()} + return parameters def log_prob(self, samples: dict[str, torch.Tensor]) -> torch.Tensor: From 13c96c9cc5545eefd316b05568a8cfc11990a774 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Wed, 27 Aug 2025 09:58:41 -0500 Subject: [PATCH 22/32] Update ml4gw to 0.7.7 --- projects/train/pyproject.toml | 2 +- projects/train/uv.lock | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/projects/train/pyproject.toml b/projects/train/pyproject.toml index d47050cc4..8adbbc07e 100644 --- a/projects/train/pyproject.toml +++ b/projects/train/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "fsspec[s3]>=2024,<2025", "urllib3>=1.25.4,<1.27", "utils", - "ml4gw>=0.7.5", + "ml4gw>=0.7.7", "aframe", "ledger", "priors", diff --git a/projects/train/uv.lock b/projects/train/uv.lock index 57dc9764c..be5a7d7b8 100644 --- a/projects/train/uv.lock +++ b/projects/train/uv.lock @@ -2319,7 +2319,7 @@ wheels = [ [[package]] name = "ml4gw" -version = "0.7.5" +version = "0.7.7" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jaxtyping" }, @@ -2328,9 +2328,9 @@ dependencies = [ { name = "torch" }, { name = "torchaudio" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/fd/bb/99c1e75bbacd81f8cc887380cd0dc260066959cc81926fc2472b8842dc0a/ml4gw-0.7.5.tar.gz", hash = "sha256:3c776664bd8594d3b87c450cf0a0e5369f2ba8bfe2ad5923236b31d768eb8fb5", size = 811225 } +sdist = { url = "https://files.pythonhosted.org/packages/3a/50/c4a36b1bdf7fdd99405d9c2643da087a848603bdc64875f5f76824ad897f/ml4gw-0.7.7.tar.gz", hash = "sha256:80223a40b9a36e6cb11883924a95e5bf1b52d39c906cb3c1da5c0812a9f2afd8", size = 115393 } wheels = [ - { url = "https://files.pythonhosted.org/packages/32/7d/13ae4a5199dc081b7855202131a2ea69e322e544c45ed2cc6d7ca5786aa8/ml4gw-0.7.5-py3-none-any.whl", hash = "sha256:430fb2994a820c659806e4bff745395a04e466bc4c50c0b3c14f680b6fab187c", size = 125108 }, + { url = "https://files.pythonhosted.org/packages/7e/a9/5f8648cce66b07481da4d58a91ba16143ebfcf97d31b2929745910a4f1b7/ml4gw-0.7.7-py3-none-any.whl", hash = "sha256:0ca8c36afa45ca84a249070027ab89c73ab99b681ae6d7eda25de4643728dcc9", size = 125625 }, ] [[package]] @@ -4370,7 +4370,7 @@ requires-dist = [ { name = "ledger", editable = "../../libs/ledger" }, { name = "lightning", specifier = "==2.2.1" }, { name = "lightray", specifier = ">=0.2.3" }, - { name = "ml4gw", specifier = ">=0.7.5" }, + { name = "ml4gw", specifier = ">=0.7.7" }, { name = "priors", editable = "../../libs/priors" }, { name = "ray", extras = ["default", "tune"], specifier = ">=2.8.0,<3" }, { name = "s3fs", specifier = ">=2024,<2025" }, From 8192bfe0101fc6aa1b641fb39887ab7ef21807a6 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Thu, 4 Sep 2025 10:59:56 -0700 Subject: [PATCH 23/32] Restore accidentally-deleted snapshotter.py --- projects/export/export/snapshotter.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/projects/export/export/snapshotter.py b/projects/export/export/snapshotter.py index 834f16bcd..555bdcdcd 100644 --- a/projects/export/export/snapshotter.py +++ b/projects/export/export/snapshotter.py @@ -5,11 +5,7 @@ from hermes.quiver import Platform from hermes.quiver.streaming import utils as streaming_utils -<<<<<<< HEAD from utils.preprocessing import BackgroundSnapshotter -======= -from utils.preprocessing import BackgroundSnapshotter, MultiModalPreprocessor ->>>>>>> b074291 (Re-organize data flow to get multimodal export and inference working) if TYPE_CHECKING: from hermes.quiver.model import EnsembleModel, ExposedTensor @@ -43,17 +39,8 @@ def add_streaming_input_preprocessor( ) -> "ExposedTensor": """Create a snapshotter model and add it to the repository""" -<<<<<<< HEAD -======= batch_size, num_ifos, *kernel_size = input.shape - if q is not None: - if len(kernel_size) != 2: - raise ValueError( - "If q is not None, the input kernel should be 2D, " - f"got {len(kernel_size)} dimension(s)" - ) ->>>>>>> b074291 (Re-organize data flow to get multimodal export and inference working) snapshotter = BackgroundSnapshotter( psd_length=psd_length, kernel_length=kernel_length, From a9f6b115aa7f683fbbf5f4ea0d42fa4045a2fbbb Mon Sep 17 00:00:00 2001 From: William Benoit Date: Thu, 9 Oct 2025 09:19:18 -0700 Subject: [PATCH 24/32] Clean up after rebase --- aframe/tasks/export/export.py | 2 +- aframe/tasks/train/base.py | 2 +- libs/utils/utils/preprocessing.py | 25 ++----------------------- projects/export/export/snapshotter.py | 2 -- projects/train/train/callbacks.py | 9 ++++----- 5 files changed, 8 insertions(+), 32 deletions(-) diff --git a/aframe/tasks/export/export.py b/aframe/tasks/export/export.py index 3e4731a32..5c0f2f9f9 100644 --- a/aframe/tasks/export/export.py +++ b/aframe/tasks/export/export.py @@ -29,7 +29,7 @@ class ExportParams(law.Task): ) train_task = luigi.TaskParameter() platform = luigi.Parameter( - default="TORCHSCRIPT", + default="TENSORRT", description="Platform to use for exporting model for inference", ) diff --git a/aframe/tasks/train/base.py b/aframe/tasks/train/base.py index 1d8aa969a..2925c5e21 100644 --- a/aframe/tasks/train/base.py +++ b/aframe/tasks/train/base.py @@ -60,7 +60,7 @@ class TrainBaseParameters(law.Task): "It is expected to contain a `val_waveforms.hdf5` file of " "signals for validation, a `/background` sub-directory containing " "background, and a `train_waveforms.hdf5` file containing " - "training signals if `generate_train_waveforms` is set to False.", + "training signals if `precompute_train_waveforms` is set to True.", default=paths().train_datadir, ) precompute_train_waveforms = luigi.BoolParameter( diff --git a/libs/utils/utils/preprocessing.py b/libs/utils/utils/preprocessing.py index 8e084c60d..e06e6c788 100644 --- a/libs/utils/utils/preprocessing.py +++ b/libs/utils/utils/preprocessing.py @@ -134,9 +134,6 @@ def __init__( ) self.whitener = Whiten(fduration, sample_rate, highpass, lowpass) - freqs = torch.fft.rfftfreq(size, d=1 / sample_rate) - self.freq_mask = (freqs > highpass) & (freqs < lowpass) - def forward(self, x: Tensor) -> Tensor: # Get the number of channels so we know how to # reshape `x` appropriately after unfolding to @@ -154,19 +151,6 @@ def forward(self, x: Tensor) -> Tensor: x, psd = self.psd_estimator(x.double()) whitened = self.whitener(x, psd) - x = x.float() - - if self.return_asd: - asd = psd**0.5 - asd = asd.float() - asd = torch.nn.functional.interpolate( - asd.unsqueeze(0), - size=(len(self.freq_mask),), - mode="linear", - ) - asd = asd[:, :, self.freq_mask] - asd = asd.expand(x.shape[0], -1, -1) - # unfold x and then put it into the expected shape. # Note that if x has both signal and background # batch elements, they will be interleaved along @@ -176,14 +160,9 @@ def forward(self, x: Tensor) -> Tensor: if self.augmentor is not None: x = self.augmentor(x) - if self.return_whitened and self.return_asd: - return x, whitened, asd - elif self.return_whitened: + if self.return_whitened: return x, whitened - elif self.return_asd: - return x, asd - else: - return x + return x class MultiModalPreprocessor(torch.nn.Module): diff --git a/projects/export/export/snapshotter.py b/projects/export/export/snapshotter.py index 555bdcdcd..67411dad5 100644 --- a/projects/export/export/snapshotter.py +++ b/projects/export/export/snapshotter.py @@ -39,8 +39,6 @@ def add_streaming_input_preprocessor( ) -> "ExposedTensor": """Create a snapshotter model and add it to the repository""" - batch_size, num_ifos, *kernel_size = input.shape - snapshotter = BackgroundSnapshotter( psd_length=psd_length, kernel_length=kernel_length, diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index 2075a0896..f1fdd0030 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -45,7 +45,7 @@ def on_train_end(self, trainer, pl_module): device = pl_module.device [X] = next(iter(trainer.train_dataloader)) X = X.to(device) - X, y = trainer.datamodule.augment(X, waveforms) + X, y = trainer.datamodule.inject(X) if isinstance(X, tuple): X = tuple(i.cpu() for i in X) else: @@ -78,7 +78,7 @@ def on_train_start(self, trainer, pl_module): [X] = next(iter(trainer.train_dataloader)) X = X.to(device) - X, y = trainer.datamodule.augment(X, waveforms) + X, y = trainer.datamodule.inject(X) # If X is not a tuple, make it one for consistency # of format for saving to file below if not isinstance(X, tuple): @@ -90,8 +90,8 @@ def on_train_start(self, trainer, pl_module): ) background = background.to(device) signals = signals.to(device) - X_bg, X_inj, val_X_bg_fft, val_X_fg_fft = ( - trainer.datamodule.build_val_batches(background, signals) + X_bg, X_inj = trainer.datamodule.build_val_batches( + background, signals ) # Make background and injected validation data into # tuples for consistency if necessary @@ -122,7 +122,6 @@ def on_train_start(self, trainer, pl_module): for i, x in enumerate(X): f[f"input_{i}"] = x.cpu().numpy() f["y"] = y.cpu().numpy() - f["X_fft"] = X_fft.cpu().numpy() with h5py.File( os.path.join(save_dir, "val_batch.hdf5"), "w" From b7cb50980e934904f59df165c48860eedc9d1056 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Thu, 9 Oct 2025 13:08:12 -0700 Subject: [PATCH 25/32] Address comments --- aframe/tasks/data/waveforms/training.py | 4 ++-- projects/train/train.yaml | 4 +++- .../{frequency_domain.py => time_frequency_domain.py} | 0 3 files changed, 5 insertions(+), 3 deletions(-) rename projects/train/train/data/supervised/{frequency_domain.py => time_frequency_domain.py} (100%) diff --git a/aframe/tasks/data/waveforms/training.py b/aframe/tasks/data/waveforms/training.py index 54baab85f..c04d49a88 100644 --- a/aframe/tasks/data/waveforms/training.py +++ b/aframe/tasks/data/waveforms/training.py @@ -66,8 +66,8 @@ def run(self): @inherits(DeployTrainingWaveforms) class TrainingWaveforms(AframeDataTask): """ - Launch condorized generation of validation waveforms via - rejection sampling, and merge results into a single file + Launch condorized generation of training waveforms, + and merge results into a single file """ def __init__(self, *args, **kwargs): diff --git a/projects/train/train.yaml b/projects/train/train.yaml index 007bfa67d..d9dc84aca 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -34,10 +34,12 @@ data: # data_dir: # ifos: - # preprocessing args + # dataloading args batch_size: 384 batches_per_epoch: 3700 num_files_per_batch: 10 + + # preprocessing args # kernel_length: psd_length: 8 # fduration: diff --git a/projects/train/train/data/supervised/frequency_domain.py b/projects/train/train/data/supervised/time_frequency_domain.py similarity index 100% rename from projects/train/train/data/supervised/frequency_domain.py rename to projects/train/train/data/supervised/time_frequency_domain.py From 6369b7b6081c2a2a6c981bfdb682a42091d8b0d0 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Thu, 9 Oct 2025 14:25:56 -0700 Subject: [PATCH 26/32] Restored original method of loading waveforms from disk --- projects/train/train.yaml | 5 + projects/train/train/data/base.py | 78 +++++- .../train/train/data/supervised/multimodal.py | 33 --- .../train/train/data/supervised/supervised.py | 10 +- .../train/train/data/waveforms/__init__.py | 2 + projects/train/train/data/waveforms/loader.py | 248 ++++++++++++++++-- projects/train/train/model/supervised.py | 4 +- 7 files changed, 311 insertions(+), 69 deletions(-) create mode 100644 projects/train/train/data/waveforms/__init__.py diff --git a/projects/train/train.yaml b/projects/train/train.yaml index d9dc84aca..5b81a2984 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -38,6 +38,9 @@ data: batch_size: 384 batches_per_epoch: 3700 num_files_per_batch: 10 + # Used for loading waveforms from disk + # chunks_per_epoch: 10 + # chunk_size: 10000 # preprocessing args # kernel_length: @@ -68,6 +71,8 @@ data: # max_snr: 100 # alpha: -3 # decay_steps: 989 + # If loading waveforms from disk, waveform_sampler should be + # an instance of train.data.waveforms.WaveformLoader waveform_sampler: class_path: train.data.waveforms.generator.cbc.CBCGenerator init_args: diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 5d6291dc9..ed632a41f 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -17,7 +17,12 @@ from train import augmentations as aug from train.data.utils import fs as fs_utils from train.metrics import get_timeslides -from train.data.waveforms.sampler import WaveformSampler +from train.data.waveforms import ( + ChunkedWaveformDataset, + Hdf5WaveformLoader, + WaveformLoader, + WaveformSampler, +) from utils.preprocessing import PsdEstimator Tensor = torch.Tensor @@ -130,6 +135,13 @@ class BaseAframeDataset(pl.LightningDataModule): valid_livetime: Total livetime in seconds of the validation data to be generated via timeslides. + chunks_per_epoch: + Number of chunks of waveforms to load from disk + each epoch. Not used if generating waveforms + during training. + chunk_size: + Number of waveforms to load in each chunk. + Not used if generating waveforms during training. verbose: Whether to log debug information during training. """ @@ -167,6 +179,9 @@ def __init__( min_valid_duration: float = 15000, valid_livetime: float = (3600 * 12), max_num_workers: int = 6, + # dataloading args + chunks_per_epoch: int = 1, + chunk_size: int = 10000, verbose: bool = False, ) -> None: super().__init__() @@ -190,6 +205,10 @@ def __init__( self.dec, self.psi, self.phi = dec, psi, phi self.waveform_sampler = waveform_sampler + # If we're using a `WaveformLoader`, we're loading + # training waveforms from disk, so have a flag tp + # indicate that + self.waveforms_from_disk = isinstance(waveform_sampler, WaveformLoader) self.snr_sampler = snr_sampler # generate our local node data directory @@ -482,8 +501,12 @@ def on_after_batch_transfer(self, batch, _): if self.trainer.training: # if we're training, perform random augmentations # on input data and use it to impact labels - [batch] = batch - batch = self.inject(batch) + if self.waveforms_from_disk: + [batch], waveforms = batch + batch = self.inject(batch, waveforms) + else: + [batch] = batch + batch = self.inject(batch) elif self.trainer.validating or self.trainer.sanity_checking: # If we're in validation mode but we're not validating # on the local device, the relevant tensors will be @@ -636,4 +659,51 @@ def train_dataloader(self) -> torch.utils.data.DataLoader: num_workers=self.num_workers, pin_memory=pin_memory, ) - return dataloader + + # If we're not loading waveforms from disk, just return + # the background dataloader + if not self.waveforms_from_disk: + return dataloader + + # build iterator for waveform loading + # that will load chunks of waveforms + # to be sampled from + waveform_loader = Hdf5WaveformLoader( + self.train_waveform_fnames, + batch_size=self.hparams.chunk_size, + batches_per_epoch=self.hparams.chunks_per_epoch or 1, + channels=["cross", "plus"], + path="waveforms", + ) + # calculate how many batches we'll sample from each chunk + # based on requested chunks per epoch and batches per epoch + world_size, _ = self.get_world_size_and_rank() + batches_per_epoch = self.hparams.batches_per_epoch // world_size + batches_per_chunk = ( + int(batches_per_epoch // self.hparams.chunks_per_epoch) + 1 + ) + self._logger.info( + f"Training on pool of {waveform_loader.total} waveforms. " + f"Sampling {batches_per_chunk} batches per chunk " + f"from {self.hparams.chunks_per_epoch} chunks " + f"of size {self.hparams.chunk_size} each epoch" + ) + + # multiprocess waveform chunk loader + # so we don't have to wait for waveforms + waveform_loader = torch.utils.data.DataLoader( + waveform_loader, + num_workers=2, + pin_memory=pin_memory, + persistent_workers=True, + ) + + # build a dataset that will sample from + # iterator of chunks of waveforms + waveform_dataset = ChunkedWaveformDataset( + waveform_loader, + batch_size=self.hparams.batch_size, + batches_per_chunk=batches_per_chunk, + ) + + return ZippedDataset(dataloader, waveform_dataset) diff --git a/projects/train/train/data/supervised/multimodal.py b/projects/train/train/data/supervised/multimodal.py index 9cf9df96c..64c914eec 100644 --- a/projects/train/train/data/supervised/multimodal.py +++ b/projects/train/train/data/supervised/multimodal.py @@ -25,39 +25,6 @@ def build_val_batches(self, background, signals): return (X_bg, X_bg_fft), (X_fg, X_fg_fft) - def on_after_batch_transfer(self, batch, _): - """ - This is a method inherited from the DataModule - base class that gets called after data returned - by a dataloader gets put on the local device, - but before it gets passed to the LightningModule. - Use this to do on-device augmentation/preprocessing. - """ - if self.trainer.training: - # if we're training, perform random augmentations - # on input data and use it to impact labels - [X] = batch - (X, X_fft), y = self.inject(X) - batch = (X, X_fft, y) - elif self.trainer.validating or self.trainer.sanity_checking: - # If we're in validation mode but we're not validating - # on the local device, the relevant tensors will be - # empty, so just pass them through with a 0 shift to - # indicate that this should be ignored - [background, _, timeslide_idx], [signals] = batch - - # If we're validating, unfold the background - # data into a batch of overlapping kernels now that - # we're on the GPU so that we're not transferring as - # much data from CPU to GPU. Once everything is - # on-device, pre-inject signals into background. - shift = self.timeslides[timeslide_idx].shift_size - (X_bg, X_bg_fft), (X_fg, X_fg_fft) = self.build_val_batches( - background, signals - ) - batch = (shift, X_bg, X_fg, X_bg_fft, X_fg_fft) - return batch - def compute_frequency_domain_data(self, X, psds): asds = psds**0.5 diff --git a/projects/train/train/data/supervised/supervised.py b/projects/train/train/data/supervised/supervised.py index e5904a7da..49ea5ea50 100644 --- a/projects/train/train/data/supervised/supervised.py +++ b/projects/train/train/data/supervised/supervised.py @@ -43,7 +43,7 @@ def sample_prob(self): return self.hparams.waveform_prob + self.swap_prob + self.mute_prob @torch.no_grad() - def inject(self, X): + def inject(self, X, waveforms=None): X, psds = self.psd_estimator(X) X = self.inverter(X) X = self.reverser(X) @@ -54,7 +54,13 @@ def inject(self, X): mask = rvs < self.sample_prob dec, psi, phi = self.sample_extrinsic(X[mask]) - hc, hp = self.waveform_sampler.sample(X[mask]) + # If waveforms were passed, we're loading them from + # disk and we can slice out the ones we want. + # If not, we're generating them on the fly. + if waveforms is not None: + hc, hp = waveforms[mask, 0], waveforms[mask, 1] + else: + hc, hp = self.waveform_sampler.sample(X[mask]) snrs = self.snr_sampler.sample((mask.sum().item(),)).to(X.device) responses = self.projector( diff --git a/projects/train/train/data/waveforms/__init__.py b/projects/train/train/data/waveforms/__init__.py new file mode 100644 index 000000000..9bc09b6db --- /dev/null +++ b/projects/train/train/data/waveforms/__init__.py @@ -0,0 +1,2 @@ +from .loader import ChunkedWaveformDataset, Hdf5WaveformLoader, WaveformLoader +from .sampler import WaveformSampler diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index 678a61ede..b35229851 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -1,5 +1,12 @@ from pathlib import Path +import logging +import math +import warnings +from typing import Iterable, Optional + +import h5py +import numpy as np import torch from .sampler import WaveformSampler @@ -7,14 +14,12 @@ class WaveformLoader(WaveformSampler): """ - Torch module for loading training and validation - waveforms from disk and sampling them during training. - TODO: modify this to sample waveforms from disk, taking - an index sampler object so that DDP training can sample - different waveforms for each device. - Args: - training_waveform_file: - Path to the training waveforms file + Module that should be used if training waveforms + are loaded from disk. The main function is to make + waveform handling consistent whether the waveforms + are generated during training or loaded from disk + The actual loading of the training waveforms is done + by `waveform_dataloader` in `train.data.base.train_dataloader()`. """ def __init__( @@ -32,26 +37,213 @@ def __init__( "Training waveform file does not have the same " "right pad as validation waveform file" ) - self.num_train_waveforms = len(waveform_set) def get_train_waveforms(self, world_size, rank, device): - """ - Returns train waveforms for this device - """ - start, stop = self.get_slice_bounds( - self.num_train_waveforms, world_size, rank + pass + + +# TODO: move to ml4gw +class Hdf5WaveformLoader(torch.utils.data.IterableDataset): + """ + Iterable dataset that loads samples of waveforms + from a set of HDF5 files. + + It is _strongly_ recommended that these files have been + written using [chunked storage] + (https://docs.h5py.org/en/stable/high/dataset.html#chunked-storage). + This has shown to produce increases in read-time speeds + of over an order of magnitude. + + Args: + fnames: + Paths to HDF5 files from which to sample data. + channels: + Datasets to read from the indicated files, which + will be stacked along dim 1 of the generated batches + during iteration. + batch_size: + Number of samples to load at each iteration. + batches_per_epoch: + Number of batches to generate during each call + to `__iter__`. + chunk_size: + Number of samples to load from each file at a time. + This is useful for reducing I/O overhead when reading. + path: + Optional path to location of datasets in hdf5 files. + `path` should be delimited by forward slashes. If `None` + it is assumed the datasets are at the root of the file. + """ + + def __init__( + self, + fnames: Iterable[Path], + channels: Iterable[str], + batch_size: int, + batches_per_epoch: int, + chunk_size: int = 1000, + path: Optional[Path] = None, + ): + self.fnames = fnames + self.channels = channels + self.batch_size = batch_size + self.batches_per_epoch = batches_per_epoch + self.chunk_size = chunk_size + + if path is not None: + self.path = path.split("/") + else: + self.path = None + + self.sizes = {} + self.mmap_files = {} + self.mmap_datasets = {} + + # for each file store the datasets + # of interest in a dictionary so we + # can access them at will without needing + # to reopen the files each time + for fname in self.fnames: + f, g = self.open(fname) + self.mmap_files[fname] = f + self.mmap_datasets[fname] = { + channel: g[channel] for channel in self.channels + } + + # store sizes of each dataset and warn if not chunked; + # assumes all dsets have same attributes + # like size and chunking behavior + dset = self.mmap_datasets[fname][self.channels[0]] + self.sizes[fname] = len(dset) + if dset.chunks is None: + warnings.warn( + "File {} contains datasets that were generated " + "without using chunked storage. This can have " + "severe performance impacts at data loading time. " + "If you need faster loading, try re-generating " + "your datset with chunked storage turned on.".format( + fnames + ), + stacklevel=2, + ) + + self.waveform_size = dset.shape[1] + self.probs = np.array([i / self.total for i in self.sizes.values()]) + + @property + def num_channels(self): + return len(self.channels) + + @property + def chunks_per_batch(self): + return math.ceil(self.batch_size / self.chunk_size) + + @property + def total(self): + return sum(self.sizes.values()) + + def __len__(self): + return self.batches_per_epoch + + def __del__(self): + # close all opened files when the object is destroyed + for f in self.mmap_files.values(): + f.close() + + def open(self, fname) -> tuple[h5py.File, h5py.Group]: + f = group = h5py.File(fname, "r") + if self.path is not None: + for path in self.path: + group = group[path] + return f, group + + def load_chunk(self, fname, start, size): + end = min(start + size, self.sizes[fname]) + return { + channel: self.mmap_datasets[fname][channel][start:end] + for channel in self.channels + } + + def sample_batch(self): + # allocate batch up front + batch = np.zeros( + (self.batch_size, self.num_channels, self.waveform_size) ) - waveform_set = self.waveform_set_cls.read(self.training_waveform_file) - waveforms = torch.Tensor(waveform_set.waveforms[start:stop]) - self.train_waveforms = waveforms.to(device) - - def sample(self, X: torch.Tensor): - """ - Sample method for generating training waveforms - """ - N = len(X) - idx = torch.randperm(self.num_train_waveforms)[:N] - waveforms = self.train_waveforms[:, idx] - - hc, hp = waveforms - return hc, hp + + for i in range(self.chunks_per_batch): + fname = np.random.choice(self.fnames, p=self.probs) + + chunk_size = min( + self.chunk_size, self.batch_size - i * self.chunk_size + ) + + # select a random starting index for the chunk + max_start = self.sizes[fname] - chunk_size + start = np.random.randint(0, max_start + 1) + + # load the chunk and insert it into the batch + chunk = self.load_chunk(fname, start, chunk_size) + batch_start = i * self.chunk_size + batch_end = batch_start + chunk_size + + for i, channel in enumerate(self.channels): + batch[batch_start:batch_end, i, :] = chunk[channel] + + return torch.tensor(batch) + + def __iter__(self): + for _ in range(self.batches_per_epoch): + yield self.sample_batch() + + +class ChunkedWaveformDataset(torch.utils.data.IterableDataset): + """ + Wrapper dataset that will loop through chunks of timeseries + data produced by another iterable and sample subsets + of waveforms from each chunk. + + Args: + chunk_it: + Iterator which will produce batches of waveform + data to sample subsets from. Should have shape + `(N, C, T)`, where `N` is the number of waveformns + to sample from, `C` is the number of channels, + and `T` is the number of samples along the + time dimension for each waveform. + batch_size: + Number of waveforms to sample at each iteration + batches_per_chunk: + Number of batches of waveforms to sample from + each chunk before moving on to the next one. + """ + + def __init__( + self, + chunk_it: Iterable, + batch_size: int, + batches_per_chunk: int, + ) -> None: + self.logger = logging.getLogger(__name__) + self.chunk_it = chunk_it + self.batch_size = batch_size + self.batches_per_chunk = batches_per_chunk + + def __len__(self): + return len(self.chunk_it) * self.batches_per_chunk + + def __iter__(self): + it = iter(self.chunk_it) + [chunk] = next(it) + + num_waveforms, _, _ = chunk.shape + while True: + # generate batches from the current chunk + for _ in range(self.batches_per_chunk): + idx = torch.randperm(num_waveforms)[: self.batch_size] + yield chunk[idx] + + try: + [chunk] = next(it) + except StopIteration: + break + num_waveforms, _, _ = chunk.shape diff --git a/projects/train/train/model/supervised.py b/projects/train/train/model/supervised.py index b8d33324f..9087b3bf9 100644 --- a/projects/train/train/model/supervised.py +++ b/projects/train/train/model/supervised.py @@ -33,12 +33,12 @@ def score(self, X, X_fft): return self(X, X_fft) def train_step(self, batch: tuple[Tensor, Tensor]) -> Tensor: - X, X_fft, y = batch + (X, X_fft), y = batch y_hat = self(X, X_fft) return torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y) def validation_step(self, batch, _) -> None: - shift, X_bg, X_inj, X_bg_fft, X_inj_fft = batch + shift, (X_bg, X_inj), (X_bg_fft, X_inj_fft) = batch y_bg = self.score(X_bg, X_bg_fft) From b490a7e2d6fb299757091f4b1f450c7c20589e51 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Fri, 10 Oct 2025 06:15:29 -0700 Subject: [PATCH 27/32] Get things working for all cases --- projects/train/train.yaml | 2 +- projects/train/train/augmentations.py | 4 ++- projects/train/train/callbacks.py | 27 ++++++++++++++----- projects/train/train/data/base.py | 7 +++-- .../train/train/data/supervised/__init__.py | 2 +- .../train/train/data/supervised/multimodal.py | 4 +-- .../train/data/supervised/time_domain.py | 4 +-- .../data/supervised/time_frequency_domain.py | 8 +++--- projects/train/train/data/waveforms/loader.py | 3 ++- projects/train/train/model/supervised.py | 2 +- 10 files changed, 39 insertions(+), 24 deletions(-) diff --git a/projects/train/train.yaml b/projects/train/train.yaml index 5b81a2984..bd733f299 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -77,7 +77,7 @@ data: class_path: train.data.waveforms.generator.cbc.CBCGenerator init_args: training_prior: ./training_prior.yaml - val_waveform_file: ${oc.env:AFRAME_TRAIN_DATA_DIR}/train/val_waveforms.hdf5 + val_waveform_file: ${oc.env:AFRAME_TRAIN_DATA_DIR}/val_waveforms.hdf5 approximant: ml4gw.waveforms.IMRPhenomPv2 f_min: 20 f_ref: 40 diff --git a/projects/train/train/augmentations.py b/projects/train/train/augmentations.py index 345625339..9b272caee 100644 --- a/projects/train/train/augmentations.py +++ b/projects/train/train/augmentations.py @@ -55,7 +55,9 @@ def forward(self, X): if num > 0: channel = torch.randint(X.shape[1], size=(num,)) indices = torch.randint(X.shape[0], size=(num,)) - X[indices, channel] = torch.zeros(X.shape[-1], device=X.device) + X[indices, channel] = torch.zeros( + X.shape[-1], device=X.device, dtype=X.dtype + ) return X, indices diff --git a/projects/train/train/callbacks.py b/projects/train/train/callbacks.py index f1fdd0030..40fee25ac 100644 --- a/projects/train/train/callbacks.py +++ b/projects/train/train/callbacks.py @@ -43,9 +43,16 @@ def on_train_end(self, trainer, pl_module): ) device = pl_module.device - [X] = next(iter(trainer.train_dataloader)) - X = X.to(device) - X, y = trainer.datamodule.inject(X) + # Handle the case of loading training waveforms from disk + if trainer.datamodule.waveforms_from_disk: + [X], waveforms = next(iter(trainer.train_dataloader)) + X = X.to(device) + waveforms = waveforms.to(device) + X, y = trainer.datamodule.inject(X, waveforms) + else: + [X] = next(iter(trainer.train_dataloader)) + X = X.to(device) + X, y = trainer.datamodule.inject(X) if isinstance(X, tuple): X = tuple(i.cpu() for i in X) else: @@ -75,10 +82,16 @@ def on_train_start(self, trainer, pl_module): save_dir = trainer.logger.save_dir # build training batch by hand - [X] = next(iter(trainer.train_dataloader)) - X = X.to(device) - - X, y = trainer.datamodule.inject(X) + # Handle the case of loading training waveforms from disk + if trainer.datamodule.waveforms_from_disk: + [X], waveforms = next(iter(trainer.train_dataloader)) + X = X.to(device) + waveforms = waveforms.to(device) + X, y = trainer.datamodule.inject(X, waveforms) + else: + [X] = next(iter(trainer.train_dataloader)) + X = X.to(device) + X, y = trainer.datamodule.inject(X) # If X is not a tuple, make it one for consistency # of format for saving to file below if not isinstance(X, tuple): diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index ed632a41f..fc7a08459 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -619,10 +619,9 @@ def val_dataloader(self) -> ZippedDataset: # we're going to go through, then batch the # signals so that they're spaced evenly # throughout all those batches. - cross, plus = self.val_waveforms - num_waveforms = len(cross) + num_waveforms = len(self.val_waveforms) signal_batch_size = (num_waveforms - 1) // self.valid_loader_length + 1 - signal_dataset = torch.utils.data.TensorDataset(cross, plus) + signal_dataset = torch.utils.data.TensorDataset(self.val_waveforms) signal_loader = torch.utils.data.DataLoader( signal_dataset, batch_size=signal_batch_size, @@ -669,7 +668,7 @@ def train_dataloader(self) -> torch.utils.data.DataLoader: # that will load chunks of waveforms # to be sampled from waveform_loader = Hdf5WaveformLoader( - self.train_waveform_fnames, + [self.waveform_sampler.training_waveform_file], batch_size=self.hparams.chunk_size, batches_per_epoch=self.hparams.chunks_per_epoch or 1, channels=["cross", "plus"], diff --git a/projects/train/train/data/supervised/__init__.py b/projects/train/train/data/supervised/__init__.py index 291ef42a8..4ca80c178 100644 --- a/projects/train/train/data/supervised/__init__.py +++ b/projects/train/train/data/supervised/__init__.py @@ -1,4 +1,4 @@ -from .frequency_domain import ( +from .time_frequency_domain import ( FrequencyDomainSupervisedAframeDataset, SpectrogramDomainSupervisedAframeDataset, ) diff --git a/projects/train/train/data/supervised/multimodal.py b/projects/train/train/data/supervised/multimodal.py index 64c914eec..034e396d7 100644 --- a/projects/train/train/data/supervised/multimodal.py +++ b/projects/train/train/data/supervised/multimodal.py @@ -53,8 +53,8 @@ def compute_frequency_domain_data(self, X, psds): return X_fft - def inject(self, X): - X, y, psds = super().inject(X) + def inject(self, X, waveforms=None): + X, y, psds = super().inject(X, waveforms) X = self.whitener(X, psds) X_fft = self.compute_frequency_domain_data(X, psds) diff --git a/projects/train/train/data/supervised/time_domain.py b/projects/train/train/data/supervised/time_domain.py index 35418b3be..a26806614 100644 --- a/projects/train/train/data/supervised/time_domain.py +++ b/projects/train/train/data/supervised/time_domain.py @@ -16,7 +16,7 @@ def build_val_batches(self, background, signals): X_fg = torch.stack(X_fg) return X_bg, X_fg - def inject(self, X): - X, y, psds = super().inject(X) + def inject(self, X, waveforms=None): + X, y, psds = super().inject(X, waveforms) X = self.whitener(X, psds) return X, y diff --git a/projects/train/train/data/supervised/time_frequency_domain.py b/projects/train/train/data/supervised/time_frequency_domain.py index 54bba8ed3..09c9f0486 100644 --- a/projects/train/train/data/supervised/time_frequency_domain.py +++ b/projects/train/train/data/supervised/time_frequency_domain.py @@ -23,8 +23,8 @@ def build_transforms(self, *args, **kwargs): spectrogram_shape=self.spectrogram_shape, ) - def inject(self, X): - X, y, psds = super().inject(X) + def inject(self, X, waveforms=None): + X, y, psds = super().inject(X, waveforms) X = self.whitener(X, psds) X = self.qtransform(X) return X, y @@ -117,8 +117,8 @@ def build_val_batches(self, *args, **kwargs): return X_bg, X_inj - def inject(self, X): - X, y, psds = super().inject(X) + def inject(self, X, waveforms=None): + X, y, psds = super().inject(X, waveforms) # fft whiten and bandpass in frequency domain X = self.whiten(X, psds) diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index b35229851..1268021be 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -9,6 +9,7 @@ import numpy as np import torch +from ledger.injections import WaveformPolarizationSet from .sampler import WaveformSampler @@ -31,7 +32,7 @@ def __init__( super().__init__(*args, **kwargs) self.training_waveform_file = training_waveform_file - waveform_set = self.waveform_set_cls.read(training_waveform_file) + waveform_set = WaveformPolarizationSet.read(training_waveform_file) if waveform_set.right_pad != self.right_pad: raise ValueError( "Training waveform file does not have the same " diff --git a/projects/train/train/model/supervised.py b/projects/train/train/model/supervised.py index 9087b3bf9..5fe42b241 100644 --- a/projects/train/train/model/supervised.py +++ b/projects/train/train/model/supervised.py @@ -38,7 +38,7 @@ def train_step(self, batch: tuple[Tensor, Tensor]) -> Tensor: return torch.nn.functional.binary_cross_entropy_with_logits(y_hat, y) def validation_step(self, batch, _) -> None: - shift, (X_bg, X_inj), (X_bg_fft, X_inj_fft) = batch + shift, (X_bg, X_bg_fft), (X_inj, X_inj_fft) = batch y_bg = self.score(X_bg, X_bg_fft) From 9511e3923f64b6f091f56c2754245ff655f12de8 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Fri, 10 Oct 2025 06:20:52 -0700 Subject: [PATCH 28/32] Make pre-generating training waveforms the default --- aframe/pipelines/sandbox/configs/bbh.cfg | 1 + aframe/pipelines/sandbox/configs/bns.cfg | 1 + projects/train/train.yaml | 24 ++++++++++++++---------- 3 files changed, 16 insertions(+), 10 deletions(-) diff --git a/aframe/pipelines/sandbox/configs/bbh.cfg b/aframe/pipelines/sandbox/configs/bbh.cfg index aee9d6caf..a1aa04d59 100644 --- a/aframe/pipelines/sandbox/configs/bbh.cfg +++ b/aframe/pipelines/sandbox/configs/bbh.cfg @@ -29,3 +29,4 @@ lowpass = &::luigi_base::lowpass fduration = &::luigi_base::fduration seed = &::luigi_base::seed use_wandb = true +precompute_train_waveforms = true diff --git a/aframe/pipelines/sandbox/configs/bns.cfg b/aframe/pipelines/sandbox/configs/bns.cfg index faecdea95..867f8c775 100644 --- a/aframe/pipelines/sandbox/configs/bns.cfg +++ b/aframe/pipelines/sandbox/configs/bns.cfg @@ -40,3 +40,4 @@ seed = &::luigi_base::seed fftlength = &::luigi_base::fftlength q = &::luigi_base::q use_wandb = true +precompute_train_waveforms = true diff --git a/projects/train/train.yaml b/projects/train/train.yaml index bd733f299..e7f7f679b 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -39,8 +39,8 @@ data: batches_per_epoch: 3700 num_files_per_batch: 10 # Used for loading waveforms from disk - # chunks_per_epoch: 10 - # chunk_size: 10000 + chunks_per_epoch: 10 + chunk_size: 10000 # preprocessing args # kernel_length: @@ -73,15 +73,19 @@ data: # decay_steps: 989 # If loading waveforms from disk, waveform_sampler should be # an instance of train.data.waveforms.WaveformLoader - waveform_sampler: - class_path: train.data.waveforms.generator.cbc.CBCGenerator + waveform_sampler: + class_path: train.data.waveforms.WaveformLoader init_args: - training_prior: ./training_prior.yaml - val_waveform_file: ${oc.env:AFRAME_TRAIN_DATA_DIR}/val_waveforms.hdf5 - approximant: ml4gw.waveforms.IMRPhenomPv2 - f_min: 20 - f_ref: 40 - duration: 8 + training_waveform_file: ${oc.env:AFRAME_TRAIN_DATA_DIR}/training_waveforms.hdf5 + # waveform_sampler: + # class_path: train.data.waveforms.generator.cbc.CBCGenerator + # init_args: + # training_prior: ./training_prior.yaml + # val_waveform_file: ${oc.env:AFRAME_TRAIN_DATA_DIR}/val_waveforms.hdf5 + # approximant: ml4gw.waveforms.IMRPhenomPv2 + # f_min: 20 + # f_ref: 40 + # duration: 8 # Extrinsic parameter distributions dec: class_path: ml4gw.distributions.Cosine From b1ef8e2751d1c4fa772acc580541de332c0c1c31 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Fri, 10 Oct 2025 08:14:38 -0700 Subject: [PATCH 29/32] Clean up leftover pieces --- aframe/tasks/train/base.py | 2 +- libs/ledger/ledger/injections.py | 19 --------- libs/priors/priors/priors.py | 8 ++-- projects/train/train/conversion.py | 66 +++++++++++++++--------------- projects/train/train/data/base.py | 4 -- projects/train/train/prior.py | 2 - 6 files changed, 36 insertions(+), 65 deletions(-) diff --git a/aframe/tasks/train/base.py b/aframe/tasks/train/base.py index 2925c5e21..3bf5a9e70 100644 --- a/aframe/tasks/train/base.py +++ b/aframe/tasks/train/base.py @@ -59,7 +59,7 @@ class TrainBaseParameters(law.Task): description="Directory where training data is stored." "It is expected to contain a `val_waveforms.hdf5` file of " "signals for validation, a `/background` sub-directory containing " - "background, and a `train_waveforms.hdf5` file containing " + "background, and a `training_waveforms.hdf5` file containing " "training signals if `precompute_train_waveforms` is set to True.", default=paths().train_datadir, ) diff --git a/libs/ledger/ledger/injections.py b/libs/ledger/ledger/injections.py index b379f45b7..c45c30892 100644 --- a/libs/ledger/ledger/injections.py +++ b/libs/ledger/ledger/injections.py @@ -158,25 +158,6 @@ def redshift(self, cosmology=DEFAULT_COSMOLOGY): cosmology.luminosity_distance, self.luminosity_distance * Mpc ).value - @property - def ml4gw_generation_params(self): - params = { - "mass_1": self.mass1, - "mass_2": self.mass2, - "chirp_mass": chirp_mass(self.mass1, self.mass2), - "mass_ratio": self.mass2 / self.mass1, - "s1x": self.spin1x, - "s1y": self.spin1y, - "s1z": self.spin1z, - "s2x": self.spin2x, - "s2y": self.spin2y, - "s2z": self.spin2z, - "inclination": self.inclination, - "distance": self.luminosity_distance, - "phic": self.phase, - } - return params - @property def generation_params(self): params = { diff --git a/libs/priors/priors/priors.py b/libs/priors/priors/priors.py index 77eace1a8..9220783c3 100644 --- a/libs/priors/priors/priors.py +++ b/libs/priors/priors/priors.py @@ -204,13 +204,11 @@ def end_o3_ratesandpops_bns( prior["redshift"] = UniformSourceFrame( 0, 0.15, name="redshift", cosmology=cosmology ) - prior["psi"] = 0 + spin_prior = uniform_spin() + for key, value in spin_prior.items(): + prior[key] = value prior["a_1"] = Uniform(0, 0.4) prior["a_2"] = Uniform(0, 0.4) - prior["tilt_1"] = 0 - prior["tilt_2"] = 0 - prior["phi_12"] = 0 - prior["phi_jl"] = 0 detector_frame_prior = False return prior, detector_frame_prior diff --git a/projects/train/train/conversion.py b/projects/train/train/conversion.py index fbacb9928..c3e5e594a 100644 --- a/projects/train/train/conversion.py +++ b/projects/train/train/conversion.py @@ -5,18 +5,42 @@ ) +def add_mass_params( + parameters: dict[str, torch.Tensor], +) -> dict[str, torch.Tensor]: + """ + Add mass_1 and mass_2 to parameter dictionary if + chirp_mass and mass_ratio are present, or vice-versa. + + Raises a ValueError if neither pair are present. + """ + if "chirp_mass" in parameters and "mass_ratio" in parameters: + mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components( + parameters["chirp_mass"], parameters["mass_ratio"] + ) + + parameters["mass_1"] = mass_1 + parameters["mass_2"] = mass_2 + elif "mass_1" in parameters and "mass_2" in parameters: + m1, m2 = parameters["mass_1"], parameters["mass_2"] + parameters["chirp_mass"] = (m1 * m2) ** (3 / 5) / (m1 + m2) ** (1 / 5) + parameters["mass_ratio"] = m2 / m1 + else: + raise ValueError( + "Parameter dictionary did not contain either " + "(chirp mass, mass ratio) or (mass 1, mass 2)" + ) + return parameters + + def precessing_to_lalsimulation_parameters( parameters: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]: """ Convert precessing spin parameters to lalsimulation parameters """ - mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components( - parameters["chirp_mass"], parameters["mass_ratio"] - ) - - parameters["mass_1"] = mass_1 - parameters["mass_2"] = mass_2 + parameters = add_mass_params(parameters) + mass_1 = parameters["mass_1"] # TODO: hard coding f_ref = 40 here b/c not sure best way to link this # to the f_ref specified in the config file @@ -50,12 +74,8 @@ def aligned_to_lalsimulation_parameters( """ Convert aligned spin parameters to lalsimulation parameters """ - mass_1, mass_2 = chirp_mass_and_mass_ratio_to_components( - parameters["chirp_mass"], parameters["mass_ratio"] - ) - - parameters["mass_1"] = mass_1 - parameters["mass_2"] = mass_2 + parameters = add_mass_params(parameters) + mass_1 = parameters["mass_1"] parameters["s1x"] = torch.zeros_like(mass_1) parameters["s1y"] = torch.zeros_like(mass_1) @@ -66,25 +86,3 @@ def aligned_to_lalsimulation_parameters( parameters["s1z"] = parameters["chi1"] parameters["s2z"] = parameters["chi2"] return parameters - - -def component_aligned_to_lalsimulation_parameters( - parameters: dict[str, torch.Tensor], -) -> dict[str, torch.Tensor]: - """ - Convert aligned spin parameters to lalsimulation parameters - and compnent masses to chirp mass and mass ratio - """ - m1, m2 = parameters["mass_1"], parameters["mass_2"] - parameters["chirp_mass"] = (m1 * m2) ** (3 / 5) / (m1 + m2) ** (1 / 5) - parameters["mass_ratio"] = m2 / m1 - - parameters["s1x"] = torch.zeros_like(m1) - parameters["s1y"] = torch.zeros_like(m1) - - parameters["s2x"] = torch.zeros_like(m1) - parameters["s2y"] = torch.zeros_like(m1) - - parameters["s1z"] = parameters["chi1"] - parameters["s2z"] = parameters["chi2"] - return parameters diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index fc7a08459..0ad2c51fb 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -29,10 +29,6 @@ Distribution = torch.distributions.Distribution TransformedDist = torch.distributions.TransformedDistribution -# TODO: -# Move waveform slicing to the waveform sampler? -# Make separate training prior - # TODO: using this right now because # lightning.pytorch.utilities.CombinedLoader diff --git a/projects/train/train/prior.py b/projects/train/train/prior.py index 48fab8c76..f4c18448f 100644 --- a/projects/train/train/prior.py +++ b/projects/train/train/prior.py @@ -69,8 +69,6 @@ def __call__( return {k: v[:N] for k, v in parameters.items()} - return parameters - def log_prob(self, samples: dict[str, torch.Tensor]) -> torch.Tensor: """ Calculate the log probability of samples under the prior. From 55644650e68d536013d0381eef985274ad2f5fb3 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Fri, 10 Oct 2025 08:38:23 -0700 Subject: [PATCH 30/32] Reinstate slicing before loading for training waveforms --- projects/train/train/data/base.py | 14 ++++++++++ .../train/train/data/supervised/supervised.py | 26 +++++++++++++++---- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 0ad2c51fb..6d3c4c937 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -486,6 +486,20 @@ def device(self): """Return the device of the associated lightning module""" return self.trainer.lightning_module.device + def on_before_batch_transfer(self, batch, _): + """ + Slice loaded waveforms before sending to device + if not generating waveforms during training + """ + # TODO: maybe pass indices as argument to + # waveform loader to reduce quantity of data + # we need to load + if self.trainer.training and self.waveforms_from_disk: + X, waveforms = batch + waveforms = self.slice_waveforms(waveforms) + batch = X, waveforms + return batch + def on_after_batch_transfer(self, batch, _): """ This is a method inherited from the DataModule diff --git a/projects/train/train/data/supervised/supervised.py b/projects/train/train/data/supervised/supervised.py index 49ea5ea50..66be2c57c 100644 --- a/projects/train/train/data/supervised/supervised.py +++ b/projects/train/train/data/supervised/supervised.py @@ -44,6 +44,12 @@ def sample_prob(self): @torch.no_grad() def inject(self, X, waveforms=None): + if self.waveforms_from_disk and waveforms is None: + raise ValueError( + "Waveforms should be passed to the `inject` method " + "if waveforms are being loaded from disk, got None" + ) + X, psds = self.psd_estimator(X) X = self.inverter(X) X = self.reverser(X) @@ -54,11 +60,18 @@ def inject(self, X, waveforms=None): mask = rvs < self.sample_prob dec, psi, phi = self.sample_extrinsic(X[mask]) - # If waveforms were passed, we're loading them from - # disk and we can slice out the ones we want. + # If we're loading waveforms from disk, we can + # slice out the ones we want. # If not, we're generating them on the fly. - if waveforms is not None: - hc, hp = waveforms[mask, 0], waveforms[mask, 1] + if self.waveforms_from_disk: + # TODO: Can we just use `mask` to slice out the + # waveforms we want here? Copying this from the + # old `WaveformSampler` in case it handles edge + # cases I'm not thinking of + N = mask.sum().item() + idx = torch.randperm(waveforms.shape[0])[:N] + waveforms = waveforms[idx].to(X.device).float() + hc, hp = waveforms[:, 0], waveforms[:, 1] else: hc, hp = self.waveform_sampler.sample(X[mask]) @@ -66,7 +79,10 @@ def inject(self, X, waveforms=None): responses = self.projector( dec, psi, phi, snrs, psds[mask], cross=hc, plus=hp ) - responses = self.slice_waveforms(responses) + # If we're loading waveforms from disk, we'll have sliced + # the waveforms already in `on_before_batch_transfer` + if not self.waveforms_from_disk: + responses = self.slice_waveforms(responses) kernels = sample_kernels( responses, kernel_size=X.size(-1), coincident=True ) From b14c8c2c8f3d29051f5314244f22777b584559b6 Mon Sep 17 00:00:00 2001 From: William Benoit Date: Fri, 10 Oct 2025 08:41:48 -0700 Subject: [PATCH 31/32] Fix typo in docstring --- aframe/tasks/train/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aframe/tasks/train/base.py b/aframe/tasks/train/base.py index 3bf5a9e70..46f01a2b7 100644 --- a/aframe/tasks/train/base.py +++ b/aframe/tasks/train/base.py @@ -67,7 +67,7 @@ class TrainBaseParameters(law.Task): default=False, description="Whether to pre-compute the waveforms used " "during training. If True, the training waveforms will be " - "read from the `train_waveforms.hdf5` file in the data " + "read from the `training_waveforms.hdf5` file in the data " "directory. If False, the waveforms will be simulated " "on-the-fly during training.", ) From d2f7b9420c073b5cdeeabc223f584d8dbd39c8ed Mon Sep 17 00:00:00 2001 From: William Benoit Date: Fri, 10 Oct 2025 10:31:55 -0700 Subject: [PATCH 32/32] Allow for multiple training waveform files --- projects/train/train.yaml | 2 +- projects/train/train/data/base.py | 2 +- projects/train/train/data/waveforms/loader.py | 13 ++++++++++--- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/projects/train/train.yaml b/projects/train/train.yaml index e7f7f679b..fd19c4a0c 100644 --- a/projects/train/train.yaml +++ b/projects/train/train.yaml @@ -76,7 +76,7 @@ data: waveform_sampler: class_path: train.data.waveforms.WaveformLoader init_args: - training_waveform_file: ${oc.env:AFRAME_TRAIN_DATA_DIR}/training_waveforms.hdf5 + training_waveform_path: ${oc.env:AFRAME_TRAIN_DATA_DIR}/training_waveforms.hdf5 # waveform_sampler: # class_path: train.data.waveforms.generator.cbc.CBCGenerator # init_args: diff --git a/projects/train/train/data/base.py b/projects/train/train/data/base.py index 6d3c4c937..cfd7d7638 100644 --- a/projects/train/train/data/base.py +++ b/projects/train/train/data/base.py @@ -678,7 +678,7 @@ def train_dataloader(self) -> torch.utils.data.DataLoader: # that will load chunks of waveforms # to be sampled from waveform_loader = Hdf5WaveformLoader( - [self.waveform_sampler.training_waveform_file], + self.waveform_sampler.training_waveform_files, batch_size=self.hparams.chunk_size, batches_per_epoch=self.hparams.chunks_per_epoch or 1, channels=["cross", "plus"], diff --git a/projects/train/train/data/waveforms/loader.py b/projects/train/train/data/waveforms/loader.py index 1268021be..3538ccdaf 100644 --- a/projects/train/train/data/waveforms/loader.py +++ b/projects/train/train/data/waveforms/loader.py @@ -26,13 +26,20 @@ class WaveformLoader(WaveformSampler): def __init__( self, *args, - training_waveform_file: Path, + training_waveform_path: Path, **kwargs, ) -> None: super().__init__(*args, **kwargs) - self.training_waveform_file = training_waveform_file + if training_waveform_path.is_dir(): + self.training_waveform_files = list( + training_waveform_path.iterdir() + ) + else: + self.training_waveform_files = [training_waveform_path] - waveform_set = WaveformPolarizationSet.read(training_waveform_file) + waveform_set = WaveformPolarizationSet.read( + self.training_waveform_files[0] + ) if waveform_set.right_pad != self.right_pad: raise ValueError( "Training waveform file does not have the same "