From 25d5cc1e9eb90d8ada67336ad33381a4274872bf Mon Sep 17 00:00:00 2001 From: Steven James Henderson Date: Fri, 6 Dec 2024 09:52:09 -0800 Subject: [PATCH 1/9] Added architecture to concat HL and V neural networks --- hermes | 1 + libs/architectures/architectures/concat.py | 70 ++++++++++++++++++++++ ml4gw | 1 + pycondor | 1 + 4 files changed, 73 insertions(+) create mode 160000 hermes create mode 100644 libs/architectures/architectures/concat.py create mode 160000 ml4gw create mode 160000 pycondor diff --git a/hermes b/hermes new file mode 160000 index 000000000..86c6a2478 --- /dev/null +++ b/hermes @@ -0,0 +1 @@ +Subproject commit 86c6a2478c93ac2e7cf50039de7528023579bd5e diff --git a/libs/architectures/architectures/concat.py b/libs/architectures/architectures/concat.py new file mode 100644 index 000000000..07d99be9b --- /dev/null +++ b/libs/architectures/architectures/concat.py @@ -0,0 +1,70 @@ +from typing import Literal, Optional + +from architectures.supervised import SupervisedArchitecture +from ml4gw.nn.resnet.resnet_1d import NormLayer, ResNet1D + +import torch + +class ConcatResNet(SupervisedArchitecture): + def __init__( + self, + in_channels: int, + layers: List[int], + classes: int, + kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[List[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + ) -> None: + super().__init__( + ) + self.layers = layers + # Initialize resnets here + self.hlresnet = ResNet1D( + in_channels=2, + classes=64, + layers=layers + ) + self.vresnet = ResNet1D( + classes=64, + layers=layers, + in_channels=1, + ) + # Initialize linear classifier here + self.classifier = torch.nn.Linear(128, 1) + + def forward(self, X): + # Extract hl data and v data from X + hl = X[:, :2, :] + v = X[:, 2:, :] + # Pass hl and v through resnets + hl = self.hlresnet(hl) + v = self.vresnet(v) + # Concatenate hl and v outputs + concat = torch.concat([hl, v], dim=-1) + # Pass concatenated output through linear classifier + outputs = self.classifier(concat) + + return outputs + + + + + + + + + + + + + + + + + + + + diff --git a/ml4gw b/ml4gw new file mode 160000 index 000000000..92ad1d047 --- /dev/null +++ b/ml4gw @@ -0,0 +1 @@ +Subproject commit 92ad1d047e892d57a76477c22811b69979ba2066 diff --git a/pycondor b/pycondor new file mode 160000 index 000000000..213774e98 --- /dev/null +++ b/pycondor @@ -0,0 +1 @@ +Subproject commit 213774e985a00cda74b04e1da028c9edbe082f5b From 8fa3a254c22e13400bba29a4d361541f9144e0f8 Mon Sep 17 00:00:00 2001 From: Steven James Henderson Date: Fri, 6 Dec 2024 09:57:07 -0800 Subject: [PATCH 2/9] Added architecture to concat HL and V neural networks --- libs/architectures/architectures/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/libs/architectures/architectures/__init__.py b/libs/architectures/architectures/__init__.py index 26032f78d..61435a175 100644 --- a/libs/architectures/architectures/__init__.py +++ b/libs/architectures/architectures/__init__.py @@ -6,3 +6,4 @@ SupervisedSpectrogramDomainResNet, SupervisedTimeDomainResNet, ) +from .concat import ConcatArchitecture, ConcatResNet From 7c4e378614c213834de039d57215e295e4cb65bf Mon Sep 17 00:00:00 2001 From: Steven James Henderson Date: Mon, 13 Jan 2025 13:55:44 -0800 Subject: [PATCH 3/9] Fixed imports --- libs/architectures/architectures/__init__.py | 2 +- libs/architectures/architectures/concat.py | 24 ++++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/libs/architectures/architectures/__init__.py b/libs/architectures/architectures/__init__.py index 61435a175..376fcf8d3 100644 --- a/libs/architectures/architectures/__init__.py +++ b/libs/architectures/architectures/__init__.py @@ -6,4 +6,4 @@ SupervisedSpectrogramDomainResNet, SupervisedTimeDomainResNet, ) -from .concat import ConcatArchitecture, ConcatResNet +from .concat import ConcatResNet diff --git a/libs/architectures/architectures/concat.py b/libs/architectures/architectures/concat.py index 07d99be9b..d0a6a34a2 100644 --- a/libs/architectures/architectures/concat.py +++ b/libs/architectures/architectures/concat.py @@ -8,32 +8,36 @@ class ConcatResNet(SupervisedArchitecture): def __init__( self, - in_channels: int, - layers: List[int], - classes: int, + v_dim: int, + v_layers: list[int], + hl_dim: int, + hl_layers: list[int], + num_ifos: int, + layers: list[int], + sample_rate: float, + kernel_length: float, kernel_size: int = 3, zero_init_residual: bool = False, groups: int = 1, width_per_group: int = 64, - stride_type: Optional[List[Literal["stride", "dilation"]]] = None, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, norm_layer: Optional[NormLayer] = None, ) -> None: super().__init__( ) - self.layers = layers # Initialize resnets here self.hlresnet = ResNet1D( in_channels=2, - classes=64, - layers=layers + classes=hl_dim, + layers=hl_layers, ) self.vresnet = ResNet1D( - classes=64, - layers=layers, + classes=v_dim, + layers=v_layers, in_channels=1, ) # Initialize linear classifier here - self.classifier = torch.nn.Linear(128, 1) + self.classifier = torch.nn.Linear(hl_dim+v_dim, 1) def forward(self, X): # Extract hl data and v data from X From 84b2f972587a41d2ba454ef1d95731b086fd7135 Mon Sep 17 00:00:00 2001 From: Steven James Henderson Date: Wed, 13 Aug 2025 10:33:43 -0700 Subject: [PATCH 4/9] Reformatted main and snapshotter to take high, low, fft inputs --- libs/utils/utils/preprocessing.py | 128 ++++++++++++++ .../train/train/data/supervised/__init__.py | 1 + .../data/supervised/multimodal_time_split.py | 156 ++++++++++++++++++ projects/train/train/model/__init__.py | 1 + projects/train/train/model/supervised.py | 70 ++++++++ 5 files changed, 356 insertions(+) create mode 100644 projects/train/train/data/supervised/multimodal_time_split.py diff --git a/libs/utils/utils/preprocessing.py b/libs/utils/utils/preprocessing.py index e06e6c788..321137d28 100644 --- a/libs/utils/utils/preprocessing.py +++ b/libs/utils/utils/preprocessing.py @@ -254,3 +254,131 @@ def forward(self, x: Tensor) -> Tensor: x_fft = torch.cat([x_fft.real, x_fft.imag, inv_asd], dim=1) return x, x_fft + +class MultiModalTimeSplitPreprocessor(torch.nn.Module): + """ + Preprocess a batch of waveforms for multimodal + training with a split time strain domain. + This includes whitening the time domain data and + calculating the frequency domain data + """ + def __init__( + self, + kernel_length: float, + sample_rate: float, + inference_sampling_rate: float, + batch_size: int, + fduration: float, + fftlength: float, + augmentor: Optional[Callable] = None, + highpass: Optional[float] = None, + lowpass: Optional[float] = None, + return_whitened: bool = False, + ) -> None: + super().__init__() + self.stride_size = int(sample_rate / inference_sampling_rate) + self.kernel_size = int(kernel_length * sample_rate) + self.augmentor = augmentor + self.return_whitened = return_whitened + self.highpass = highpass + self.lowpass = lowpass + self.fftlength = fftlength + self.n_fft = int(fftlength * sample_rate) + self.sample_rate = sample_rate + self.n_fft_psd = int(fftlength * sample_rate) + self.feature_nfft = self.kernel_size + + strides = (batch_size - 1) * self.stride_size + fsize = int(fduration * sample_rate) + size = strides + self.kernel_size + fsize + length = size / sample_rate + self.psd_estimator = PsdEstimator( + length, + sample_rate, + fftlength=fftlength, + overlap=None, + average="median", + fast=highpass is not None, + ) + self.whitener = Whiten(fduration, sample_rate, highpass, lowpass) + + freqs = torch.fft.rfftfreq(self.feature_nfft, d=1 / sample_rate) + self.freq_mask = torch.ones_like(freqs, dtype=torch.bool) + if highpass is not None: + self.freq_mask &= freqs > highpass + # if lowpass is not None: + # self.freq_mask &= freqs < lowpass + self.freqs = freqs + self.freq_mask = self.freq_mask + self.num_freqs = int(freqs.numel()) + + def forward(self, x: Tensor): + if x.ndim == 3: + num_channels = x.size(1) + elif x.ndim == 2: + num_channels = x.size(0) + else: + raise ValueError(f"Unexpected input shape: {x.shape}") + + x, psd = self.psd_estimator(x) + whitened = self.whitener(x, psd) + whitened_low = self.whitener( + x, psd, highpass=self.highpass, lowpass=self.lowpass + ) + whitened_high = self.whitener( + x, psd, highpass=self.lowpass, lowpass=None + ) + + x = x.float() + + asds = psd**0.5 + asds = asds.float() + asds *= 1e23 + + # ensure asd is 3D so mask works even if asd is never interpolated + if asds.ndim == 2: + asds = asds.unsqueeze(0) + + # unfold x and other inputs and then put into expected shape. + # Note that if x has both signal and background + # batch elements, they will be interleaved along + # the batch dimension after unfolding + x = unfold_windows(whitened, self.kernel_size, self.stride_size) + x_low = unfold_windows( + whitened_low, self.kernel_size, self.stride_size + ) + x_high = unfold_windows( + whitened_high, self.kernel_size, self.stride_size + ) + x = x.reshape(-1, num_channels, self.kernel_size) + + x_fft = torch.fft.rfft(x, n=self.feature_nfft, dim=-1) + F_expected = self.num_freqs + F_actual = x_fft.shape[-1] + if F_actual != F_expected: + raise ValueError( + f"""FFT bin mismatch: got F={F_actual} from rfft," \ + expected {F_expected}""" + f"(kernel_size={self.kernel_size}, n_fft={self.n_fft}). " + "Ensure fftlength*sample_rate matches the " \ + "intended kernel-length FFT." + ) + + asds = torch.nn.functional.interpolate( + asds, + size=(self.num_freqs,), + mode="linear", + ) + + asds = asds[:, :, self.freq_mask] + inv_asd = 1 / asds + inv_asd = inv_asd.repeat(x_fft.shape[0], 1, 1) + x_fft = x_fft[..., self.freq_mask] + + x_fft = torch.cat([x_fft.real, x_fft.imag, inv_asd], dim=1) + + x_low = x_low.reshape(-1, num_channels, self.kernel_size) + x_high = x_high.reshape(-1, num_channels, self.kernel_size) + x_fft = x_fft.reshape(x_low.shape[0], -1, x_fft.shape[-1]) + + return x_low, x_high, x_fft diff --git a/projects/train/train/data/supervised/__init__.py b/projects/train/train/data/supervised/__init__.py index 4ca80c178..ab559bde3 100644 --- a/projects/train/train/data/supervised/__init__.py +++ b/projects/train/train/data/supervised/__init__.py @@ -3,5 +3,6 @@ SpectrogramDomainSupervisedAframeDataset, ) from .multimodal import MultiModalSupervisedAframeDataset +from .multimodal_time_split import MultimodalTimeSplitSupervisedAframeDataset from .supervised import SupervisedAframeDataset from .time_domain import TimeDomainSupervisedAframeDataset diff --git a/projects/train/train/data/supervised/multimodal_time_split.py b/projects/train/train/data/supervised/multimodal_time_split.py new file mode 100644 index 000000000..b47153c44 --- /dev/null +++ b/projects/train/train/data/supervised/multimodal_time_split.py @@ -0,0 +1,156 @@ +from typing import Optional +import torch +import torch.nn.functional as F +from train.data.supervised.supervised import SupervisedAframeDataset + + +class MultimodalSupervisedAframeDataset(SupervisedAframeDataset): + def __init__( + self, + *args, + swap_prob: Optional[float] = None, + mute_prob: Optional[float] = None, + **kwargs, + ) -> None: + super().__init__( + *args, swap_prob=swap_prob, mute_prob=mute_prob, **kwargs + ) + + @torch.no_grad() + def augment(self, X, waveforms): + X, y, psds = super().augment(X, waveforms) + + psds = psds + + X_whitened = self.whitener(X, psds) + X_low = self.whitener( + X, + psds, + highpass=self.hparams.highpass, + lowpass=self.hparams.lowpass, + ).float() + X_high = self.whitener( + X, psds, highpass=self.hparams.lowpass, lowpass=None + ).float() + + X_fft = torch.fft.rfft(X_whitened, dim=-1) + asds = psds**0.5 + + freqs = torch.fft.rfftfreq( + X_whitened.shape[-1], d=1 / self.hparams.sample_rate + ) + num_freqs = len(freqs) + + asds = torch.nn.functional.interpolate( + asds, + size=(num_freqs,), + mode="linear", + ) + mask = torch.ones_like(freqs, dtype=torch.bool) + if self.hparams.highpass is not None: + mask &= freqs > self.hparams.highpass + # if self.hparams .lowpass is not None: + # mask &= freqs < self.hparams.lowpass + + asds = asds[:, :, mask] + asds *= 1e23 + inv_asds = 1 / asds + + X_fft = X_fft[:, :, mask] + X_fft = torch.cat([X_fft.real, X_fft.imag, inv_asds], dim=1).float() + if torch.isnan(X_low).any() or torch.isinf(X_low).any(): + raise ValueError("NaN or Inf in X_low") + if torch.isnan(X_fft).any() or torch.isinf(X_fft).any(): + raise ValueError("NaN or Inf in X_fft") + if torch.isnan(y).any() or torch.isinf(y).any(): + raise ValueError("NaN or Inf in y") + + return X_low, X_high, X_fft, y.float() + + def on_after_batch_transfer(self, batch, _): + """ + Perform on-device preprocessing after transferring batch to device. + + Augments data during training and injects signals into background during validation, + performing whitening and FFT-based preprocessing. + """ + if self.trainer.training: + # Training mode: perform random augmentations using waveforms + [X], waveforms = batch + return self.augment(X, waveforms) + + elif self.trainer.validating or self.trainer.sanity_checking: + # Validation mode: prepare signal-injected validation batches + [background, _, timeslide_idx], [signals] = batch + if isinstance(timeslide_idx, torch.Tensor): + timeslide_idx = timeslide_idx[0].item() + shift = float(self.timeslides[timeslide_idx].shift_size) + + # Build validation inputs and corresponding PSDs + X_bg, X_inj, psds = super().build_val_batches(background, signals) + + # Background: low/high-passed and FFT-processed + X_bg_whitened = self.whitener(X_bg, psds) + X_bg_low = self.whitener( + X_bg, + psds, + highpass=self.hparams.highpass, + lowpass=self.hparams.lowpass, + ).float() + X_bg_high = self.whitener( + X_bg, psds, highpass=self.hparams.lowpass, lowpass=None + ).float() + X_bg_fft = torch.fft.rfft(X_bg_whitened) + + # Foreground: process injected signals similarly + X_fg_whitened, X_fg_low, X_fg_high = [], [], [] + for inj in X_inj: + X_fg_low.append( + self.whitener( + inj, + psds, + lowpass=self.hparams.lowpass, + highpass=self.hparams.highpass, + ).float() + ) + X_fg_high.append( + self.whitener( + inj, psds, lowpass=None, highpass=self.hparams.lowpass + ).float() + ) + X_fg_whitened.append(self.whitener(inj, psds)) + X_fg_low = torch.stack(X_fg_low) + X_fg_high = torch.stack(X_fg_high) + X_fg_whitened = torch.stack(X_fg_whitened, dim=0) + X_fg_fft = torch.fft.rfft(X_fg_whitened) + + asds = psds**0.5 + asds *= 1e23 + asds = asds.float() + num_freqs = X_fg_fft.shape[-1] + if asds.shape[-1] != num_freqs: + asds = F.interpolate(asds, size=(num_freqs,), mode="linear") + inv_asds = 1 / asds + + X_bg_fft = torch.cat( + [X_bg_fft.real, X_bg_fft.imag, inv_asds], dim=1 + ).float() + inv_asds = inv_asds.unsqueeze(0).repeat(5, 1, 1, 1) + X_fg_fft = torch.cat( + [X_fg_fft.real, X_fg_fft.imag, inv_asds], dim=2 + ).float() + + # Return data grouped into background and injected signal components + return ( + shift, + X_bg_low, + X_bg_high, + X_bg_fft, + X_fg_low, + X_fg_high, + X_fg_fft, + psds, + ) + + # Default: return batch unchanged + return batch \ No newline at end of file diff --git a/projects/train/train/model/__init__.py b/projects/train/train/model/__init__.py index 5bbd13d30..d127ce481 100644 --- a/projects/train/train/model/__init__.py +++ b/projects/train/train/model/__init__.py @@ -4,4 +4,5 @@ SupervisedAframe, SupervisedAframeS4, SupervisedMultiModalAframe, + SupervisedMultiModalTimeSplitAframe, ) diff --git a/projects/train/train/model/supervised.py b/projects/train/train/model/supervised.py index 5fe42b241..7088895f6 100644 --- a/projects/train/train/model/supervised.py +++ b/projects/train/train/model/supervised.py @@ -126,3 +126,73 @@ def configure_optimizers(self): scheduler_config = {"scheduler": scheduler, "interval": "step"} return {"optimizer": optimizer, "lr_scheduler": scheduler_config} + +class SupervisedMultimodalTimeSplitAframe(AframeBase): + def __init__( + self, + arch: SupervisedArchitecture, + *args, + **kwargs, + ) -> None: + super().__init__(arch, *args, **kwargs) + + def forward(self, x_low: Tensor, x_high: Tensor, x_fft: Tensor) -> Tensor: + return self.model(x_low, x_high, x_fft) + + def train_step(self, batch: tuple) -> Tensor: + # Unpack depending on number of elements + if len(batch) != 4: + raise ValueError( + f"Unexpected batch format in train_step: {len(batch)} elements" + ) + + X_low, X_high, X_fft, y = batch + + y_hat = self(X_low, X_high, X_fft).squeeze(-1) + # Match shape of y_hat + y = y.float().view_as(y_hat) + return F.binary_cross_entropy_with_logits(y_hat, y) + + def score(self, X): + X_low, X_high, X_fft = X + return self.model(X_low, X_high, X_fft).squeeze(-1) + + def validation_step(self, batch, batch_idx): + try: + shift, X_bg, X_inj = batch + except ValueError: + ( + shift, + X_bg_low, + X_bg_high, + X_bg_fft, + X_fg_low, + X_fg_high, + X_fg_fft, + *_, + ) = batch + X_bg = (X_bg_low, X_bg_high, X_bg_fft) + X_inj = (X_fg_low, X_fg_high, X_fg_fft) + + # Score background + y_bg = self.score(X_bg) + + # Score injected signals + x0 = X_inj[0] + if x0.ndim >= 4: + V, B = x0.shape[:2] + x_flat = tuple(x.reshape(V * B, *x.shape[2:]) for x in X_inj) + y_fg = self.score(x_flat).view(V, B).mean(0) + else: + y_fg = self.score(X_inj) + + shift_val = float(shift) if not isinstance(shift, float) else shift + self.metric.update(shift_val, y_bg, y_fg) + + self.log( + "valid_auroc", + self.metric, + on_step=True, + on_epoch=True, + sync_dist=True, + ) \ No newline at end of file From a42a77969430470a0409a7cb41c6a0ffbb8aa2af Mon Sep 17 00:00:00 2001 From: Steven James Henderson Date: Wed, 27 Aug 2025 07:15:04 -0700 Subject: [PATCH 5/9] Changes to test for multimodal --- .../architectures/architectures/multimodal.py | 86 +++++++++++++++++++ projects/export/export/main.py | 30 +++++-- 2 files changed, 109 insertions(+), 7 deletions(-) create mode 100644 libs/architectures/architectures/multimodal.py diff --git a/libs/architectures/architectures/multimodal.py b/libs/architectures/architectures/multimodal.py new file mode 100644 index 000000000..c7d80ff31 --- /dev/null +++ b/libs/architectures/architectures/multimodal.py @@ -0,0 +1,86 @@ +import torch +import torch.nn as nn +from architectures.supervised import SupervisedArchitecture +from ml4gw.nn.resnet.resnet_1d import ResNet1D, NormLayer +from typing import Optional, Literal + + +class MultimodalSupervisedArchitecture(SupervisedArchitecture): + def __init__( + self, + num_ifos: int, + time_classes: int, + freq_classes: int, + time_layers: list[int], + freq_layers: list[int], + time_kernel_size: int = 3, + freq_kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + **kwargs, + ): + super().__init__() + + # Time-domain ResNets + self.strain_low_resnet = ResNet1D( + in_channels=num_ifos, + layers=time_layers, + classes=time_classes, + kernel_size=time_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + self.strain_high_resnet = ResNet1D( + in_channels=num_ifos, + layers=time_layers, + classes=time_classes, + kernel_size=time_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + # Frequency-domain ResNet + freq_input_channels = int(num_ifos * 3) + self.fft_resnet = ResNet1D( + in_channels=freq_input_channels, + layers=freq_layers, + classes=freq_classes, + kernel_size=freq_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + embed_dim = 2 * time_classes + freq_classes + self.classifier = nn.Linear(embed_dim, 1) + + def forward(self, x_low: torch.Tensor, x_high: torch.Tensor, x_fft: torch.Tensor): + if x_low.ndim == 4: + # [num_views, batch, C, L] -> [num_views * batch, C, L] + x_low = x_low.flatten(0, 1) + x_high = x_high.flatten(0, 1) + x_fft = x_fft.flatten(0, 1) + + assert x_low.shape[1] == self.strain_low_resnet.conv1.in_channels, f"x_low has wrong shape: {x_low.shape}" + assert x_high.shape[1] == self.strain_high_resnet.conv1.in_channels, f"x_high has wrong shape: {x_high.shape}" + assert x_fft.shape[1] == self.fft_resnet.conv1.in_channels, f"x_fft has wrong shape: {x_fft.shape}" + + low_out = self.strain_low_resnet(x_low) + high_out = self.strain_high_resnet(x_high) + fft_out = self.fft_resnet(x_fft) + + features = torch.cat([low_out, high_out, fft_out], dim=-1) + return self.classifier(features) + diff --git a/projects/export/export/main.py b/projects/export/export/main.py index b5f6ed9b3..c2b5dd669 100644 --- a/projects/export/export/main.py +++ b/projects/export/export/main.py @@ -107,13 +107,29 @@ def export( # load in the model graph logging.info("Initializing model graph") - - with open_file(weights, "rb") as f: - graph = nn = torch.jit.load(f, map_location="cpu") - - graph.eval() - logging.info(f"Initialize:\n{nn}") - +#UNCOMMENT BELOW FOR ACTUAL RUN + #with open_file(weights, "rb") as f: + # graph = nn = torch.jit.load(f, map_location="cpu") +# REMOVE BELOW FOR ACTUAL RUN + if weights.endswith(".ckpt"): + from train.model.multimodal import MultimodalAframe # or whatever class wraps your arch + + repo = qv.ModelRepository(repository_directory, clean=clean) + try: + aframe = repo.models["aframe"] + except KeyError: + aframe = repo.add("aframe", platform=platform) + + if aframe_instances is not None: + scale_model(aframe, aframe_instances) + + ckpt_model = MultimodalAframe.load_from_checkpoint(weights) + graph = ckpt_model.arch + +#UNCOMMENT ON ACTUAL RUN + #graph.eval() + logging.info(f"Initialize:\n{graph}") + #logging.info(f"Initialize:\n{nn}") # instantiate a model repository at the # indicated location. Split up the preprocessor # and the neural network (which we'll call aframe) From 3bfaf9612ccb98666828136e6af0fba33d05377a Mon Sep 17 00:00:00 2001 From: Steven Henderson Date: Thu, 30 Oct 2025 07:21:19 -0700 Subject: [PATCH 6/9] Added Multimodal time split arch --- libs/architectures/architectures/concat.py | 74 ----------------- .../architectures/multimodal_time_split.py | 79 +++++++++++++++++++ 2 files changed, 79 insertions(+), 74 deletions(-) delete mode 100644 libs/architectures/architectures/concat.py create mode 100644 libs/architectures/architectures/multimodal_time_split.py diff --git a/libs/architectures/architectures/concat.py b/libs/architectures/architectures/concat.py deleted file mode 100644 index d0a6a34a2..000000000 --- a/libs/architectures/architectures/concat.py +++ /dev/null @@ -1,74 +0,0 @@ -from typing import Literal, Optional - -from architectures.supervised import SupervisedArchitecture -from ml4gw.nn.resnet.resnet_1d import NormLayer, ResNet1D - -import torch - -class ConcatResNet(SupervisedArchitecture): - def __init__( - self, - v_dim: int, - v_layers: list[int], - hl_dim: int, - hl_layers: list[int], - num_ifos: int, - layers: list[int], - sample_rate: float, - kernel_length: float, - kernel_size: int = 3, - zero_init_residual: bool = False, - groups: int = 1, - width_per_group: int = 64, - stride_type: Optional[list[Literal["stride", "dilation"]]] = None, - norm_layer: Optional[NormLayer] = None, - ) -> None: - super().__init__( - ) - # Initialize resnets here - self.hlresnet = ResNet1D( - in_channels=2, - classes=hl_dim, - layers=hl_layers, - ) - self.vresnet = ResNet1D( - classes=v_dim, - layers=v_layers, - in_channels=1, - ) - # Initialize linear classifier here - self.classifier = torch.nn.Linear(hl_dim+v_dim, 1) - - def forward(self, X): - # Extract hl data and v data from X - hl = X[:, :2, :] - v = X[:, 2:, :] - # Pass hl and v through resnets - hl = self.hlresnet(hl) - v = self.vresnet(v) - # Concatenate hl and v outputs - concat = torch.concat([hl, v], dim=-1) - # Pass concatenated output through linear classifier - outputs = self.classifier(concat) - - return outputs - - - - - - - - - - - - - - - - - - - - diff --git a/libs/architectures/architectures/multimodal_time_split.py b/libs/architectures/architectures/multimodal_time_split.py new file mode 100644 index 000000000..885dcaf37 --- /dev/null +++ b/libs/architectures/architectures/multimodal_time_split.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +from architectures.supervised import SupervisedArchitecture +from ml4gw.nn.resnet.resnet_1d import ResNet1D, NormLayer +from typing import Optional, Literal + + +class MultimodalTimeSplitSupervisedArchitecture(SupervisedArchitecture): + def __init__( + self, + num_ifos: int, + low_time_classes: int, + high_time_classes: int, + freq_classes: int, + low_time_layers: list[int], + high_time_layers: list[int], + freq_layers: list[int], + time_kernel_size: int = 3, + freq_kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + **kwargs, + ): + super().__init__() + + # Time-domain ResNets + self.strain_low_resnet = ResNet1D( + in_channels=num_ifos, + layers=low_time_layers, + classes=low_time_classes, + kernel_size=time_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + self.strain_high_resnet = ResNet1D( + in_channels=num_ifos, + layers=high_time_layers, + classes=high_time_classes, + kernel_size=time_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + # Frequency-domain ResNet + freq_input_channels = int(num_ifos * 3) + self.fft_resnet = ResNet1D( + in_channels=freq_input_channels, + layers=freq_layers, + classes=freq_classes, + kernel_size=freq_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + embed_dim = high_time_classes + low_time_classes + freq_classes + self.classifier = nn.Linear(embed_dim, 1) + + def forward( + self, x_low: torch.Tensor, x_high: torch.Tensor, x_fft: torch.Tensor + ): + low_out = self.strain_low_resnet(x_low) + high_out = self.strain_high_resnet(x_high) + fft_out = self.fft_resnet(x_fft) + + features = torch.cat([low_out, high_out, fft_out], dim=-1) + return self.classifier(features) \ No newline at end of file From 570dafd3da9d37624d885ef361bef06f5d3c4cc3 Mon Sep 17 00:00:00 2001 From: Steven Henderson Date: Thu, 30 Oct 2025 07:22:35 -0700 Subject: [PATCH 7/9] Updated init file from previous commits --- libs/architectures/architectures/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/libs/architectures/architectures/__init__.py b/libs/architectures/architectures/__init__.py index 376fcf8d3..26032f78d 100644 --- a/libs/architectures/architectures/__init__.py +++ b/libs/architectures/architectures/__init__.py @@ -6,4 +6,3 @@ SupervisedSpectrogramDomainResNet, SupervisedTimeDomainResNet, ) -from .concat import ConcatResNet From 27cf05761dacbd583097f165be9dd1314a5801ec Mon Sep 17 00:00:00 2001 From: Steven Henderson Date: Thu, 30 Oct 2025 07:26:53 -0700 Subject: [PATCH 8/9] Renamed multimodal arch --- .../architectures/architectures/multimodal.py | 86 ------------------- 1 file changed, 86 deletions(-) delete mode 100644 libs/architectures/architectures/multimodal.py diff --git a/libs/architectures/architectures/multimodal.py b/libs/architectures/architectures/multimodal.py deleted file mode 100644 index c7d80ff31..000000000 --- a/libs/architectures/architectures/multimodal.py +++ /dev/null @@ -1,86 +0,0 @@ -import torch -import torch.nn as nn -from architectures.supervised import SupervisedArchitecture -from ml4gw.nn.resnet.resnet_1d import ResNet1D, NormLayer -from typing import Optional, Literal - - -class MultimodalSupervisedArchitecture(SupervisedArchitecture): - def __init__( - self, - num_ifos: int, - time_classes: int, - freq_classes: int, - time_layers: list[int], - freq_layers: list[int], - time_kernel_size: int = 3, - freq_kernel_size: int = 3, - zero_init_residual: bool = False, - groups: int = 1, - width_per_group: int = 64, - stride_type: Optional[list[Literal["stride", "dilation"]]] = None, - norm_layer: Optional[NormLayer] = None, - **kwargs, - ): - super().__init__() - - # Time-domain ResNets - self.strain_low_resnet = ResNet1D( - in_channels=num_ifos, - layers=time_layers, - classes=time_classes, - kernel_size=time_kernel_size, - zero_init_residual=zero_init_residual, - groups=groups, - width_per_group=width_per_group, - stride_type=stride_type, - norm_layer=norm_layer, - ) - - self.strain_high_resnet = ResNet1D( - in_channels=num_ifos, - layers=time_layers, - classes=time_classes, - kernel_size=time_kernel_size, - zero_init_residual=zero_init_residual, - groups=groups, - width_per_group=width_per_group, - stride_type=stride_type, - norm_layer=norm_layer, - ) - - # Frequency-domain ResNet - freq_input_channels = int(num_ifos * 3) - self.fft_resnet = ResNet1D( - in_channels=freq_input_channels, - layers=freq_layers, - classes=freq_classes, - kernel_size=freq_kernel_size, - zero_init_residual=zero_init_residual, - groups=groups, - width_per_group=width_per_group, - stride_type=stride_type, - norm_layer=norm_layer, - ) - - embed_dim = 2 * time_classes + freq_classes - self.classifier = nn.Linear(embed_dim, 1) - - def forward(self, x_low: torch.Tensor, x_high: torch.Tensor, x_fft: torch.Tensor): - if x_low.ndim == 4: - # [num_views, batch, C, L] -> [num_views * batch, C, L] - x_low = x_low.flatten(0, 1) - x_high = x_high.flatten(0, 1) - x_fft = x_fft.flatten(0, 1) - - assert x_low.shape[1] == self.strain_low_resnet.conv1.in_channels, f"x_low has wrong shape: {x_low.shape}" - assert x_high.shape[1] == self.strain_high_resnet.conv1.in_channels, f"x_high has wrong shape: {x_high.shape}" - assert x_fft.shape[1] == self.fft_resnet.conv1.in_channels, f"x_fft has wrong shape: {x_fft.shape}" - - low_out = self.strain_low_resnet(x_low) - high_out = self.strain_high_resnet(x_high) - fft_out = self.fft_resnet(x_fft) - - features = torch.cat([low_out, high_out, fft_out], dim=-1) - return self.classifier(features) - From d14677d11a95fa282d9cd380731825a9d9b25919 Mon Sep 17 00:00:00 2001 From: Steven Henderson Date: Thu, 30 Oct 2025 07:32:18 -0700 Subject: [PATCH 9/9] Removed submodules --- hermes | 1 - ml4gw | 1 - projects/export/export/main.py | 30 +++++++----------------------- pycondor | 1 - 4 files changed, 7 insertions(+), 26 deletions(-) delete mode 160000 hermes delete mode 160000 ml4gw delete mode 160000 pycondor diff --git a/hermes b/hermes deleted file mode 160000 index 86c6a2478..000000000 --- a/hermes +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 86c6a2478c93ac2e7cf50039de7528023579bd5e diff --git a/ml4gw b/ml4gw deleted file mode 160000 index 92ad1d047..000000000 --- a/ml4gw +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 92ad1d047e892d57a76477c22811b69979ba2066 diff --git a/projects/export/export/main.py b/projects/export/export/main.py index c2b5dd669..b5f6ed9b3 100644 --- a/projects/export/export/main.py +++ b/projects/export/export/main.py @@ -107,29 +107,13 @@ def export( # load in the model graph logging.info("Initializing model graph") -#UNCOMMENT BELOW FOR ACTUAL RUN - #with open_file(weights, "rb") as f: - # graph = nn = torch.jit.load(f, map_location="cpu") -# REMOVE BELOW FOR ACTUAL RUN - if weights.endswith(".ckpt"): - from train.model.multimodal import MultimodalAframe # or whatever class wraps your arch - - repo = qv.ModelRepository(repository_directory, clean=clean) - try: - aframe = repo.models["aframe"] - except KeyError: - aframe = repo.add("aframe", platform=platform) - - if aframe_instances is not None: - scale_model(aframe, aframe_instances) - - ckpt_model = MultimodalAframe.load_from_checkpoint(weights) - graph = ckpt_model.arch - -#UNCOMMENT ON ACTUAL RUN - #graph.eval() - logging.info(f"Initialize:\n{graph}") - #logging.info(f"Initialize:\n{nn}") + + with open_file(weights, "rb") as f: + graph = nn = torch.jit.load(f, map_location="cpu") + + graph.eval() + logging.info(f"Initialize:\n{nn}") + # instantiate a model repository at the # indicated location. Split up the preprocessor # and the neural network (which we'll call aframe) diff --git a/pycondor b/pycondor deleted file mode 160000 index 213774e98..000000000 --- a/pycondor +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 213774e985a00cda74b04e1da028c9edbe082f5b