Source code for zae_engine.trainer.addons.mix_precision

from typing import Type, Union

import torch
from torch.cuda.amp import autocast, GradScaler

from .core import AddOnBase
from .._trainer import T


[docs] class PrecisionMixerAddon(AddOnBase): """ Add-on for mixed precision training. This add-on enables mixed precision training to improve computational efficiency and reduce memory usage. It supports automatic precision selection or user-defined precision settings (e.g., 'fp32', 'fp16', 'bf16'). Parameters ---------- precision : Union[str, list], optional The precision setting for training. Default is "auto". - "auto": Automatically selects the best precision based on hardware capabilities. - "fp32": Uses full precision (default in PyTorch). - "fp16": Uses half precision for accelerated computation. - "bf16": Uses Brain Float 16 precision for supported hardware. - List: Specifies a priority order for precision (e.g., ["bf16", "fp16"]). Methods ------- run_batch(batch, **kwargs) Override the batch processing method to apply mixed precision. Notes ----- - Mixed precision improves training speed and reduces memory usage by performing certain operations in lower precision (e.g., FP16) while maintaining stability in others (e.g., FP32 for loss calculation). - Automatically handles loss scaling via `torch.cuda.amp.GradScaler` to prevent overflow issues when using FP16. Examples -------- Using PrecisionMixerAddon with auto precision: >>> from zae_engine.trainer import Trainer >>> from zae_engine.trainer.addons import PrecisionMixerAddon >>> MyTrainer = Trainer.add_on(PrecisionMixerAddon) >>> trainer = MyTrainer( >>> model=my_model, >>> device='cuda', >>> optimizer=my_optimizer, >>> scheduler=my_scheduler, >>> precision='auto' # Automatically selects best precision >>> ) >>> trainer.run(n_epoch=10, loader=train_loader, valid_loader=valid_loader) Using a priority list for precision: >>> trainer = MyTrainer( >>> model=my_model, >>> device='cuda', >>> optimizer=my_optimizer, >>> scheduler=my_scheduler, >>> precision=["bf16", "fp16"] # Tries bf16 first, falls back to fp16 >>> ) >>> trainer.run(n_epoch=10, loader=train_loader) """
[docs] @classmethod def apply(cls, base_cls: Type[T]) -> Type[T]: class PrecisionTrainer(base_cls): def __init__(self, *args, precision: Union[str, list] = "auto", **kwargs): super().__init__(*args, **kwargs) # Check if all devices are CPU if all("cpu" in str(device) for device in self.device): raise ValueError("PrecisionMixerAddon requires at least one GPU device to function.") # Determine precision settings self.precision = self._determine_precision(precision) self.scaler = GradScaler() if "fp16" in self.precision else None def _determine_precision(self, precision): """ Determine the appropriate precision settings based on user input or hardware capabilities. Parameters ---------- precision : Union[str, list] The user-specified precision setting. - "auto": Automatically selects the best precision based on the GPU's hardware capabilities. - str: A single precision mode (e.g., "fp32", "fp16", "bf16"). - list: A priority order for precision modes (e.g., ["bf16", "fp16"]). Returns ------- list A list of supported precision modes, ordered by priority. Raises ------ ValueError If the `precision` argument is not a valid string or list. Notes ----- - When "auto" is specified, the method checks the hardware's compute capability to determine the best available precision (e.g., "bf16" on NVIDIA Ampere GPUs, otherwise "fp16"). - If multiple precision modes are specified in a list, the method validates their compatibility with the hardware and prioritizes them as given. Examples -------- Automatically select the best precision: >>> addon._determine_precision("auto") ['bf16', 'fp16'] Use a specific precision: >>> addon._determine_precision("fp32") ['fp32'] Provide a priority list: >>> addon._determine_precision(["bf16", "fp16"]) ['bf16', 'fp16'] """ if precision == "auto": if torch.cuda.is_available() and torch.cuda.get_device_capability(0)[0] >= 8: return ["bf16", "fp16"] # Take BF16 first @ over `Ampere` elif torch.cuda.is_available(): return ["fp16"] else: return ["fp32"] elif isinstance(precision, str): return [precision] elif isinstance(precision, list): return precision else: raise ValueError("Invalid precision setting.") def run_batch(self, batch, **kwargs): """ Override the batch processing method to enable mixed precision training. Parameters ---------- batch : Union[tuple, dict] The batch of data to process. kwargs : dict Additional arguments for batch processing. Notes ----- - If precision is set to "fp16" or "bf16", this method uses `torch.cuda.amp.autocast` to enable mixed precision operations for the forward pass. - Loss scaling is automatically applied when using FP16 to prevent overflow issues. - The precision setting does not affect the input data's original precision; it only applies during model computations and optimizer steps. Examples -------- Forward pass and gradient scaling with mixed precision: >>> trainer.run_batch(batch) >>> # AutoCast ensures model operations are performed in fp16 or bf16 """ batch = self._to_device(**batch) if isinstance(batch, dict) else self._to_device(*batch) if self.mode == "train": self.model.train() self.optimizer.zero_grad() with autocast(enabled="fp16" in self.precision or "bf16" in self.precision, dtype=torch.float16): step_dict = self.train_step(batch) loss = step_dict["loss"] if self.scaler: self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() else: loss.backward() self.optimizer.step() if self.scheduler_step_on_batch: self.scheduler.step(**kwargs) elif self.mode == "test": self.model.eval() with torch.no_grad(): with autocast( enabled="fp16" in self.precision or "bf16" in self.precision, dtype=torch.float16 ): step_dict = self.test_step(batch) else: raise ValueError(f"Unexpected mode {self.mode}.") torch.cuda.synchronize() self.logging(step_dict) self.progress_checker.update_step() return PrecisionTrainer