Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ jobs:
# tests the libraries that our projects depend on
library-ci-tests:
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
permissions:
packages: read
container:
Expand Down
5 changes: 3 additions & 2 deletions libs/trainer/deepclean/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def train(
batch_size=batch_size,
chunk_length=chunk_length,
num_chunks=num_chunks,
shuffle=False,
shuffle=True,
device=device,
)

Expand Down Expand Up @@ -423,6 +423,7 @@ def train(

# generate some analyses of our model
logging.info("Performing post-hoc analysis on trained model")
train_data.batch_size = 32
gradients, coherences = analyze_model(train_data, model, sample_rate)
history.update(
{
Expand Down Expand Up @@ -452,4 +453,4 @@ def train(
nn = PrePostDeepClean(model)
nn.fit(X, y)
torch.save(nn.state_dict(), weights_path)
# return history
return history
39 changes: 39 additions & 0 deletions projects/sandbox/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,45 @@ verbose = "${base.verbose}"
# architecture subcommands
commands.autoencoder = "${base.autoencoder}"

[tool.typeo.scripts.search]
output_directory = "${PROJECT_DIR}"
data_path = "${DATA_DIR}"

# data parameters
channels = "${base.channels}"
t0 = "${base.train_t0}"
duration = "${base.train_duration}"
valid_frac = 0.25

# data loading parameters
sample_rate = "${base.sample_rate}"
kernel_length = 8
kernel_stride = 0.25
chunk_length = 0

# preprocessing parameters
freq_low = "${base.freq_low}"
freq_high = "${base.freq_high}"
filter_order = 8

# optimization parameters
batch_size = 64
max_epochs = 500
num_trials = 10
patience = 8
factor = 0.2
early_stop = 50

# criterion parameters
fftlength = 2
overlap = 1
alpha = 1.0

# miscellaneous parameters
profile = false
verbose = "${base.verbose}"


[tool.typeo.scripts.export-model]
repository_directory = "${MODEL_REPOSITORY}"
output_directory = "${PROJECT_DIR}"
Expand Down
Loading