zae_engine.loss package

Submodules

zae_engine.loss.angular module

class zae_engine.loss.angular.ArcFaceLoss(in_features: int, out_features: int, s: float = 30.0, m: float = 0.5)[source]

Bases: Module

ArcFace loss for classification with angular margin.

Parameters:
  • in_features (int) – Size of each input sample.

  • out_features (int) – Number of classes.

  • s (float, optional) – Norm of input feature (default is 30.0).

  • m (float, optional) – Margin (default is 0.50).

forward(features, labels):

Forward pass to compute the ArcFace loss.

References

forward(features: Tensor, labels: Tensor)[source]

Forward pass to compute the ArcFace loss.

Parameters:
  • features (torch.Tensor) – Input features of shape (batch_size, in_features).

  • labels (torch.Tensor) – Ground truth labels of shape (batch_size).

Returns:

Computed ArcFace logits of shape (batch_size, out_features).

Return type:

torch.Tensor

zae_engine.loss.iou module

zae_engine.loss.iou.GIoU(true_onoff: ndarray | Tensor, pred_onoff: ndarray | Tensor)[source]

Compute mean Generalized Intersection over Union (GIoU) using the given true and predicted labels.

The true and predicted labels must be 2-D tensors with elements of integer type. See https://arxiv.org/abs/1902.09630v2 for details on GIoU.

Parameters:
  • true_onoff (Union[np.ndarray, torch.Tensor]) – True labels tensor. Elements must be of type int.

  • pred_onoff (Union[np.ndarray, torch.Tensor]) – Predicted labels tensor. Elements must be of type int.

Returns:

The mean GIoU score.

Return type:

torch.Tensor

Raises:

AssertionError – If the elements of true_onoff or pred_onoff are not of integer type.

zae_engine.loss.iou.IoU(pred: Tensor, true: Tensor)[source]

Compute mean Intersection over Union (IoU) using the given true and predicted labels.

The true and predicted labels must be 2-D tensors with elements of integer type.

Parameters:
  • pred (Union[np.ndarray, torch.Tensor]) – Predicted labels tensor. Elements must be of type int.

  • true (Union[np.ndarray, torch.Tensor]) – True labels tensor. Elements must be of type int.

Returns:

The mean IoU score.

Return type:

torch.Tensor

zae_engine.loss.iou.mIoU(pred: Tensor, true: Tensor)[source]

Compute mean Intersection over Union (mIoU) using the given predicted and true labels.

The outputs and labels must be 1-D or 2-D tensors with elements of integer or boolean type.

Parameters:
  • pred (torch.Tensor) – Predicted labels tensor. Elements must be of type int or bool.

  • true (torch.Tensor) – True labels tensor. Elements must be of type int or bool.

Returns:

The mean IoU score.

Return type:

torch.Tensor

zae_engine.loss.norm module

zae_engine.loss.norm.mse(true: Tensor, predict: Tensor)[source]

Compute the mean squared error (MSE) between the true and predicted values.

Parameters:
  • true (torch.Tensor) – The ground truth values.

  • predict (torch.Tensor) – The predicted values.

Returns:

The computed mean squared error.

Return type:

torch.Tensor

Module contents