zae_engine.loss.angular의 소스 코드

import torch
import torch.nn as nn
import torch.nn.functional as F


[문서] class ArcFaceLoss(nn.Module): """ ArcFace loss for classification with angular margin. Parameters ---------- in_features : int Size of each input sample. out_features : int Number of classes. s : float, optional Norm of input feature (default is 30.0). m : float, optional Margin (default is 0.50). Methods ------- forward(features, labels): Forward pass to compute the ArcFace loss. References ---------- .. [1] Deng, Jiankang, et al. "ArcFace: Additive Angular Margin Loss for Deep Face Recognition." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2019. https://arxiv.org/abs/1801.07698 """ def __init__(self, in_features: int, out_features: int, s: float = 30.0, m: float = 0.50): super(ArcFaceLoss, self).__init__() self.weight = nn.Parameter(torch.Tensor(out_features, in_features)) nn.init.xavier_uniform_(self.weight) self.s = s self.m = m
[문서] def forward(self, features: torch.Tensor, labels: torch.Tensor): """ Forward pass to compute the ArcFace loss. Parameters ---------- features : torch.Tensor Input features of shape (batch_size, in_features). labels : torch.Tensor Ground truth labels of shape (batch_size). Returns ------- torch.Tensor Computed ArcFace logits of shape (batch_size, out_features). """ cosine = F.linear(F.normalize(features), F.normalize(self.weight)) theta = torch.acos(cosine.clamp(-1 + 1e-7, 1 - 1e-7)) target_logit = torch.cos(theta + self.m) one_hot = torch.zeros_like(cosine) one_hot.scatter_(1, labels.view(-1, 1).long(), 1) output = (one_hot * target_logit) + ((1.0 - one_hot) * cosine) output *= self.s return output