zae_engine.models.builds.nested_autoencoder의 소스 코드

from typing import OrderedDict, Union, Sequence, List

import torch
import torch.nn as nn
import torch.nn.functional as F

from ...nn_night.blocks.unet_block import RSUBlock

__all__ = ["NestedUNet"]
# https://arxiv.org/pdf/2005.09007


[문서] class NestedUNet(nn.Module): """ Implementation of the U²-Net architecture. Parameters ---------- in_ch : int, optional Number of input channels. Default is 3. out_ch : int, optional Number of output channels. Default is 1. width : Union[int, Sequence], optional Initial number of middle channels. Default is 32. heights : Sequence[int], optional List of RSU block heights for each encoder layer. Default is (7, 6, 5, 4, 4). dilation_heights : Sequence[int], optional List of dilation heights for each encoder layer. Default is (2, 2, 2, 2, 4). middle_width : Union[int, Sequence], optional List of middle channels for each RSU block. Default is (32, 32, 64, 128, 256). References ---------- .. [1] Qin, X., Zhang, Z., Huang, C., Dehghan, M., Zaiane, O. R., & Jagersand, M. (2020). U2-Net: Going deeper with nested U-structure for salient object detection. Pattern recognition, 106, 107404. (https://arxiv.org/pdf/2005.09007) """ def __init__( self, in_ch=3, out_ch=1, width: Union[int, Sequence] = 32, heights: Sequence[int] = (7, 6, 5, 4, 4), dilation_heights: Sequence[int] = (2, 2, 2, 2, 4), middle_width: Union[int, Sequence] = (32, 32, 64, 128, 256), ): super(NestedUNet, self).__init__() assert len(heights) == len(dilation_heights), "heights and dilation_heights must have the same length." self.minimum_resolution = max([2 ** (h - 1 + i) for i, h in enumerate(heights)]) self.num_layers = len(heights) # Encoder configuration self.encoder = nn.ModuleList() self.pool_layers = nn.ModuleList() self.encoder_channels = [] # Verify input width if isinstance(width, int): width_list = [width * 2**i for i in range(len(heights))] else: assert len(width) == len(heights) width_list = width for i, (h, dh, w, mw) in enumerate(zip(heights, dilation_heights, width_list, middle_width)): enc_ch_in = w if i else in_ch enc_ch_out = 2 * w # Double the channels self.encoder.append(RSUBlock(ch_in=enc_ch_in, ch_mid=mw, ch_out=enc_ch_out, height=h, dilation_height=dh)) self.pool_layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) # print(f"\tEnc_{i}\tch_in: {enc_ch_in}\tch_mid: {mw}\tch_out: {enc_ch_out}\th: {h}\tdh: {dh}") # Bottleneck bottleneck_height = heights[-1] # Use the same height as the last encoder layer bottleneck_dil = dilation_heights[-1] self.bottleneck = RSUBlock( ch_in=enc_ch_out, ch_mid=mw, ch_out=enc_ch_out, height=bottleneck_height, dilation_height=bottleneck_dil, ) # print( # f"\tBottle\tch_in: {enc_ch_out}\tch_mid: {mw}\tch_out: {enc_ch_out}\th: {bottleneck_height}\tdh: {bottleneck_dil}" # ) # Decoder configuration self.up_layers = nn.ModuleList() self.decoder = nn.ModuleList() self.decoder_channels = [] dec_ch_ = enc_ch_out * 2 for i, (h, dh, w, mw) in enumerate( zip(heights[::-1], dilation_heights[::-1], width_list[::-1], middle_width[::-1]) ): dec_ch_out = w # Set output channels self.up_layers.append(nn.Upsample(scale_factor=2)) self.decoder.append( RSUBlock( ch_in=dec_ch_, # Concatenated channels from skip connection ch_mid=mw, ch_out=dec_ch_out, height=h, dilation_height=dh, ) ) # print(f"\tDec_{i}\tch_in: {dec_ch_}\tch_mid: {mw}\tch_out: {dec_ch_out}\th: {h}\tdh: {dh}") self.decoder_channels.append((dec_ch_, mw, dec_ch_out)) dec_ch_ = 2 * dec_ch_out # Output layer self.out_conv = nn.Conv2d(dec_ch_out, out_ch, kernel_size=1) # Side outputs for deep supervision self.side_layers = nn.ModuleList() for i in range(self.num_layers): # Create side outputs for each decoder stage self.side_layers.append(nn.Conv2d(self.decoder_channels[i][2], out_ch, kernel_size=3, padding=1))
[문서] def forward(self, x): encoder_features = [] # Encoder path for i in range(self.num_layers): x = self.encoder[i](x) encoder_features.append(x) x = self.pool_layers[i](x) # Bottleneck x = self.bottleneck(x) side_outputs = [] # Decoder path for i in range(self.num_layers): x = self.up_layers[i](x) enc_feat = encoder_features[self.num_layers - i - 1] # Resize if necessary # if x.shape[2:] != enc_feat.shape[2:]: # x = F.interpolate(x, size=enc_feat.shape[2:], mode="bilinear", align_corners=True) x = torch.cat([x, enc_feat], dim=1) x = self.decoder[i](x) # Generate side output side_output = self.side_layers[i](x) side_outputs.append(side_output) # Final output out = self.out_conv(x) return out, *side_outputs