zae_engine.metrics.confusion의 소스 코드

from typing import Union, List, Tuple, Optional

import numpy as np
import torch
from rich import box
from rich.console import Console
from rich.table import Table

from ..utils.decorators import np2torch, shape_check


[문서] @np2torch(dtype=torch.int) @shape_check(2) def confusion_matrix( y_hat: Union[np.ndarray, torch.Tensor], y_true: Union[np.ndarray, torch.Tensor], num_classes: int ) -> torch.Tensor: """ Compute the confusion matrix for classification predictions. This function calculates the confusion matrix, comparing the predicted labels (y_hat) with the true labels (y_true). Parameters ---------- y_hat : Union[np.ndarray, torch.Tensor] The predicted labels, either as a numpy array or a torch tensor. y_true : Union[np.ndarray, torch.Tensor] The true labels, either as a numpy array or a torch tensor. num_classes : int The number of classes in the classification task. Returns ------- torch.Tensor The confusion matrix as a 2-D tensor of shape (num_classes, num_classes). Examples -------- >>> y_true = np.array([0, 1, 2, 2, 1]) >>> y_hat = np.array([0, 2, 2, 2, 0]) >>> confusion_matrix(y_hat, y_true, 3) tensor([[1., 0., 0.], [1., 0., 0.], [0., 1., 2.]]) >>> y_true = torch.tensor([0, 1, 2, 2, 1]) >>> y_hat = torch.tensor([0, 2, 2, 2, 0]) >>> confusion_matrix(y_hat, y_true, 3) tensor([[1., 0., 0.], [1., 0., 0.], [0., 1., 2.]]) """ canvas = torch.zeros((num_classes, num_classes)) for true, hat in zip(y_true, y_hat): canvas[true, hat] += 1 return canvas