diff --git a/libs/architectures/architectures/multimodal_time_split.py b/libs/architectures/architectures/multimodal_time_split.py new file mode 100644 index 00000000..885dcaf3 --- /dev/null +++ b/libs/architectures/architectures/multimodal_time_split.py @@ -0,0 +1,79 @@ +import torch +import torch.nn as nn +from architectures.supervised import SupervisedArchitecture +from ml4gw.nn.resnet.resnet_1d import ResNet1D, NormLayer +from typing import Optional, Literal + + +class MultimodalTimeSplitSupervisedArchitecture(SupervisedArchitecture): + def __init__( + self, + num_ifos: int, + low_time_classes: int, + high_time_classes: int, + freq_classes: int, + low_time_layers: list[int], + high_time_layers: list[int], + freq_layers: list[int], + time_kernel_size: int = 3, + freq_kernel_size: int = 3, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + stride_type: Optional[list[Literal["stride", "dilation"]]] = None, + norm_layer: Optional[NormLayer] = None, + **kwargs, + ): + super().__init__() + + # Time-domain ResNets + self.strain_low_resnet = ResNet1D( + in_channels=num_ifos, + layers=low_time_layers, + classes=low_time_classes, + kernel_size=time_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + self.strain_high_resnet = ResNet1D( + in_channels=num_ifos, + layers=high_time_layers, + classes=high_time_classes, + kernel_size=time_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + # Frequency-domain ResNet + freq_input_channels = int(num_ifos * 3) + self.fft_resnet = ResNet1D( + in_channels=freq_input_channels, + layers=freq_layers, + classes=freq_classes, + kernel_size=freq_kernel_size, + zero_init_residual=zero_init_residual, + groups=groups, + width_per_group=width_per_group, + stride_type=stride_type, + norm_layer=norm_layer, + ) + + embed_dim = high_time_classes + low_time_classes + freq_classes + self.classifier = nn.Linear(embed_dim, 1) + + def forward( + self, x_low: torch.Tensor, x_high: torch.Tensor, x_fft: torch.Tensor + ): + low_out = self.strain_low_resnet(x_low) + high_out = self.strain_high_resnet(x_high) + fft_out = self.fft_resnet(x_fft) + + features = torch.cat([low_out, high_out, fft_out], dim=-1) + return self.classifier(features) \ No newline at end of file diff --git a/libs/utils/utils/preprocessing.py b/libs/utils/utils/preprocessing.py index e06e6c78..321137d2 100644 --- a/libs/utils/utils/preprocessing.py +++ b/libs/utils/utils/preprocessing.py @@ -254,3 +254,131 @@ def forward(self, x: Tensor) -> Tensor: x_fft = torch.cat([x_fft.real, x_fft.imag, inv_asd], dim=1) return x, x_fft + +class MultiModalTimeSplitPreprocessor(torch.nn.Module): + """ + Preprocess a batch of waveforms for multimodal + training with a split time strain domain. + This includes whitening the time domain data and + calculating the frequency domain data + """ + def __init__( + self, + kernel_length: float, + sample_rate: float, + inference_sampling_rate: float, + batch_size: int, + fduration: float, + fftlength: float, + augmentor: Optional[Callable] = None, + highpass: Optional[float] = None, + lowpass: Optional[float] = None, + return_whitened: bool = False, + ) -> None: + super().__init__() + self.stride_size = int(sample_rate / inference_sampling_rate) + self.kernel_size = int(kernel_length * sample_rate) + self.augmentor = augmentor + self.return_whitened = return_whitened + self.highpass = highpass + self.lowpass = lowpass + self.fftlength = fftlength + self.n_fft = int(fftlength * sample_rate) + self.sample_rate = sample_rate + self.n_fft_psd = int(fftlength * sample_rate) + self.feature_nfft = self.kernel_size + + strides = (batch_size - 1) * self.stride_size + fsize = int(fduration * sample_rate) + size = strides + self.kernel_size + fsize + length = size / sample_rate + self.psd_estimator = PsdEstimator( + length, + sample_rate, + fftlength=fftlength, + overlap=None, + average="median", + fast=highpass is not None, + ) + self.whitener = Whiten(fduration, sample_rate, highpass, lowpass) + + freqs = torch.fft.rfftfreq(self.feature_nfft, d=1 / sample_rate) + self.freq_mask = torch.ones_like(freqs, dtype=torch.bool) + if highpass is not None: + self.freq_mask &= freqs > highpass + # if lowpass is not None: + # self.freq_mask &= freqs < lowpass + self.freqs = freqs + self.freq_mask = self.freq_mask + self.num_freqs = int(freqs.numel()) + + def forward(self, x: Tensor): + if x.ndim == 3: + num_channels = x.size(1) + elif x.ndim == 2: + num_channels = x.size(0) + else: + raise ValueError(f"Unexpected input shape: {x.shape}") + + x, psd = self.psd_estimator(x) + whitened = self.whitener(x, psd) + whitened_low = self.whitener( + x, psd, highpass=self.highpass, lowpass=self.lowpass + ) + whitened_high = self.whitener( + x, psd, highpass=self.lowpass, lowpass=None + ) + + x = x.float() + + asds = psd**0.5 + asds = asds.float() + asds *= 1e23 + + # ensure asd is 3D so mask works even if asd is never interpolated + if asds.ndim == 2: + asds = asds.unsqueeze(0) + + # unfold x and other inputs and then put into expected shape. + # Note that if x has both signal and background + # batch elements, they will be interleaved along + # the batch dimension after unfolding + x = unfold_windows(whitened, self.kernel_size, self.stride_size) + x_low = unfold_windows( + whitened_low, self.kernel_size, self.stride_size + ) + x_high = unfold_windows( + whitened_high, self.kernel_size, self.stride_size + ) + x = x.reshape(-1, num_channels, self.kernel_size) + + x_fft = torch.fft.rfft(x, n=self.feature_nfft, dim=-1) + F_expected = self.num_freqs + F_actual = x_fft.shape[-1] + if F_actual != F_expected: + raise ValueError( + f"""FFT bin mismatch: got F={F_actual} from rfft," \ + expected {F_expected}""" + f"(kernel_size={self.kernel_size}, n_fft={self.n_fft}). " + "Ensure fftlength*sample_rate matches the " \ + "intended kernel-length FFT." + ) + + asds = torch.nn.functional.interpolate( + asds, + size=(self.num_freqs,), + mode="linear", + ) + + asds = asds[:, :, self.freq_mask] + inv_asd = 1 / asds + inv_asd = inv_asd.repeat(x_fft.shape[0], 1, 1) + x_fft = x_fft[..., self.freq_mask] + + x_fft = torch.cat([x_fft.real, x_fft.imag, inv_asd], dim=1) + + x_low = x_low.reshape(-1, num_channels, self.kernel_size) + x_high = x_high.reshape(-1, num_channels, self.kernel_size) + x_fft = x_fft.reshape(x_low.shape[0], -1, x_fft.shape[-1]) + + return x_low, x_high, x_fft diff --git a/projects/train/train/data/supervised/__init__.py b/projects/train/train/data/supervised/__init__.py index 4ca80c17..ab559bde 100644 --- a/projects/train/train/data/supervised/__init__.py +++ b/projects/train/train/data/supervised/__init__.py @@ -3,5 +3,6 @@ SpectrogramDomainSupervisedAframeDataset, ) from .multimodal import MultiModalSupervisedAframeDataset +from .multimodal_time_split import MultimodalTimeSplitSupervisedAframeDataset from .supervised import SupervisedAframeDataset from .time_domain import TimeDomainSupervisedAframeDataset diff --git a/projects/train/train/data/supervised/multimodal_time_split.py b/projects/train/train/data/supervised/multimodal_time_split.py new file mode 100644 index 00000000..b47153c4 --- /dev/null +++ b/projects/train/train/data/supervised/multimodal_time_split.py @@ -0,0 +1,156 @@ +from typing import Optional +import torch +import torch.nn.functional as F +from train.data.supervised.supervised import SupervisedAframeDataset + + +class MultimodalSupervisedAframeDataset(SupervisedAframeDataset): + def __init__( + self, + *args, + swap_prob: Optional[float] = None, + mute_prob: Optional[float] = None, + **kwargs, + ) -> None: + super().__init__( + *args, swap_prob=swap_prob, mute_prob=mute_prob, **kwargs + ) + + @torch.no_grad() + def augment(self, X, waveforms): + X, y, psds = super().augment(X, waveforms) + + psds = psds + + X_whitened = self.whitener(X, psds) + X_low = self.whitener( + X, + psds, + highpass=self.hparams.highpass, + lowpass=self.hparams.lowpass, + ).float() + X_high = self.whitener( + X, psds, highpass=self.hparams.lowpass, lowpass=None + ).float() + + X_fft = torch.fft.rfft(X_whitened, dim=-1) + asds = psds**0.5 + + freqs = torch.fft.rfftfreq( + X_whitened.shape[-1], d=1 / self.hparams.sample_rate + ) + num_freqs = len(freqs) + + asds = torch.nn.functional.interpolate( + asds, + size=(num_freqs,), + mode="linear", + ) + mask = torch.ones_like(freqs, dtype=torch.bool) + if self.hparams.highpass is not None: + mask &= freqs > self.hparams.highpass + # if self.hparams .lowpass is not None: + # mask &= freqs < self.hparams.lowpass + + asds = asds[:, :, mask] + asds *= 1e23 + inv_asds = 1 / asds + + X_fft = X_fft[:, :, mask] + X_fft = torch.cat([X_fft.real, X_fft.imag, inv_asds], dim=1).float() + if torch.isnan(X_low).any() or torch.isinf(X_low).any(): + raise ValueError("NaN or Inf in X_low") + if torch.isnan(X_fft).any() or torch.isinf(X_fft).any(): + raise ValueError("NaN or Inf in X_fft") + if torch.isnan(y).any() or torch.isinf(y).any(): + raise ValueError("NaN or Inf in y") + + return X_low, X_high, X_fft, y.float() + + def on_after_batch_transfer(self, batch, _): + """ + Perform on-device preprocessing after transferring batch to device. + + Augments data during training and injects signals into background during validation, + performing whitening and FFT-based preprocessing. + """ + if self.trainer.training: + # Training mode: perform random augmentations using waveforms + [X], waveforms = batch + return self.augment(X, waveforms) + + elif self.trainer.validating or self.trainer.sanity_checking: + # Validation mode: prepare signal-injected validation batches + [background, _, timeslide_idx], [signals] = batch + if isinstance(timeslide_idx, torch.Tensor): + timeslide_idx = timeslide_idx[0].item() + shift = float(self.timeslides[timeslide_idx].shift_size) + + # Build validation inputs and corresponding PSDs + X_bg, X_inj, psds = super().build_val_batches(background, signals) + + # Background: low/high-passed and FFT-processed + X_bg_whitened = self.whitener(X_bg, psds) + X_bg_low = self.whitener( + X_bg, + psds, + highpass=self.hparams.highpass, + lowpass=self.hparams.lowpass, + ).float() + X_bg_high = self.whitener( + X_bg, psds, highpass=self.hparams.lowpass, lowpass=None + ).float() + X_bg_fft = torch.fft.rfft(X_bg_whitened) + + # Foreground: process injected signals similarly + X_fg_whitened, X_fg_low, X_fg_high = [], [], [] + for inj in X_inj: + X_fg_low.append( + self.whitener( + inj, + psds, + lowpass=self.hparams.lowpass, + highpass=self.hparams.highpass, + ).float() + ) + X_fg_high.append( + self.whitener( + inj, psds, lowpass=None, highpass=self.hparams.lowpass + ).float() + ) + X_fg_whitened.append(self.whitener(inj, psds)) + X_fg_low = torch.stack(X_fg_low) + X_fg_high = torch.stack(X_fg_high) + X_fg_whitened = torch.stack(X_fg_whitened, dim=0) + X_fg_fft = torch.fft.rfft(X_fg_whitened) + + asds = psds**0.5 + asds *= 1e23 + asds = asds.float() + num_freqs = X_fg_fft.shape[-1] + if asds.shape[-1] != num_freqs: + asds = F.interpolate(asds, size=(num_freqs,), mode="linear") + inv_asds = 1 / asds + + X_bg_fft = torch.cat( + [X_bg_fft.real, X_bg_fft.imag, inv_asds], dim=1 + ).float() + inv_asds = inv_asds.unsqueeze(0).repeat(5, 1, 1, 1) + X_fg_fft = torch.cat( + [X_fg_fft.real, X_fg_fft.imag, inv_asds], dim=2 + ).float() + + # Return data grouped into background and injected signal components + return ( + shift, + X_bg_low, + X_bg_high, + X_bg_fft, + X_fg_low, + X_fg_high, + X_fg_fft, + psds, + ) + + # Default: return batch unchanged + return batch \ No newline at end of file diff --git a/projects/train/train/model/__init__.py b/projects/train/train/model/__init__.py index 5bbd13d3..d127ce48 100644 --- a/projects/train/train/model/__init__.py +++ b/projects/train/train/model/__init__.py @@ -4,4 +4,5 @@ SupervisedAframe, SupervisedAframeS4, SupervisedMultiModalAframe, + SupervisedMultiModalTimeSplitAframe, ) diff --git a/projects/train/train/model/supervised.py b/projects/train/train/model/supervised.py index 5fe42b24..7088895f 100644 --- a/projects/train/train/model/supervised.py +++ b/projects/train/train/model/supervised.py @@ -126,3 +126,73 @@ def configure_optimizers(self): scheduler_config = {"scheduler": scheduler, "interval": "step"} return {"optimizer": optimizer, "lr_scheduler": scheduler_config} + +class SupervisedMultimodalTimeSplitAframe(AframeBase): + def __init__( + self, + arch: SupervisedArchitecture, + *args, + **kwargs, + ) -> None: + super().__init__(arch, *args, **kwargs) + + def forward(self, x_low: Tensor, x_high: Tensor, x_fft: Tensor) -> Tensor: + return self.model(x_low, x_high, x_fft) + + def train_step(self, batch: tuple) -> Tensor: + # Unpack depending on number of elements + if len(batch) != 4: + raise ValueError( + f"Unexpected batch format in train_step: {len(batch)} elements" + ) + + X_low, X_high, X_fft, y = batch + + y_hat = self(X_low, X_high, X_fft).squeeze(-1) + # Match shape of y_hat + y = y.float().view_as(y_hat) + return F.binary_cross_entropy_with_logits(y_hat, y) + + def score(self, X): + X_low, X_high, X_fft = X + return self.model(X_low, X_high, X_fft).squeeze(-1) + + def validation_step(self, batch, batch_idx): + try: + shift, X_bg, X_inj = batch + except ValueError: + ( + shift, + X_bg_low, + X_bg_high, + X_bg_fft, + X_fg_low, + X_fg_high, + X_fg_fft, + *_, + ) = batch + X_bg = (X_bg_low, X_bg_high, X_bg_fft) + X_inj = (X_fg_low, X_fg_high, X_fg_fft) + + # Score background + y_bg = self.score(X_bg) + + # Score injected signals + x0 = X_inj[0] + if x0.ndim >= 4: + V, B = x0.shape[:2] + x_flat = tuple(x.reshape(V * B, *x.shape[2:]) for x in X_inj) + y_fg = self.score(x_flat).view(V, B).mean(0) + else: + y_fg = self.score(X_inj) + + shift_val = float(shift) if not isinstance(shift, float) else shift + self.metric.update(shift_val, y_bg, y_fg) + + self.log( + "valid_auroc", + self.metric, + on_step=True, + on_epoch=True, + sync_dist=True, + ) \ No newline at end of file