pytorch lightning save checkpoint every epoch

Solutions on MaxInterview for pytorch lightning save checkpoint every epoch by the best coders in the world

showing results for - "pytorch lightning save checkpoint every epoch"
Serena
23 Aug 2020
1class CheckpointEveryEpoch(pl.Callback):
2    def __init__(self, start_epoc, save_path,):
3        self.start_epoc = start_epoc
4        self.file_path = save_path
5
6    def on_epoch_end(self, trainer: pl.Trainer, _):
7        """ Check if we should save a checkpoint after every train epoch """
8        epoch = trainer.current_epoch
9        if epoch >= self.start_epoc:
10            ckpt_path = f"{self.save_path}_e{epoch}.ckpt"
11            trainer.save_checkpoint(ckpt_path)
12            
13            
14trainer = Trainer(callbacks=[CheckpointEveryEpoch(2, args.save_path)]
15)  # after 2 epoch start to saving ckpts