diff --git a/pymarlin/utils/checkpointer/checkpoint_utils.py b/pymarlin/utils/checkpointer/checkpoint_utils.py index f44c79c..42c711e 100644 --- a/pymarlin/utils/checkpointer/checkpoint_utils.py +++ b/pymarlin/utils/checkpointer/checkpoint_utils.py @@ -4,7 +4,7 @@ import os import re -from typing import Optional, Dict +from typing import Optional, Dict, Tuple, Callable from abc import ABC, abstractmethod from operator import itemgetter from dataclasses import dataclass @@ -291,3 +291,101 @@ def check_mk_dir(self, dirpath: str) -> None: os.makedirs(dirpath) assert os.path.isdir(dirpath), "supplied checkpoint dirpath "\ "is not a directory" + + +@dataclass +class BestCheckpointerArguments(DefaultCheckpointerArguments): + """Additional arguments for checkpointer + + metric_name: name of metric where minimal is defined as best. Must be a registered buffer in module interface + save_intermediate_checkpoints: whether to produce a checkpointer every epoch in addition to latest and best. + load_best: whether to load best or latest checkpoint. Default behavior is to load latest. + """ + metric_name: str = "val_perplexity" + init_metric_val: Optional[float] = None + criteria: Optional[Tuple[str, Callable]] = "min" + save_intermediate_checkpoints: bool = False # not usually necessary in practice + load_best: bool = False # default to load latest + + +class BestCheckpointer(DefaultCheckpointer): + """ + Saves best and latest checkpoint. Best checkpoint is defined as the smallest value of a given parameter in the + module interface. Therefore this checkpointer works by relying on the parameter defined in metric_name existing as a + single value. By default it checks "val_perplexity" which is a registered buffer in `AbstractUserMessageReplyModule` + that gets updated after every call to `on_end_val_epoch`. + """ + def __init__(self, args: BestCheckpointerArguments): + super().__init__(args) + self.best_checkpoint_name = f"{self.args.file_prefix}_best_checkpoint.{self.args.file_ext}" + self.latest_checkpoint_name = f"{self.args.file_prefix}_latest_checkpoint.{self.args.file_ext}" + if self.args.criteria == 'min': + self.criteria_func = lambda new, old: new < old + self.best_metric = float('inf') + elif self.args.criteria == 'max': + self.criteria_func = lambda new, old: new > old + self.best_metric = -float('inf') + else: + self.criteria_func = self.args.criteria + self.best_metric = self.args.init_metric_value + + if self.args.init_metric_value is not None: + self.best_metric = self.args.init_metric_value + + def save(self, checkpoint_state: Checkpoint, index: int, force=False) -> str: + """ + Saves trainer, optimizer, and module interface state. + + Args: + checkpoint_state: instance of `Checkpoint` which contains trainer, optimizer, and module interface state + index: current epoch number + force: whether to force a save even if period of checkpointing does not line up + + Returns: + list of paths checkpoint state was saved to + """ + paths = [] + if self.args.save_intermediate_checkpoints: + paths.append(super().save(checkpoint_state, index, force)) + if self.args.checkpoint: + # TODO grab this from logged metrics instead, checkpoint state is hacky + self.logger.debug(f"Available metrics {checkpoint_state.module_interface_state.keys()}") + metric = float(checkpoint_state.module_interface_state[self.args.metric_name]) + self.logger.info(f"epoch {index}: metric {self.args.metric_name}={metric}, best score={self.best_metric}") + + # optiionally save best + if self.criteria_func(metric, self.best_metric): + self.best_metric = metric + best_path = os.path.join(self.args.save_dir, self.best_checkpoint_name) + torch.save(checkpoint_state.__dict__, best_path) + paths.append(best_path) + + # save latest + latest_path = os.path.join(self.args.save_dir, self.latest_checkpoint_name) + torch.save(checkpoint_state.__dict__, latest_path) + paths.append(latest_path) + return paths + + def load(self) -> Checkpoint: + """ + Optionally loads a checkpoint from a given directory. Either loads a specified filename, the best checkpoint, or + the latest checkpoint. Raises a `ValueError` upon failure to load checkpoint. + + Returns: + An instance of `Checkpoint` + """ + if self.args.load_dir: + if self.args.load_filename: + load_path = os.path.join(self.args.load_dir, self.args.load_filename) + elif self.args.load_best: + load_path = os.path.join(self.args.load_dir, self.best_checkpoint_name) + else: + load_path = os.path.join(self.args.load_dir, self.latest_checkpoint_name) + + # TODO how to set best metric to match loaded checkpoint? + self.logger.debug(f"loading checkpoint from {load_path}") + checkpoint = torch.load(load_path, map_location=torch.device('cpu')) + self.logger.debug('Checkpoint loaded') + return Checkpoint(**checkpoint) + + return Checkpoint()