Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 37 additions & 23 deletions amplfi/train/callbacks/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,12 @@ def __init__(self, outdir: Path, num_plot: int, save_data: bool = True):
self.num_plot = num_plot
self.save_data = save_data

def on_test_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
) -> None:
def plot_strain(self, outdir, result, batch, trainer):
"""
Called at the end of each test step.
`outputs` consists of objects returned by `pl_module.test_step`.
"""

# test_step returns bilby result object
result = outputs

if batch_idx >= self.num_plot:
return

outdir = self.outdir / f"event_{batch_idx}"
outdir.mkdir(exist_ok=True)

# unpack batch
strain, asds, _ = batch
strain, asds = strain[0].cpu().numpy(), asds[0].cpu().numpy()
Expand Down Expand Up @@ -165,12 +154,31 @@ def on_test_batch_end(
plt.savefig(whitened_fd_strain_fname)
plt.close()

def on_test_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
) -> None:
result = outputs
# test_step returns bilby result object

if batch_idx >= self.num_plot:
return

outdir = self.outdir / f"event_{batch_idx}"
outdir.mkdir(exist_ok=True)
self.plot_strain(outdir, result, batch, trainer)

def on_predict_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
):
return self.on_test_batch_end(
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
)
result = outputs
# test_step returns bilby result object

if batch_idx >= self.num_plot:
return

gpstime = batch[2].cpu().numpy()[0]
outdir = self.outdir / f"event_{int(gpstime)}"
self.plot_strain(outdir, result, batch, trainer)


class PlotMollview(pl.Callback):
Expand All @@ -183,6 +191,12 @@ def __init__(self, outdir: Path, nside: int):
self.outdir = outdir
self.nside = nside

def plot_mollview(self, outdir: Path, result: "AmplfiResult"):
result.plot_mollview(
self.nside,
outpath=outdir / "mollview.png",
)

