Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions libs/architectures/architectures/multimodal_time_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch
import torch.nn as nn
from architectures.supervised import SupervisedArchitecture
from ml4gw.nn.resnet.resnet_1d import ResNet1D, NormLayer
from typing import Optional, Literal


class MultimodalTimeSplitSupervisedArchitecture(SupervisedArchitecture):
def __init__(
self,
num_ifos: int,
low_time_classes: int,
high_time_classes: int,
freq_classes: int,
low_time_layers: list[int],
high_time_layers: list[int],
freq_layers: list[int],
time_kernel_size: int = 3,
freq_kernel_size: int = 3,
zero_init_residual: bool = False,
groups: int = 1,
width_per_group: int = 64,
stride_type: Optional[list[Literal["stride", "dilation"]]] = None,
norm_layer: Optional[NormLayer] = None,
**kwargs,
):
super().__init__()

# Time-domain ResNets
self.strain_low_resnet = ResNet1D(
in_channels=num_ifos,
layers=low_time_layers,
classes=low_time_classes,
kernel_size=time_kernel_size,
zero_init_residual=zero_init_residual,
groups=groups,
width_per_group=width_per_group,
stride_type=stride_type,
norm_layer=norm_layer,
)

self.strain_high_resnet = ResNet1D(
in_channels=num_ifos,
layers=high_time_layers,
classes=high_time_classes,
kernel_size=time_kernel_size,
zero_init_residual=zero_init_residual,
groups=groups,
width_per_group=width_per_group,
stride_type=stride_type,
norm_layer=norm_layer,
)

# Frequency-domain ResNet
freq_input_channels = int(num_ifos * 3)
self.fft_resnet = ResNet1D(
in_channels=freq_input_channels,
layers=freq_layers,
classes=freq_classes,
kernel_size=freq_kernel_size,
zero_init_residual=zero_init_residual,
groups=groups,
width_per_group=width_per_group,
stride_type=stride_type,
norm_layer=norm_layer,
)

embed_dim = high_time_classes + low_time_classes + freq_classes
self.classifier = nn.Linear(embed_dim, 1)

def forward(
self, x_low: torch.Tensor, x_high: torch.Tensor, x_fft: torch.Tensor
):
low_out = self.strain_low_resnet(x_low)
high_out = self.strain_high_resnet(x_high)
fft_out = self.fft_resnet(x_fft)

features = torch.cat([low_out, high_out, fft_out], dim=-1)
return self.classifier(features)
128 changes: 128 additions & 0 deletions libs/utils/utils/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,131 @@ def forward(self, x: Tensor) -> Tensor:
x_fft = torch.cat([x_fft.real, x_fft.imag, inv_asd], dim=1)

return x, x_fft

class MultiModalTimeSplitPreprocessor(torch.nn.Module):
"""
Preprocess a batch of waveforms for multimodal
training with a split time strain domain.
This includes whitening the time domain data and
calculating the frequency domain data
"""
def __init__(
self,
kernel_length: float,
sample_rate: float,
inference_sampling_rate: float,
batch_size: int,
fduration: float,
fftlength: float,
augmentor: Optional[Callable] = None,
highpass: Optional[float] = None,
lowpass: Optional[float] = None,
return_whitened: bool = False,
) -> None:
super().__init__()
self.stride_size = int(sample_rate / inference_sampling_rate)
self.kernel_size = int(kernel_length * sample_rate)
self.augmentor = augmentor
self.return_whitened = return_whitened
self.highpass = highpass
self.lowpass = lowpass
self.fftlength = fftlength
self.n_fft = int(fftlength * sample_rate)
self.sample_rate = sample_rate
self.n_fft_psd = int(fftlength * sample_rate)
self.feature_nfft = self.kernel_size

strides = (batch_size - 1) * self.stride_size
fsize = int(fduration * sample_rate)
size = strides + self.kernel_size + fsize
length = size / sample_rate
self.psd_estimator = PsdEstimator(
length,
sample_rate,
fftlength=fftlength,
overlap=None,
average="median",
fast=highpass is not None,
)
self.whitener = Whiten(fduration, sample_rate, highpass, lowpass)

freqs = torch.fft.rfftfreq(self.feature_nfft, d=1 / sample_rate)
self.freq_mask = torch.ones_like(freqs, dtype=torch.bool)
if highpass is not None:
self.freq_mask &= freqs > highpass
# if lowpass is not None:
# self.freq_mask &= freqs < lowpass
self.freqs = freqs
self.freq_mask = self.freq_mask
self.num_freqs = int(freqs.numel())

def forward(self, x: Tensor):
if x.ndim == 3:
num_channels = x.size(1)
elif x.ndim == 2:
num_channels = x.size(0)
else:
raise ValueError(f"Unexpected input shape: {x.shape}")

x, psd = self.psd_estimator(x)
whitened = self.whitener(x, psd)
whitened_low = self.whitener(
x, psd, highpass=self.highpass, lowpass=self.lowpass
)
whitened_high = self.whitener(
x, psd, highpass=self.lowpass, lowpass=None
)

x = x.float()

asds = psd**0.5
asds = asds.float()
asds *= 1e23

# ensure asd is 3D so mask works even if asd is never interpolated
if asds.ndim == 2:
asds = asds.unsqueeze(0)

# unfold x and other inputs and then put into expected shape.
# Note that if x has both signal and background
# batch elements, they will be interleaved along
# the batch dimension after unfolding
x = unfold_windows(whitened, self.kernel_size, self.stride_size)
x_low = unfold_windows(
whitened_low, self.kernel_size, self.stride_size
)
x_high = unfold_windows(
whitened_high, self.kernel_size, self.stride_size
)
x = x.reshape(-1, num_channels, self.kernel_size)

x_fft = torch.fft.rfft(x, n=self.feature_nfft, dim=-1)
F_expected = self.num_freqs
F_actual = x_fft.shape[-1]
if F_actual != F_expected:
raise ValueError(
f"""FFT bin mismatch: got F={F_actual} from rfft," \
expected {F_expected}"""
f"(kernel_size={self.kernel_size}, n_fft={self.n_fft}). "
"Ensure fftlength*sample_rate matches the " \
"intended kernel-length FFT."
)

asds = torch.nn.functional.interpolate(
asds,
size=(self.num_freqs,),
mode="linear",
)

asds = asds[:, :, self.freq_mask]
inv_asd = 1 / asds
inv_asd = inv_asd.repeat(x_fft.shape[0], 1, 1)
x_fft = x_fft[..., self.freq_mask]

x_fft = torch.cat([x_fft.real, x_fft.imag, inv_asd], dim=1)

x_low = x_low.reshape(-1, num_channels, self.kernel_size)
x_high = x_high.reshape(-1, num_channels, self.kernel_size)
x_fft = x_fft.reshape(x_low.shape[0], -1, x_fft.shape[-1])

return x_low, x_high, x_fft
1 change: 1 addition & 0 deletions projects/train/train/data/supervised/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@
SpectrogramDomainSupervisedAframeDataset,
)
from .multimodal import MultiModalSupervisedAframeDataset
from .multimodal_time_split import MultimodalTimeSplitSupervisedAframeDataset
from .supervised import SupervisedAframeDataset
from .time_domain import TimeDomainSupervisedAframeDataset
156 changes: 156 additions & 0 deletions projects/train/train/data/supervised/multimodal_time_split.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
from typing import Optional
import torch
import torch.nn.functional as F
from train.data.supervised.supervised import SupervisedAframeDataset


class MultimodalSupervisedAframeDataset(SupervisedAframeDataset):
def __init__(
self,
*args,
swap_prob: Optional[float] = None,
mute_prob: Optional[float] = None,
**kwargs,
) -> None:
super().__init__(
*args, swap_prob=swap_prob, mute_prob=mute_prob, **kwargs
)

@torch.no_grad()
def augment(self, X, waveforms):
X, y, psds = super().augment(X, waveforms)

psds = psds

X_whitened = self.whitener(X, psds)
X_low = self.whitener(
X,
psds,
highpass=self.hparams.highpass,
lowpass=self.hparams.lowpass,
).float()
X_high = self.whitener(
X, psds, highpass=self.hparams.lowpass, lowpass=None
).float()

X_fft = torch.fft.rfft(X_whitened, dim=-1)
asds = psds**0.5

