zae_engine.loss package

Submodules

zae_engine.loss.angular module

class zae_engine.loss.angular.ArcFaceLoss(in_features: int, out_features: int, s: float = 30.0, m: float = 0.5)[소스]

기반 클래스: Module

ArcFace loss for classification with angular margin.

매개변수:
  • in_features (int) – Size of each input sample.

  • out_features (int) – Number of classes.

  • s (float, optional) – Norm of input feature (default is 30.0).

  • m (float, optional) – Margin (default is 0.50).

forward(features, labels):

Forward pass to compute the ArcFace loss.

참조

forward(features: Tensor, labels: Tensor)[소스]

Forward pass to compute the ArcFace loss.

매개변수:
  • features (torch.Tensor) – Input features of shape (batch_size, in_features).

  • labels (torch.Tensor) – Ground truth labels of shape (batch_size).

반환:

Computed ArcFace logits of shape (batch_size, out_features).

반환 형식:

torch.Tensor

zae_engine.loss.iou module

zae_engine.loss.iou.GIoU(true_onoff: ndarray | Tensor, pred_onoff: ndarray | Tensor)[소스]

Compute mean Generalized Intersection over Union (GIoU) using the given true and predicted labels.

The true and predicted labels must be 2-D tensors with elements of integer type. See https://arxiv.org/abs/1902.09630v2 for details on GIoU.

매개변수:
  • true_onoff (Union[np.ndarray, torch.Tensor]) – True labels tensor. Elements must be of type int.

  • pred_onoff (Union[np.ndarray, torch.Tensor]) – Predicted labels tensor. Elements must be of type int.

반환:

The mean GIoU score.

반환 형식:

torch.Tensor

예외 발생:

AssertionError – If the elements of true_onoff or pred_onoff are not of integer type.

zae_engine.loss.iou.IoU(pred: Tensor, true: Tensor)[소스]

Compute mean Intersection over Union (IoU) using the given true and predicted labels.

The true and predicted labels must be 2-D tensors with elements of integer type.

매개변수:
  • pred (Union[np.ndarray, torch.Tensor]) – Predicted labels tensor. Elements must be of type int.

  • true (Union[np.ndarray, torch.Tensor]) – True labels tensor. Elements must be of type int.

반환:

The mean IoU score.

반환 형식:

torch.Tensor

zae_engine.loss.iou.mIoU(pred: Tensor, true: Tensor)[소스]

Compute mean Intersection over Union (mIoU) using the given predicted and true labels.

The outputs and labels must be 1-D or 2-D tensors with elements of integer or boolean type.

매개변수:
  • pred (torch.Tensor) – Predicted labels tensor. Elements must be of type int or bool.

  • true (torch.Tensor) – True labels tensor. Elements must be of type int or bool.

반환:

The mean IoU score.

반환 형식:

torch.Tensor

zae_engine.loss.norm module

zae_engine.loss.norm.mse(true: Tensor, predict: Tensor)[소스]

Compute the mean squared error (MSE) between the true and predicted values.

매개변수:
  • true (torch.Tensor) – The ground truth values.

  • predict (torch.Tensor) – The predicted values.

반환:

The computed mean squared error.

반환 형식:

torch.Tensor

Module contents