def on_test_batch_end(
self,
trainer,
Expand All @@ -202,18 +216,18 @@ def on_test_batch_end(

outdir = self.outdir / f"event_{batch_idx}"
outdir.mkdir(exist_ok=True)
skymap_filename = outdir / "mollview.png"
result.plot_mollview(
self.nside,
outpath=skymap_filename,
)
self.plot_mollview(outdir, result)

def on_predict_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0
):
return self.on_test_batch_end(
trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
)
# test_step returns bilby result object
result = outputs

gpstime = batch[2].cpu().numpy()[0]
outdir = self.outdir / f"event_{int(gpstime)}"
outdir.mkdir(exist_ok=True)
self.plot_mollview(outdir, result)


class PlotCorner(pl.Callback):
Expand Down
85 changes: 51 additions & 34 deletions amplfi/train/data/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ def val_dataloader(self):

# build background dataloader
val_background = self.val_background[0][:, rank:]

background_dataset = InMemoryDataset(
val_background,
kernel_size=int(self.hparams.sample_rate * self.sample_length),
Expand Down Expand Up @@ -514,7 +513,7 @@ def test_dataloader(self):
batch_size=self.hparams.batch_size,
coincident=False,
batches_per_epoch=self.hparams.batches_per_epoch,
shuffle=True,
shuffle=False,
)
else:
background_dataset = Hdf5TimeSeriesDataset(
Expand All @@ -527,7 +526,7 @@ def test_dataloader(self):
)

background_dataloader = torch.utils.data.DataLoader(
background_dataset, pin_memory=False, num_workers=10
background_dataset, shuffle=False, pin_memory=False, num_workers=1
)
return ZippedDataset(
waveform_dataloader,
Expand All @@ -545,7 +544,10 @@ def inject(self, *args, **kwargs):
raise NotImplementedError

def background_from_gpstimes(
self, gpstimes: np.ndarray, fnames: List[Path]
self,
gpstimes: np.ndarray,
fnames: List[Path],
timeslides: Optional[np.ndarray] = None,
Comment on lines +548 to +550
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we raise a RunTimeError checking that the length of gpstimes and timeslides is the same.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good idea

) -> torch.Tensor:
"""
Construct a Tensor of background segments corresponding
Expand All @@ -554,12 +556,18 @@ def background_from_gpstimes(
"""

# load in background segments corresponding to gpstimes
background = []
segments = [
tuple(map(float, f.name.split(".")[0].split("-")[1:]))
for f in fnames
]

# apply timeslides if specified
if timeslides is None:
timeslides = np.zeros((len(gpstimes), self.num_ifos))
else:
self._logger.info("Applying timeshifts to data")
timeslides = timeslides[:, : self.num_ifos]

def find_file(time: float) -> Optional[Path]:
"""
Find file that contains `time`
Expand All @@ -583,41 +591,50 @@ def find_file(time: float) -> Optional[Path]:
)

# convert to number of indices
post = int(post * self.hparams.sample_rate)
pre = int(pre * self.hparams.sample_rate)
num_post = int(post * self.hparams.sample_rate)
num_pre = int(pre * self.hparams.sample_rate)

background = []
for time in gpstimes:
time = time.item()
for time, shifts in zip(gpstimes, timeslides, strict=True):
strain = []

# find file for this gpstime
file, start = find_file(time)

# if none exists, use random segment
if file is None:
self._logger.info(
"No segment in testing directory containing "
f"{time}. Using random segment"
)
file = random.choice(self.test_fnames)
start, length = list(
map(float, file.name.split(".")[0].split("-")[1:])
)
time = start + random.randint(
self.sample_length,
length - self.sample_length,
# loop over ifo shifts
for ifo, shift in zip(self.hparams.ifos, shifts, strict=True):
shifted_time = time + shift
self._logger.debug(
f"Shifted {time} by {shift} seconds "
f"to {shifted_time} for ifo {ifo}"
)

# convert from time to index in file
middle_idx = int((time - start) * self.hparams.sample_rate)
start_idx = middle_idx + pre
end_idx = middle_idx + post
file, start = find_file(shifted_time)
# find file for this gpstime

# if none exists, use random segment
if file is None:
self._logger.info(
"No segment in testing directory containing "
f"{time}. Using random segment"
)
file = random.choice(self.test_fnames)
start, length = list(
map(float, file.name.split(".")[0].split("-")[1:])
)
time = start + random.randint(
-int(pre),
int(length - post),
)
else:
self._logger.info(f"Found segment for {shifted_time}")

# convert from time to index in file
middle_idx = int(
(shifted_time - start) * self.hparams.sample_rate
)
start_idx = middle_idx + num_pre
end_idx = middle_idx + num_post

with h5py.File(file) as f:
for ifo in self.hparams.ifos:
with h5py.File(file) as f:
strain.append(torch.tensor(f[ifo][start_idx:end_idx]))
strain = torch.stack(strain, dim=0)
background.append(strain)
strain = torch.stack(strain, dim=0)
background.append(strain)
background = torch.stack(background, dim=0)
return background
38 changes: 32 additions & 6 deletions amplfi/train/data/datasets/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def setup(self, stage: str):
self.waveforms = torch.stack([cross, plus], dim=0)
self.parameters = torch.column_stack(params)
self.background = self.background_from_gpstimes(
parameters["gpstime"], self.test_fnames
parameters["gpstime"], self.test_fnames, timeslides=None
)

# once we've generated validation/testing waveforms on cpu,
Expand Down Expand Up @@ -390,22 +390,48 @@ class RawStrainTestingDataset(FlowDataset):
the gps times of the events to analyze. If a path is
passed, the gps times should be stored in a dataset
named `gpstimes`
timesides:
A tuple of floats, list of tuple of floats, or path
to an hdf5 file containing a `timeslides` dataset,
containing a list of tuple of floats. The tuple
corresponds to a timeshift to apply to each detector
relative to the gpstime. If left as `None`, no
timeshifts will be applied.

"""

def __init__(
self, *args, gpstimes: Union[float, list[float], Path], **kwargs
self,
*args,
gpstimes: Union[float, list[float], Path],
timeslides: Optional[Union[float, list[tuple[float]], Path]] = None,
**kwargs,
):
self.gpstimes = self.parse_gps_times(np.array(gpstimes))
self.gpstimes, self.timeslides = self.parse_gpstimes_and_timeslides(
gpstimes, timeslides
)
super().__init__(*args, **kwargs)

def parse_gps_times(self, gpstimes: Union[float, np.ndarray, Path]):
def parse_gpstimes_and_timeslides(
self, gpstimes: Union[float, np.ndarray, Path], timeslides
):
if isinstance(gpstimes, (float, int)):
gpstimes = np.array([gpstimes])
elif isinstance(gpstimes, Path):
with h5py.File(gpstimes, "r") as f:
gpstimes = f["gpstimes"][:]
else:
gpstimes = np.array(gpstimes)

if isinstance(timeslides, (float, int)):
timeslides = np.array([timeslides])
elif isinstance(timeslides, Path):
with h5py.File(timeslides, "r") as f:
timeslides = f["timeslides"][:]
elif isinstance(timeslides, list):
timeslides = np.array(timeslides)

return gpstimes
return gpstimes, timeslides

def setup(self, stage: str):
world_size, rank = self.get_world_size_and_rank()
Expand All @@ -416,7 +442,7 @@ def setup(self, stage: str):
)

self.background = self.background_from_gpstimes(
self.gpstimes, self.test_fnames
self.gpstimes, self.test_fnames, self.timeslides
)

# once we've generated validation/testing waveforms on cpu,
Expand Down
5 changes: 4 additions & 1 deletion amplfi/train/models/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def validation_step(self, batch, _):

def test_step(self, batch, batch_idx) -> AmplfiResult:
strain, asds, parameters = batch

context = (strain, asds)

samples = self.model.sample(
Expand All @@ -147,6 +148,7 @@ def test_step(self, batch, batch_idx) -> AmplfiResult:
parameters = self.scale(parameters, reverse=True)

log_probs = log_probs[mask]

result = self.cast_as_bilby_result(
descaled.cpu().numpy(),
log_probs.cpu().numpy(),
Expand All @@ -166,6 +168,8 @@ def predict_step(self, batch, _):
self.hparams.samples_per_event, context=context
)
log_probs = self.model.log_prob(samples, context)

samples = samples.squeeze(1)
log_probs = log_probs.squeeze(1)
descaled = self.scale(samples, reverse=True)
descaled, mask = self.filter_parameters(descaled)
Expand Down Expand Up @@ -223,7 +227,6 @@ def cast_as_bilby_result(

num_samples = len(posterior)
log_evidence = logsumexp(posterior["log_prob"]) - np.log(num_samples)

r = AmplfiResult(
label="PEModel",
injection_parameters=injection_parameters,
Expand Down