freqs = torch.fft.rfftfreq(
X_whitened.shape[-1], d=1 / self.hparams.sample_rate
)
num_freqs = len(freqs)

asds = torch.nn.functional.interpolate(
asds,
size=(num_freqs,),
mode="linear",
)
mask = torch.ones_like(freqs, dtype=torch.bool)
if self.hparams.highpass is not None:
mask &= freqs > self.hparams.highpass
# if self.hparams .lowpass is not None:
# mask &= freqs < self.hparams.lowpass

asds = asds[:, :, mask]
asds *= 1e23
inv_asds = 1 / asds

X_fft = X_fft[:, :, mask]
X_fft = torch.cat([X_fft.real, X_fft.imag, inv_asds], dim=1).float()
if torch.isnan(X_low).any() or torch.isinf(X_low).any():
raise ValueError("NaN or Inf in X_low")
if torch.isnan(X_fft).any() or torch.isinf(X_fft).any():
raise ValueError("NaN or Inf in X_fft")
if torch.isnan(y).any() or torch.isinf(y).any():
raise ValueError("NaN or Inf in y")

return X_low, X_high, X_fft, y.float()

def on_after_batch_transfer(self, batch, _):
"""
Perform on-device preprocessing after transferring batch to device.

Augments data during training and injects signals into background during validation,
performing whitening and FFT-based preprocessing.
"""
if self.trainer.training:
# Training mode: perform random augmentations using waveforms
[X], waveforms = batch
return self.augment(X, waveforms)

elif self.trainer.validating or self.trainer.sanity_checking:
# Validation mode: prepare signal-injected validation batches
[background, _, timeslide_idx], [signals] = batch
if isinstance(timeslide_idx, torch.Tensor):
timeslide_idx = timeslide_idx[0].item()
shift = float(self.timeslides[timeslide_idx].shift_size)

# Build validation inputs and corresponding PSDs
X_bg, X_inj, psds = super().build_val_batches(background, signals)

# Background: low/high-passed and FFT-processed
X_bg_whitened = self.whitener(X_bg, psds)
X_bg_low = self.whitener(
X_bg,
psds,
highpass=self.hparams.highpass,
lowpass=self.hparams.lowpass,
).float()
X_bg_high = self.whitener(
X_bg, psds, highpass=self.hparams.lowpass, lowpass=None
).float()
X_bg_fft = torch.fft.rfft(X_bg_whitened)

# Foreground: process injected signals similarly
X_fg_whitened, X_fg_low, X_fg_high = [], [], []
for inj in X_inj:
X_fg_low.append(
self.whitener(
inj,
psds,
lowpass=self.hparams.lowpass,
highpass=self.hparams.highpass,
).float()
)
X_fg_high.append(
self.whitener(
inj, psds, lowpass=None, highpass=self.hparams.lowpass
).float()
)
X_fg_whitened.append(self.whitener(inj, psds))
X_fg_low = torch.stack(X_fg_low)
X_fg_high = torch.stack(X_fg_high)
X_fg_whitened = torch.stack(X_fg_whitened, dim=0)
X_fg_fft = torch.fft.rfft(X_fg_whitened)

asds = psds**0.5
asds *= 1e23
asds = asds.float()
num_freqs = X_fg_fft.shape[-1]
if asds.shape[-1] != num_freqs:
asds = F.interpolate(asds, size=(num_freqs,), mode="linear")
inv_asds = 1 / asds

X_bg_fft = torch.cat(
[X_bg_fft.real, X_bg_fft.imag, inv_asds], dim=1
).float()
inv_asds = inv_asds.unsqueeze(0).repeat(5, 1, 1, 1)
X_fg_fft = torch.cat(
[X_fg_fft.real, X_fg_fft.imag, inv_asds], dim=2
).float()

# Return data grouped into background and injected signal components
return (
shift,
X_bg_low,
X_bg_high,
X_bg_fft,
X_fg_low,
X_fg_high,
X_fg_fft,
psds,
)

# Default: return batch unchanged
return batch
1 change: 1 addition & 0 deletions projects/train/train/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
SupervisedAframe,
SupervisedAframeS4,
SupervisedMultiModalAframe,
SupervisedMultiModalTimeSplitAframe,
)
Loading
Loading