Source code for zae_engine.trainer.addons.state_manager

import os
import pickle
import torch
import safetensors.torch
from typing import Type, Union, Optional, Dict

from .core import AddOnBase, T


[docs] class StateManagerAddon(AddOnBase): """ Add-on to manage model, optimizer, and scheduler state. This add-on provides functionality for saving and loading the state of the model, optimizer, and scheduler during training. It supports `.ckpt` and `.safetensor` formats for storing model weights. Parameters ---------- save_path : str Path to the directory where the model, optimizer, and scheduler states will be saved. save_format : str, optional Format to save the model state, either 'ckpt' or 'safetensor'. Default is 'ckpt'. Methods ------- save_state() Save the state of the model, optimizer, and scheduler. load_state() Load the state of the model, optimizer, and scheduler. save_model(filename: str) Save the model's state dictionary to a file. save_optimizer() Save the optimizer's state dictionary. save_scheduler() Save the scheduler's state dictionary. Notes ----- This add-on automatically saves the model state whenever a better state is detected during training based on loss. Examples -------- Adding StateManagerAddon to a Trainer: >>> from zae_engine.trainer import Trainer >>> from zae_engine.trainer.addons import StateManagerAddon >>> MyTrainer = Trainer.add_on(StateManagerAddon) >>> trainer = MyTrainer( >>> model=my_model, >>> device='cuda', >>> optimizer=my_optimizer, >>> scheduler=my_scheduler, >>> save_path='./checkpoints' >>> ) >>> trainer.run(n_epoch=10, loader=train_loader) """ def __init__(self, save_path: str, save_format: str = "ckpt"): self.save_path = save_path self.save_format = save_format
[docs] @classmethod def apply(cls, base_cls: Type[T]) -> Type[T]: class TrainerWithStateManager(base_cls): def __init__(self, *args, **kwargs): self.save_path = kwargs.pop("save_path") self.save_format = kwargs.pop("save_format", "ckpt") super().__init__(*args, **kwargs) if not os.path.exists(self.save_path): os.makedirs(self.save_path) def check_better(self, cur_epoch: int, cur_loss: float) -> bool: is_better = super().check_better(cur_epoch, cur_loss) if is_better: self.save_state() return is_better def save_model(self, filename: str) -> None: if self.save_format == "ckpt": torch.save(self.model.state_dict(), filename) elif self.save_format == "safetensor": safetensors.torch.save_file(self.model.state_dict(), filename) def save_optimizer(self) -> None: with open(os.path.join(self.save_path, "optimizer.zae"), "wb") as f: pickle.dump(self.optimizer.state_dict(), f) def save_scheduler(self) -> None: with open(os.path.join(self.save_path, "scheduler.zae"), "wb") as f: pickle.dump(self.scheduler.state_dict(), f) def save_result(self) -> None: # TODO: Save result for validation set (or test) pass def save_state(self) -> None: """ Save the state of the model, optimizer, and scheduler. """ self.save_model(os.path.join(self.save_path, f"model.{self.save_format}")) self.save_optimizer() self.save_scheduler() def load_state(self) -> None: """ Load the state of the model, optimizer, and scheduler. """ model_file = os.path.join(self.save_path, f"model.{self.save_format}") optimizer_file = os.path.join(self.save_path, "optimizer.zae") scheduler_file = os.path.join(self.save_path, "scheduler.zae") if self.save_format == "ckpt" and os.path.exists(model_file): self.model.load_state_dict(torch.load(model_file)) elif self.save_format == "safetensor" and os.path.exists(model_file): self.model.load_state_dict(safetensors.torch.load_file(model_file)) if os.path.exists(optimizer_file): with open(optimizer_file, "rb") as f: self.optimizer.load_state_dict(pickle.load(f)) if os.path.exists(scheduler_file): with open(scheduler_file, "rb") as f: self.scheduler.load_state_dict(pickle.load(f)) return TrainerWithStateManager