zae_engine.loss package¶
Submodules¶
zae_engine.loss.angular module¶
- class zae_engine.loss.angular.ArcFaceLoss(in_features: int, out_features: int, s: float = 30.0, m: float = 0.5)[source]¶
Bases:
Module
ArcFace loss for classification with angular margin.
- Parameters:
in_features (int) – Size of each input sample.
out_features (int) – Number of classes.
s (float, optional) – Norm of input feature (default is 30.0).
m (float, optional) – Margin (default is 0.50).
- forward(features, labels):
Forward pass to compute the ArcFace loss.
References
- forward(features: Tensor, labels: Tensor)[source]¶
Forward pass to compute the ArcFace loss.
- Parameters:
features (torch.Tensor) – Input features of shape (batch_size, in_features).
labels (torch.Tensor) – Ground truth labels of shape (batch_size).
- Returns:
Computed ArcFace logits of shape (batch_size, out_features).
- Return type:
torch.Tensor
zae_engine.loss.iou module¶
- zae_engine.loss.iou.GIoU(true_onoff: ndarray | Tensor, pred_onoff: ndarray | Tensor)[source]¶
Compute mean Generalized Intersection over Union (GIoU) using the given true and predicted labels.
The true and predicted labels must be 2-D tensors with elements of integer type. See https://arxiv.org/abs/1902.09630v2 for details on GIoU.
- Parameters:
true_onoff (Union[np.ndarray, torch.Tensor]) – True labels tensor. Elements must be of type int.
pred_onoff (Union[np.ndarray, torch.Tensor]) – Predicted labels tensor. Elements must be of type int.
- Returns:
The mean GIoU score.
- Return type:
torch.Tensor
- Raises:
AssertionError – If the elements of true_onoff or pred_onoff are not of integer type.
- zae_engine.loss.iou.IoU(pred: Tensor, true: Tensor)[source]¶
Compute mean Intersection over Union (IoU) using the given true and predicted labels.
The true and predicted labels must be 2-D tensors with elements of integer type.
- Parameters:
pred (Union[np.ndarray, torch.Tensor]) – Predicted labels tensor. Elements must be of type int.
true (Union[np.ndarray, torch.Tensor]) – True labels tensor. Elements must be of type int.
- Returns:
The mean IoU score.
- Return type:
torch.Tensor
- zae_engine.loss.iou.mIoU(pred: Tensor, true: Tensor)[source]¶
Compute mean Intersection over Union (mIoU) using the given predicted and true labels.
The outputs and labels must be 1-D or 2-D tensors with elements of integer or boolean type.
- Parameters:
pred (torch.Tensor) – Predicted labels tensor. Elements must be of type int or bool.
true (torch.Tensor) – True labels tensor. Elements must be of type int or bool.
- Returns:
The mean IoU score.
- Return type:
torch.Tensor
zae_engine.loss.norm module¶
- zae_engine.loss.norm.mse(true: Tensor, predict: Tensor)[source]¶
Compute the mean squared error (MSE) between the true and predicted values.
- Parameters:
true (torch.Tensor) – The ground truth values.
predict (torch.Tensor) – The predicted values.
- Returns:
The computed mean squared error.
- Return type:
torch.Tensor