1class CheckpointEveryEpoch(pl.Callback):
2 def __init__(self, start_epoc, save_path,):
3 self.start_epoc = start_epoc
4 self.file_path = save_path
5
6 def on_epoch_end(self, trainer: pl.Trainer, _):
7 """ Check if we should save a checkpoint after every train epoch """
8 epoch = trainer.current_epoch
9 if epoch >= self.start_epoc:
10 ckpt_path = f"{self.save_path}_e{epoch}.ckpt"
11 trainer.save_checkpoint(ckpt_path)
12
13
14trainer = Trainer(callbacks=[CheckpointEveryEpoch(2, args.save_path)]
15) # after 2 epoch start to saving ckpts