From 59c80e1d1950fd387c322fddfd43e5286d023362 Mon Sep 17 00:00:00 2001 From: Nick Youngblut Date: Wed, 11 Jun 2025 08:22:10 -0700 Subject: [PATCH] Fix typos and minor bugs --- src/cell_load/config.py | 1 - src/cell_load/data_modules/perturbation_dataloader.py | 2 +- src/cell_load/dataset/_perturbation.py | 4 ++-- src/cell_load/utils/data_utils.py | 10 +++++++--- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/src/cell_load/config.py b/src/cell_load/config.py index 58e7f5d..ddc430e 100644 --- a/src/cell_load/config.py +++ b/src/cell_load/config.py @@ -70,7 +70,6 @@ def get_fewshot_celltypes(self, dataset: str) -> dict[str, dict[str, list[str]]] if key.startswith(f"{dataset}."): celltype = key.split(".", 1)[1] result[celltype] = pert_config - print(dataset, celltype, {k: len(v) for k, v in pert_config.items()}) return result def validate(self) -> None: diff --git a/src/cell_load/data_modules/perturbation_dataloader.py b/src/cell_load/data_modules/perturbation_dataloader.py index 08e625b..06480e1 100644 --- a/src/cell_load/data_modules/perturbation_dataloader.py +++ b/src/cell_load/data_modules/perturbation_dataloader.py @@ -59,7 +59,7 @@ def __init__( num_workers: Num workers for PyTorch DataLoader few_shot_percent: Fraction of data to use for few-shot tasks random_seed: For reproducible splits & sampling - embed_key: Embedding key or matrix in the H5 file to use for feauturizing cells + embed_key: Embedding key or matrix in the H5 file to use for featurizing cells output_space: The output space for model predictions (gene or latent, which uses embed_key) basal_mapping_strategy: One of {"batch","random","nearest","ot"} n_basal_samples: Number of control cells to sample per perturbed cell diff --git a/src/cell_load/dataset/_perturbation.py b/src/cell_load/dataset/_perturbation.py index 819502f..5bebd4e 100644 --- a/src/cell_load/dataset/_perturbation.py +++ b/src/cell_load/dataset/_perturbation.py @@ -456,8 +456,8 @@ def collate_fn(batch, int_counts=False): if has_ctrl_cell_counts: ctrl_cell_counts = torch.stack(ctrl_cell_counts_list) - is_discrete = suspected_discrete_torch(pert_cell_counts) - is_log = suspected_log_torch(pert_cell_counts) + is_discrete = suspected_discrete_torch(ctrl_cell_counts) + is_log = suspected_log_torch(ctrl_cell_counts) already_logged = (not is_discrete) and is_log if already_logged: # counts are already log transformed diff --git a/src/cell_load/utils/data_utils.py b/src/cell_load/utils/data_utils.py index c7a030b..64d322a 100644 --- a/src/cell_load/utils/data_utils.py +++ b/src/cell_load/utils/data_utils.py @@ -282,8 +282,7 @@ def is_on_target_knockdown( return False if target_gene not in adata.var_names: - print(f"Gene {target_gene!r} not found in `adata.var_names`.") - return 1 + raise KeyError(f"Gene {target_gene!r} not found in `adata.var_names`.") gene_idx = adata.var_names.get_loc(target_gene) X = adata.layers[layer] if layer is not None else adata.X @@ -396,7 +395,9 @@ def filter_on_target_knockdown( return adata_[keep_mask] -def set_var_index_to_col(adata: anndata.AnnData, col: str = "col", copy=True) -> None: +def set_var_index_to_col( + adata: anndata.AnnData, col: str = "col", copy: bool = True +) -> anndata.AnnData: """ Set `adata.var` index to the values in the specified column, allowing non-unique indices. @@ -412,6 +413,9 @@ def set_var_index_to_col(adata: anndata.AnnData, col: str = "col", copy=True) -> KeyError If the specified column does not exist in `adata.var`. """ + if copy: + adata = adata.copy() + if col not in adata.var.columns: raise KeyError(f"Column {col!r} not found in adata.var.")