diff --git a/test/test_vambtools.py b/test/test_vambtools.py index 0dadcde8..3a0feaeb 100644 --- a/test/test_vambtools.py +++ b/test/test_vambtools.py @@ -537,14 +537,14 @@ def test_bad_params(self): # Too many bins for maxbins with self.assertRaises(ValueError): vamb.vambtools.write_bins( - self.dir, self.bins, self.file, maxbins=self.N_BINS - 1 + self.dir, self.bins.items(), self.file, maxbins=self.N_BINS - 1 ) # Parent does not exist with self.assertRaises(NotADirectoryError): vamb.vambtools.write_bins( pathlib.Path("svogew/foo"), - self.bins, + self.bins.items(), self.file, maxbins=self.N_BINS + 1, ) @@ -554,7 +554,7 @@ def test_bad_params(self): with tempfile.NamedTemporaryFile() as file: vamb.vambtools.write_bins( pathlib.Path(file.name), - self.bins, + self.bins.items(), self.file, maxbins=self.N_BINS + 1, ) @@ -564,14 +564,14 @@ def test_bad_params(self): bins = {k: v.copy() for k, v in self.bins.items()} next(iter(bins.values())).add("a_new_bin_which_does_not_exist") vamb.vambtools.write_bins( - self.dir, bins, self.file, maxbins=self.N_BINS + 1 + self.dir, bins.items(), self.file, maxbins=self.N_BINS + 1 ) def test_round_trip(self): with tempfile.TemporaryDirectory() as dir: vamb.vambtools.write_bins( pathlib.Path(dir), - self.bins, + self.bins.items(), self.file, maxbins=self.N_BINS, ) diff --git a/vamb/__main__.py b/vamb/__main__.py index b7e8c486..173cc7c1 100755 --- a/vamb/__main__.py +++ b/vamb/__main__.py @@ -14,11 +14,10 @@ from math import isfinite from typing import Optional, Tuple, Union, cast, Callable, Literal, NamedTuple from pathlib import Path -from collections.abc import Sequence +from collections.abc import Sequence, Collection from torch.utils.data import DataLoader from functools import partial from loguru import logger -from typing import TextIO from typing import Iterable from contextlib import nullcontext @@ -1165,15 +1164,78 @@ def try_from_common(cls, common: BinnerCommonOptions): return None -def write_clusters_table( - file_handle: Optional[TextIO], - file_path: Optional[str], - clusters: Iterable[tuple[str, set[str]]], - print_header: bool, -) -> tuple[int, dict[str, set[str]]]: - handle = nullcontext(file_handle) if file_handle else open(file_path, "w") - with handle as file: - return vamb.vambtools.write_clusters(file, clusters, print_header) +def ceil_div(num: int, den: int) -> int: + return -(num // -den) + + +# Similar to cluster_and_write_files, but is given the clusters beforehandm +# e.g. from reclustering. +# So it does not need to be printed while clustering, and there is no clustering +# parameters, nor any clustering metadata +def export_clusters( + binsplitter: vamb.vambtools.BinSplitter, + clusters: Collection[tuple[str, Collection[str]]], + base_clusters_name: str, # e.g. /foo/bar/vae -> /foo/bar/vae_unsplit.tsv + fasta_output: Optional[tuple[FastaOutput, Sequence[str], Sequence[int]]], +) -> None: + begintime = time.time() + + if binsplitter.splitter is None: + context = nullcontext(None) + else: + context = open(base_clusters_name + "_split.tsv", "w") + print(vamb.vambtools.CLUSTERS_HEADER, file=context) + + n_split_clusters = 0 + n_unsplit_clusters = len(clusters) + n_total_contigs = sum(len(cl) for (_, cl) in clusters) + + with ( + open(base_clusters_name + "_unsplit.tsv", "w") as unsplit_clusters_file, + context as maybe_split_file, + ): + print(vamb.vambtools.CLUSTERS_HEADER, file=unsplit_clusters_file) + + for unsplit_bin_name, unsplit_contigs in clusters: + for unsplit_contig in unsplit_contigs: + print(unsplit_bin_name, unsplit_contig, file=unsplit_clusters_file) + + if maybe_split_file is not None: + for split_bin_name, split_members in binsplitter.split_bin( + unsplit_bin_name, unsplit_contigs + ): + n_split_clusters += 1 + + for split_member in split_members: + print( + split_bin_name, + split_member, + sep="\t", + file=unsplit_clusters_file, + ) + + # When done: Log the time + if binsplitter.splitter is not None: + msg = f"\tClustered {n_total_contigs} contigs in {n_split_clusters} split bins ({n_unsplit_clusters} clusters)" + else: + msg = f"\tClustered {n_total_contigs} contigs in {n_unsplit_clusters} unsplit bins" + logger.info(msg) + elapsed = round(time.time() - begintime, 2) + logger.info(f"\tWrote cluster file(s) in {elapsed} seconds.") + + # If FASTA output requested, create those files + if fasta_output is not None: + (fasta_output_struct, sequence_names, sequence_lens) = fasta_output + create_cluster_fasta_files( + fasta_output_struct.bins_dir_to_populate, + clusters, + fasta_output_struct.existing_fasta_path.path, + sequence_lens, + sequence_names, + fasta_output_struct.min_fasta_size, + ) + + return None def cluster_and_write_files( @@ -1210,141 +1272,151 @@ def cluster_and_write_files( rng_seed=seed, ) + # Take only the first `max_clusters` clusters clusters = itertools.islice(cluster_generator, cluster_options.max_clusters) - # Write the cluster metadata to file - with open(Path(base_clusters_name + "_metadata.tsv"), "w") as file: - unsplit_path = Path(base_clusters_name + "_unsplit.tsv") - split_path = Path(base_clusters_name + "_split.tsv") - with ( - open(unsplit_path, "w") as unsplit_clusters_file, - open(split_path, "w") as split_clusters_file, - ): - print( - "name\tradius\tpeak valley ratio\tkind\tbp\tncontigs\tmedoid", file=file - ) - num_contigs = latent.shape[0] - progress_step = num_contigs / 10 - next_reporting_threshold = progress_step - progress = 0 - processed_contigs = 0 - total_unsplit = 0 - total_split = 0 - for i, cluster in enumerate(clusters): - cluster_members = { - sequence_names[cast(int, i)] for i in cluster.members - } - header = i == 0 - n_unsplit_clusters, n_split_clusters = export_binning_results( - fasta_output, - bin_prefix, - binsplitter, - base_clusters_name, - [(str(i), cluster_members)], - sequence_names, - cast(Sequence[int], sequence_lens), - header, - unsplit_clusters_file, - split_clusters_file, - ) + if binsplitter.splitter is None: + context = nullcontext(None) + else: + context = open(base_clusters_name + "_split.tsv", "w") - print( - str(i + 1), - None if cluster.radius is None else round(cluster.radius, 3), - ( - None - if cluster.observed_pvr is None - else round(cluster.observed_pvr, 2) - ), - cluster.kind_str, - sum(sequence_lens[i] for i in cluster.members), - len(cluster_members), - sequence_names[cluster.medoid], - file=file, - sep="\t", - ) - total_unsplit += n_unsplit_clusters - total_split += n_split_clusters - processed_contigs += len(cluster_members) - while processed_contigs >= next_reporting_threshold: - next_reporting_threshold += progress_step - progress += 10 - logger.info(f"{progress}% of contigs clustered") - if processed_contigs == num_contigs: - if binsplitter.splitter is not None: - msg = f"\tClustered {processed_contigs} contigs in {total_split} split bins ({total_unsplit} clusters)" - else: - msg = f"\tClustered {processed_contigs} contigs in {total_unsplit} unsplit bins" - logger.info(msg) - elapsed = round(time.time() - begintime, 2) - logger.info(f"\tWrote cluster file(s) in {elapsed} seconds.") - - if fasta_output is not None: - logger.info( - f"\tWrote {max(total_split, total_unsplit)} bins with {processed_contigs} sequences in {elapsed} seconds." - ) + # The FASTA output requires us to know all clusters - we cannot stream those. + # so we keep them all in memory only if the FASTA output is needed + if fasta_output is None: + maybe_stored_clusters = None + else: + maybe_stored_clusters: Optional[list[tuple[str, list[str]]]] = [] - elapsed = round(time.time() - begintime, 2) - logger.info(f"\tClustered contigs in {elapsed} seconds.\n") + n_processed_contigs = 0 + n_split_clusters = 0 + n_unsplit_clusters = 0 + with ( + open(base_clusters_name + "_metadata.tsv", "w") as metadata_file, + open(base_clusters_name + "_unsplit.tsv", "w") as unsplit_clusters_file, + context as maybe_split_file, + ): + # Print headers + print( + "name\tradius\tpeak valley ratio\tkind\tbp\tncontigs\tmedoid", + file=metadata_file, + ) + print(vamb.vambtools.CLUSTERS_HEADER, file=unsplit_clusters_file) -def export_binning_results( - fasta_output: Optional[FastaOutput], - # If `x` and not None, clusters will be renamed `x` + old_name. - # This is necessary since for the AAE, we may need to write bins - # from three latent spaces into the same directory, and the names - # must not clash. - bin_prefix: Optional[str], - binsplitter: vamb.vambtools.BinSplitter, - base_clusters_name: str, # e.g. /foo/bar/vae -> /foo/bar/vae_unsplit.tsv - clusters: Iterable[tuple[str, set[str]]], - sequence_names: Sequence[str], - sequence_lens: Sequence[int], - to_file: bool = True, - unsplit_clusters_file: Optional[TextIO] = None, - split_clusters_file: Optional[TextIO] = None, -): - # Write unsplit clusters to file - unsplit_path = Path(base_clusters_name + "_unsplit.tsv") - split_path = Path(base_clusters_name + "_split.tsv") + if maybe_split_file is not None: + print(vamb.vambtools.CLUSTERS_HEADER, file=maybe_split_file) - n_unsplit_clusters, _ = write_clusters_table( - unsplit_clusters_file, unsplit_path, clusters, to_file - ) + n_total_contigs = latent.shape[0] + last_decile_printed = 0 + for cluster_index, cluster in enumerate(clusters): + cluster_members = [sequence_names[cast(int, i)] for i in cluster.members] + cluster_name = str(cluster_index + 1) + if bin_prefix is not None: + cluster_name = bin_prefix + cluster_name + + n_processed_contigs += len(cluster_members) + n_unsplit_clusters += 1 + + # Print the clusters to the two cluster files + for member in cluster_members: + # Prefer storing the split clusters if we use a binsplitter + if maybe_stored_clusters is not None and maybe_split_file is None: + maybe_stored_clusters.append((cluster_name, list(cluster_members))) + + print(cluster_name, member, sep="\t", file=unsplit_clusters_file) + + if maybe_split_file is not None: + for split_bin_name, split_members in binsplitter.split_bin( + cluster_name, cluster_members + ): + n_split_clusters += 1 + + if maybe_stored_clusters is not None: + maybe_stored_clusters.append( + (split_bin_name, list(split_members)) + ) + + for split_member in split_members: + print( + split_bin_name, + split_member, + sep="\t", + file=unsplit_clusters_file, + ) + + # Print metadata + print( + cluster_name, + None if cluster.radius is None else round(cluster.radius, 3), + ( + None + if cluster.observed_pvr is None + else round(cluster.observed_pvr, 2) + ), + cluster.kind_str, + sum(sequence_lens[i] for i in cluster.members), + len(cluster_members), + sequence_names[cluster.medoid], + file=metadata_file, + sep="\t", + ) + + # Log each decile + current_decile = ceil_div(10 * n_processed_contigs, n_total_contigs) + for unprinted_decile in range(last_decile_printed + 1, current_decile + 1): + logger.info(f"\t {unprinted_decile * 10:3} % of contigs clustered") + + last_decile_printed = current_decile - # Open unsplit clusters and split them + # When done: Log the time if binsplitter.splitter is not None: - split_clusters = binsplitter.binsplit(clusters) - if bin_prefix is not None: - split_clusters = add_bin_prefix(dict(split_clusters), bin_prefix).items() - clusters_with_prefix = split_clusters - n_split_clusters, _ = write_clusters_table( - split_clusters_file, split_path, split_clusters, to_file - ) + msg = f"\tClustered {n_processed_contigs} contigs in {n_split_clusters} split bins ({n_unsplit_clusters} clusters)" else: - clusters_with_prefix = clusters - if bin_prefix is not None: - clusters_with_prefix = add_bin_prefix(dict(clusters), bin_prefix).items() - n_split_clusters = n_unsplit_clusters + msg = f"\tClustered {n_processed_contigs} contigs in {n_unsplit_clusters} unsplit bins" + logger.info(msg) + elapsed = round(time.time() - begintime, 2) + logger.info(f"\tWrote cluster file(s) in {elapsed} seconds.") - # Write bins, if necessary + # If FASTA output requested, create those files if fasta_output is not None: - filtered_clusters: dict[str, set[str]] = dict() - assert len(sequence_lens) == len(sequence_names) - sizeof = dict(zip(sequence_names, sequence_lens)) - for binname, contigs in clusters_with_prefix: - if sum(sizeof[c] for c in contigs) >= fasta_output.min_fasta_size: - filtered_clusters[binname] = contigs - - with vamb.vambtools.Reader(fasta_output.existing_fasta_path.path) as file: - vamb.vambtools.write_bins( - fasta_output.bins_dir_to_populate, - filtered_clusters, - file, - None, - ) + assert maybe_stored_clusters is not None + create_cluster_fasta_files( + fasta_output.bins_dir_to_populate, + maybe_stored_clusters, + fasta_output.existing_fasta_path.path, + cast(Sequence[int], sequence_lens), + sequence_names, + fasta_output.min_fasta_size, + ) - return n_unsplit_clusters, n_split_clusters + +def create_cluster_fasta_files( + dir_to_populate: Path, + clusters: Iterable[tuple[str, Collection[str]]], + existing_fasta_path: Path, + sequence_lens: Sequence[int], + sequence_names: Sequence[str], + min_bin_size: int, +) -> None: + begintime = time.time() + filtered_clusters: list[tuple[str, list[str]]] = [] + assert len(sequence_lens) == len(sequence_names) + sizeof = dict(zip(sequence_names, sequence_lens)) + for binname, contigs in clusters: + if sum(sizeof[c] for c in contigs) >= min_bin_size: + filtered_clusters.append((binname, list(contigs))) + + with vamb.vambtools.Reader(existing_fasta_path) as file: + vamb.vambtools.write_bins( + dir_to_populate, + filtered_clusters, + file, + None, + ) + elapsed = round(time.time() - begintime, 2) + logger.info( + f"\tWrote clusters above {min_bin_size} bp to FASTA files in {elapsed} seconds.\n" + ) def add_bin_prefix( @@ -1445,15 +1517,6 @@ def run_bin_aae(opt: BinAvambOptions): # We enforce this in the VAEAAEOptions constructor, see comment there # Cluster and output the Y clusters assert opt.common.clustering.max_clusters is None - export_binning_results( - FastaOutput.try_from_common(opt.common), - "y_", - binsplitter=opt.common.output.binsplitter, - base_clusters_name=str(opt.common.general.out_dir.joinpath("aae_y_clusters")), - clusters=clusters_y_dict, - sequence_names=cast(Sequence[str], comp_metadata.identifiers), - sequence_lens=cast(Sequence[int], comp_metadata.lengths), - ) def predict_taxonomy( @@ -1804,27 +1867,27 @@ def run_reclustering(opt: ReclusteringOptions): (str(i), {identifiers[c] for c in cluster}) for i, cluster in enumerate(reclustered_contigs) ] - # for i, cluster in enumerate(reclustered_contigs): - # clusters_dict[str(i)] = {identifiers[c] for c in cluster} if opt.output.min_fasta_output_size is None: fasta_output = None else: assert isinstance(opt.composition.path, FASTAPath) - fasta_output = FastaOutput( + fasta_output_struct = FastaOutput( opt.composition.path, opt.general.out_dir.joinpath("bins"), opt.output.min_fasta_output_size, ) + fasta_output = ( + fasta_output_struct, + cast(Sequence[str], identifiers), + cast(Sequence[int], composition.metadata.lengths), + ) - export_binning_results( - fasta_output, - None, + export_clusters( opt.output.binsplitter, - str(opt.general.out_dir.joinpath("clusters_reclustered")), clusters_dict, - cast(Sequence[str], composition.metadata.identifiers), - cast(Sequence[int], composition.metadata.lengths), + str(opt.general.out_dir.joinpath("clusters_reclustered")), + fasta_output, ) diff --git a/vamb/vambtools.py b/vamb/vambtools.py index 47721b48..b524e39a 100644 --- a/vamb/vambtools.py +++ b/vamb/vambtools.py @@ -9,7 +9,7 @@ import collections as _collections from itertools import zip_longest from hashlib import md5 as _md5 -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Collection from typing import Optional, IO, Union from pathlib import Path from loguru import logger @@ -634,7 +634,7 @@ def create_dir_if_not_existing(path: Path) -> None: def write_bins( directory: Path, - bins: dict[str, set[str]], + bins: Collection[tuple[str, Iterable[str]]], fastaio: Iterable[bytes], maxbins: Optional[int] = 1000, ): @@ -657,8 +657,9 @@ def write_bins( create_dir_if_not_existing(directory) keep: set[str] = set() - for i in bins.values(): - keep.update(i) + for _, contigs in bins: + for contig in contigs: + keep.add(contig) bytes_by_id: dict[str, bytes] = dict() for entry in byte_iterfasta(fastaio, None): @@ -668,7 +669,7 @@ def write_bins( ) # Now actually print all the contigs to files - for binname, contigs in bins.items(): + for binname, contigs in bins: for contig in contigs: byts = bytes_by_id.get(contig) if byts is None: