From 5361a883d234ee7f951c0aa84714213d583f9914 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Wed, 27 Aug 2025 10:25:52 +0200 Subject: [PATCH 1/2] Refactor cluster printing This PR picks up some of the remaining work from Elek's recent refactoring of cluster printing. It simplifies the logical flow of contig printing by reducing the number of functions. The logic is more clear about when printing happens: It now separates the cluster writing into two functions: One clusters and writes the associated files while clustering. Another function, used e.g. when reclustering, merely writes the files. This commit also: * Fixes a bunch of type errors, making the code pass typecheck * More thoroughly encodes the state into the type system: i.e. if arguments A and B are needed together, or not at all, they are condensed in a single variable of type Option[tuple[A, B]]. --- vamb/__main__.py | 365 +++++++++++++++++++++++++++------------------- vamb/vambtools.py | 11 +- 2 files changed, 220 insertions(+), 156 deletions(-) diff --git a/vamb/__main__.py b/vamb/__main__.py index b7e8c486..246fc18a 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -14,11 +14,10 @@ from math import isfinite from typing import Optional, Tuple, Union, cast, Callable, Literal, NamedTuple from pathlib import Path -from collections.abc import Sequence +from collections.abc import Sequence, Collection from torch.utils.data import DataLoader from functools import partial from loguru import logger -from typing import TextIO from typing import Iterable from contextlib import nullcontext @@ -1165,15 +1164,78 @@ def try_from_common(cls, common: BinnerCommonOptions): return None -def write_clusters_table( - file_handle: Optional[TextIO], - file_path: Optional[str], - clusters: Iterable[tuple[str, set[str]]], - print_header: bool, -) -> tuple[int, dict[str, set[str]]]: - handle = nullcontext(file_handle) if file_handle else open(file_path, "w") - with handle as file: - return vamb.vambtools.write_clusters(file, clusters, print_header) +def ceil_div(num: int, den: int) -> int: + return -(num // -den) + + +# Similar to cluster_and_write_files, but is given the clusters beforehandm +# e.g. from reclustering. +# So it does not need to be printed while clustering, and there is no clustering +# parameters, nor any clustering metadata +def export_clusters( + binsplitter: vamb.vambtools.BinSplitter, + clusters: Collection[tuple[str, Collection[str]]], + base_clusters_name: str, # e.g. /foo/bar/vae -> /foo/bar/vae_unsplit.tsv + fasta_output: Optional[tuple[FastaOutput, Sequence[str], Sequence[int]]], +) -> None: + begintime = time.time() + + if binsplitter.splitter is None: + context = nullcontext(None) + else: + context = open(base_clusters_name + "_split.tsv", "w") + print(vamb.vambtools.CLUSTERS_HEADER, file=context) + + n_split_clusters = 0 + n_unsplit_clusters = len(clusters) + n_total_contigs = sum(len(cl) for (_, cl) in clusters) + + with ( + open(base_clusters_name + "_unsplit.tsv", "w") as unsplit_clusters_file, + context as maybe_split_file, + ): + print(vamb.vambtools.CLUSTERS_HEADER, file=unsplit_clusters_file) + + for unsplit_bin_name, unsplit_contigs in clusters: + for unsplit_contig in unsplit_contigs: + print(unsplit_bin_name, unsplit_contig, file=unsplit_clusters_file) + + if maybe_split_file is not None: + for split_bin_name, split_members in binsplitter.split_bin( + unsplit_bin_name, unsplit_contigs + ): + n_split_clusters += 1 + + for split_member in split_members: + print( + split_bin_name, + split_member, + sep="\t", + file=unsplit_clusters_file, + ) + + # When done: Log the time + if binsplitter.splitter is not None: + msg = f"\tClustered {n_total_contigs} contigs in {n_split_clusters} split bins ({n_unsplit_clusters} clusters)" + else: + msg = f"\tClustered {n_total_contigs} contigs in {n_unsplit_clusters} unsplit bins" + logger.info(msg) + elapsed = round(time.time() - begintime, 2) + logger.info(f"\tWrote cluster file(s) in {elapsed} seconds.") + + # If FASTA output requested, create those files + if fasta_output is not None: + (fasta_output_struct, sequence_names, sequence_lens) = fasta_output + create_cluster_fasta_files( + fasta_output_struct.bins_dir_to_populate, + clusters, + fasta_output_struct.existing_fasta_path.path, + sequence_lens, + sequence_names, + fasta_output_struct.min_fasta_size, + ) + + return None def cluster_and_write_files( @@ -1210,141 +1272,151 @@ def cluster_and_write_files( rng_seed=seed, ) + # Take only the first `max_clusters` clusters clusters = itertools.islice(cluster_generator, cluster_options.max_clusters) - # Write the cluster metadata to file - with open(Path(base_clusters_name + "_metadata.tsv"), "w") as file: - unsplit_path = Path(base_clusters_name + "_unsplit.tsv") - split_path = Path(base_clusters_name + "_split.tsv") - with ( - open(unsplit_path, "w") as unsplit_clusters_file, - open(split_path, "w") as split_clusters_file, - ): - print( - "name\tradius\tpeak valley ratio\tkind\tbp\tncontigs\tmedoid", file=file - ) - num_contigs = latent.shape[0] - progress_step = num_contigs / 10 - next_reporting_threshold = progress_step - progress = 0 - processed_contigs = 0 - total_unsplit = 0 - total_split = 0 - for i, cluster in enumerate(clusters): - cluster_members = { - sequence_names[cast(int, i)] for i in cluster.members - } - header = i == 0 - n_unsplit_clusters, n_split_clusters = export_binning_results( - fasta_output, - bin_prefix, - binsplitter, - base_clusters_name, - [(str(i), cluster_members)], - sequence_names, - cast(Sequence[int], sequence_lens), - header, - unsplit_clusters_file, - split_clusters_file, - ) + if binsplitter.splitter is None: + context = nullcontext(None) + else: + context = open(base_clusters_name + "_split.tsv", "w") - print( - str(i + 1), - None if cluster.radius is None else round(cluster.radius, 3), - ( - None - if cluster.observed_pvr is None - else round(cluster.observed_pvr, 2) - ), - cluster.kind_str, - sum(sequence_lens[i] for i in cluster.members), - len(cluster_members), - sequence_names[cluster.medoid], - file=file, - sep="\t", - ) - total_unsplit += n_unsplit_clusters - total_split += n_split_clusters - processed_contigs += len(cluster_members) - while processed_contigs >= next_reporting_threshold: - next_reporting_threshold += progress_step - progress += 10 - logger.info(f"{progress}% of contigs clustered") - if processed_contigs == num_contigs: - if binsplitter.splitter is not None: - msg = f"\tClustered {processed_contigs} contigs in {total_split} split bins ({total_unsplit} clusters)" - else: - msg = f"\tClustered {processed_contigs} contigs in {total_unsplit} unsplit bins" - logger.info(msg) - elapsed = round(time.time() - begintime, 2) - logger.info(f"\tWrote cluster file(s) in {elapsed} seconds.") - - if fasta_output is not None: - logger.info( - f"\tWrote {max(total_split, total_unsplit)} bins with {processed_contigs} sequences in {elapsed} seconds." - ) + # The FASTA output requires us to know all clusters - we cannot stream those. + # so we keep them all in memory only if the FASTA output is needed + if fasta_output is None: + maybe_stored_clusters = None + else: + maybe_stored_clusters: Optional[list[tuple[str, list[str]]]] = [] - elapsed = round(time.time() - begintime, 2) - logger.info(f"\tClustered contigs in {elapsed} seconds.\n") + n_processed_contigs = 0 + n_split_clusters = 0 + n_unsplit_clusters = 0 + with ( + open(base_clusters_name + "_metadata.tsv", "w") as metadata_file, + open(base_clusters_name + "_unsplit.tsv", "w") as unsplit_clusters_file, + context as maybe_split_file, + ): + # Print headers + print( + "name\tradius\tpeak valley ratio\tkind\tbp\tncontigs\tmedoid", + file=metadata_file, + ) + print(vamb.vambtools.CLUSTERS_HEADER, file=unsplit_clusters_file) -def export_binning_results( - fasta_output: Optional[FastaOutput], - # If `x` and not None, clusters will be renamed `x` + old_name. - # This is necessary since for the AAE, we may need to write bins - # from three latent spaces into the same directory, and the names - # must not clash. - bin_prefix: Optional[str], - binsplitter: vamb.vambtools.BinSplitter, - base_clusters_name: str, # e.g. /foo/bar/vae -> /foo/bar/vae_unsplit.tsv - clusters: Iterable[tuple[str, set[str]]], - sequence_names: Sequence[str], - sequence_lens: Sequence[int], - to_file: bool = True, - unsplit_clusters_file: Optional[TextIO] = None, - split_clusters_file: Optional[TextIO] = None, -): - # Write unsplit clusters to file - unsplit_path = Path(base_clusters_name + "_unsplit.tsv") - split_path = Path(base_clusters_name + "_split.tsv") + if maybe_split_file is not None: + print(vamb.vambtools.CLUSTERS_HEADER, file=maybe_split_file) - n_unsplit_clusters, _ = write_clusters_table( - unsplit_clusters_file, unsplit_path, clusters, to_file - ) + n_total_contigs = latent.shape[0] + last_decile_printed = 0 + for cluster_index, cluster in enumerate(clusters): + cluster_members = [sequence_names[cast(int, i)] for i in cluster.members] + cluster_name = str(cluster_index + 1) + if bin_prefix is not None: + cluster_name = bin_prefix + cluster_name + + n_processed_contigs += len(cluster_members) + n_unsplit_clusters += 1 + + # Print the clusters to the two cluster files + for member in cluster_members: + # Prefer storing the split clusters if we use a binsplitter + if maybe_stored_clusters is not None and maybe_split_file is None: + maybe_stored_clusters.append((cluster_name, list(cluster_members))) + + print(cluster_name, member, sep="\t", file=unsplit_clusters_file) + + if maybe_split_file is not None: + for split_bin_name, split_members in binsplitter.split_bin( + cluster_name, cluster_members + ): + n_split_clusters += 1 + + if maybe_stored_clusters is not None: + maybe_stored_clusters.append( + (split_bin_name, list(split_members)) + ) + + for split_member in split_members: + print( + split_bin_name, + split_member, + sep="\t", + file=unsplit_clusters_file, + ) + + # Print metadata + print( + cluster_name, + None if cluster.radius is None else round(cluster.radius, 3), + ( + None + if cluster.observed_pvr is None + else round(cluster.observed_pvr, 2) + ), + cluster.kind_str, + sum(sequence_lens[i] for i in cluster.members), + len(cluster_members), + sequence_names[cluster.medoid], + file=metadata_file, + sep="\t", + ) + + # Log each decile + current_decile = ceil_div(10 * n_processed_contigs, n_total_contigs) + for unprinted_decile in range(last_decile_printed + 1, current_decile + 1): + logger.info(f"\t {unprinted_decile * 10:3} % of contigs clustered") + + last_decile_printed = current_decile - # Open unsplit clusters and split them + # When done: Log the time if binsplitter.splitter is not None: - split_clusters = binsplitter.binsplit(clusters) - if bin_prefix is not None: - split_clusters = add_bin_prefix(dict(split_clusters), bin_prefix).items() - clusters_with_prefix = split_clusters - n_split_clusters, _ = write_clusters_table( - split_clusters_file, split_path, split_clusters, to_file - ) + msg = f"\tClustered {n_processed_contigs} contigs in {n_split_clusters} split bins ({n_unsplit_clusters} clusters)" else: - clusters_with_prefix = clusters - if bin_prefix is not None: - clusters_with_prefix = add_bin_prefix(dict(clusters), bin_prefix).items() - n_split_clusters = n_unsplit_clusters + msg = f"\tClustered {n_processed_contigs} contigs in {n_unsplit_clusters} unsplit bins" + logger.info(msg) + elapsed = round(time.time() - begintime, 2) + logger.info(f"\tWrote cluster file(s) in {elapsed} seconds.") - # Write bins, if necessary + # If FASTA output requested, create those files if fasta_output is not None: - filtered_clusters: dict[str, set[str]] = dict() - assert len(sequence_lens) == len(sequence_names) - sizeof = dict(zip(sequence_names, sequence_lens)) - for binname, contigs in clusters_with_prefix: - if sum(sizeof[c] for c in contigs) >= fasta_output.min_fasta_size: - filtered_clusters[binname] = contigs - - with vamb.vambtools.Reader(fasta_output.existing_fasta_path.path) as file: - vamb.vambtools.write_bins( - fasta_output.bins_dir_to_populate, - filtered_clusters, - file, - None, - ) + assert maybe_stored_clusters is not None + create_cluster_fasta_files( + fasta_output.bins_dir_to_populate, + maybe_stored_clusters, + fasta_output.existing_fasta_path.path, + cast(Sequence[int], sequence_lens), + sequence_names, + fasta_output.min_fasta_size, + ) - return n_unsplit_clusters, n_split_clusters + +def create_cluster_fasta_files( + dir_to_populate: Path, + clusters: Iterable[tuple[str, Collection[str]]], + existing_fasta_path: Path, + sequence_lens: Sequence[int], + sequence_names: Sequence[str], + min_bin_size: int, +) -> None: + begintime = time.time() + filtered_clusters: list[tuple[str, list[str]]] = [] + assert len(sequence_lens) == len(sequence_names) + sizeof = dict(zip(sequence_names, sequence_lens)) + for binname, contigs in clusters: + if sum(sizeof[c] for c in contigs) >= min_bin_size: + filtered_clusters.append((binname, list(contigs))) + + with vamb.vambtools.Reader(existing_fasta_path) as file: + vamb.vambtools.write_bins( + dir_to_populate, + filtered_clusters, + file, + None, + ) + elapsed = round(time.time() - begintime, 2) + logger.info( + f"\tWrote clusters above {min_bin_size} bp to FASTA files in {elapsed} seconds.\n" + ) def add_bin_prefix( @@ -1445,15 +1517,6 @@ def run_bin_aae(opt: BinAvambOptions): # We enforce this in the VAEAAEOptions constructor, see comment there # Cluster and output the Y clusters assert opt.common.clustering.max_clusters is None - export_binning_results( - FastaOutput.try_from_common(opt.common), - "y_", - binsplitter=opt.common.output.binsplitter, - base_clusters_name=str(opt.common.general.out_dir.joinpath("aae_y_clusters")), - clusters=clusters_y_dict, - sequence_names=cast(Sequence[str], comp_metadata.identifiers), - sequence_lens=cast(Sequence[int], comp_metadata.lengths), - ) def predict_taxonomy( @@ -1804,27 +1867,27 @@ def run_reclustering(opt: ReclusteringOptions): (str(i), {identifiers[c] for c in cluster}) for i, cluster in enumerate(reclustered_contigs) ] - # for i, cluster in enumerate(reclustered_contigs): - # clusters_dict[str(i)] = {identifiers[c] for c in cluster} if opt.output.min_fasta_output_size is None: fasta_output = None else: assert isinstance(opt.composition.path, FASTAPath) - fasta_output = FastaOutput( + fasta_output_struct = FastaOutput( opt.composition.path, opt.general.out_dir.joinpath("bins"), opt.output.min_fasta_output_size, ) + fasta_output = ( + fasta_output_struct, + cast(Sequence[str], identifiers), + cast(Sequence[int], composition.metadata.lengths), + ) - export_binning_results( - fasta_output, - None, + export_clusters( opt.output.binsplitter, - str(opt.general.out_dir.joinpath("clusters_reclustered")), clusters_dict, - cast(Sequence[str], composition.metadata.identifiers), - cast(Sequence[int], composition.metadata.lengths), + str(opt.general.out_dir.joinpath("clusters_reclustered")), + fasta_output, ) diff --git a/vamb/vambtools.py b/vamb/vambtools.py index 47721b48..b524e39a 100644 --- a/vamb/vambtools.py +++ b/vamb/vambtools.py @@ -9,7 +9,7 @@ import collections as _collections from itertools import zip_longest from hashlib import md5 as _md5 -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Collection from typing import Optional, IO, Union from pathlib import Path from loguru import logger @@ -634,7 +634,7 @@ def create_dir_if_not_existing(path: Path) -> None: def write_bins( directory: Path, - bins: dict[str, set[str]], + bins: Collection[tuple[str, Iterable[str]]], fastaio: Iterable[bytes], maxbins: Optional[int] = 1000, ): @@ -657,8 +657,9 @@ def write_bins( create_dir_if_not_existing(directory) keep: set[str] = set() - for i in bins.values(): - keep.update(i) + for _, contigs in bins: + for contig in contigs: + keep.add(contig) bytes_by_id: dict[str, bytes] = dict() for entry in byte_iterfasta(fastaio, None): @@ -668,7 +669,7 @@ def write_bins( ) # Now actually print all the contigs to files - for binname, contigs in bins.items(): + for binname, contigs in bins: for contig in contigs: byts = bytes_by_id.get(contig) if byts is None: From be4c69f451a6b857c85479998c842c3e7eec135e Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Wed, 27 Aug 2025 10:53:18 +0200 Subject: [PATCH 2/2] Fixup --- test/test_vambtools.py | 10 +++++----- vamb/__main__.py | 26 +++++++++++++------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/test/test_vambtools.py b/test/test_vambtools.py index 0dadcde8..3a0feaeb 100644 --- a/test/test_vambtools.py +++ b/test/test_vambtools.py @@ -537,14 +537,14 @@ def test_bad_params(self): # Too many bins for maxbins with self.assertRaises(ValueError): vamb.vambtools.write_bins( - self.dir, self.bins, self.file, maxbins=self.N_BINS - 1 + self.dir, self.bins.items(), self.file, maxbins=self.N_BINS - 1 ) # Parent does not exist with self.assertRaises(NotADirectoryError): vamb.vambtools.write_bins( pathlib.Path("svogew/foo"), - self.bins, + self.bins.items(), self.file, maxbins=self.N_BINS + 1, ) @@ -554,7 +554,7 @@ def test_bad_params(self): with tempfile.NamedTemporaryFile() as file: vamb.vambtools.write_bins( pathlib.Path(file.name), - self.bins, + self.bins.items(), self.file, maxbins=self.N_BINS + 1, ) @@ -564,14 +564,14 @@ def test_bad_params(self): bins = {k: v.copy() for k, v in self.bins.items()} next(iter(bins.values())).add("a_new_bin_which_does_not_exist") vamb.vambtools.write_bins( - self.dir, bins, self.file, maxbins=self.N_BINS + 1 + self.dir, bins.items(), self.file, maxbins=self.N_BINS + 1 ) def test_round_trip(self): with tempfile.TemporaryDirectory() as dir: vamb.vambtools.write_bins( pathlib.Path(dir), - self.bins, + self.bins.items(), self.file, maxbins=self.N_BINS, ) diff --git a/vamb/__main__.py b/vamb/__main__.py index 246fc18a..173cc7c1 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -1200,19 +1200,19 @@ def export_clusters( for unsplit_contig in unsplit_contigs: print(unsplit_bin_name, unsplit_contig, file=unsplit_clusters_file) - if maybe_split_file is not None: - for split_bin_name, split_members in binsplitter.split_bin( - unsplit_bin_name, unsplit_contigs - ): - n_split_clusters += 1 - - for split_member in split_members: - print( - split_bin_name, - split_member, - sep="\t", - file=unsplit_clusters_file, - ) + if maybe_split_file is not None: + for split_bin_name, split_members in binsplitter.split_bin( + unsplit_bin_name, unsplit_contigs + ): + n_split_clusters += 1 + + for split_member in split_members: + print( + split_bin_name, + split_member, + sep="\t", + file=unsplit_clusters_file, + ) # When done: Log the time if binsplitter.splitter is not None: