zae_engine.data.collate.modules의 소스 코드

from typing import Dict, Any
import numpy as np
import torch
from torch.nn import functional as F
from scipy import signal
from sklearn.preprocessing import MinMaxScaler
import logging

logger = logging.getLogger(__name__)


[문서] class UnifiedChunker: """ Unified Chunker Class: Splits input data into chunks based on the tensor's dimensionality. Parameters ---------- chunk_size : int The size of each chunk. overlap : int, optional The overlap size between consecutive chunks. Default is 0. """ def __init__(self, chunk_size: int, overlap: int = 0): if overlap >= chunk_size: raise ValueError("overlap must be smaller than chunk_size.") self.chunk_size = chunk_size self.overlap = overlap def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: """ Splits the batch into chunks. Parameters ---------- batch : Dict[str, Any] - 'x': torch.Tensor of shape (sequence_length,) or (batch_size, sequence_length) - 'y': Optional[torch.Tensor] - 'fn': Optional[Any] Returns ------- Dict[str, Any] The batch after splitting into chunks. """ if "x" not in batch: raise KeyError("Batch must contain 'x' key.") x = batch["x"] y = batch.get("y", None) fn = batch.get("fn", None) if not isinstance(x, torch.Tensor): raise TypeError("'x' must be a torch.Tensor.") if x.dim() == 1: # Handle 1D tensor chunks = self._split_1d(x) elif x.dim() == 2: # Handle 2D tensor (split each sample in the batch) chunks = self._split_2d(x) else: raise ValueError("Unsupported tensor dimension. Only 1D and 2D tensors are supported.") # Process 'y' if it exists (e.g., repeat or average based on chunking) if y is not None: y_chunks = self._process_y(y, x.dim()) batch["y"] = y_chunks # Process 'fn' if it exists (e.g., repeat to match chunks) if fn is not None: fn_chunks = self._process_fn(fn, x.dim()) batch["fn"] = fn_chunks # Update 'x' in the batch with the chunks batch["x"] = chunks return batch def _split_1d(self, x: torch.Tensor) -> torch.Tensor: """ Splits a 1D tensor into chunks. Parameters ---------- x : torch.Tensor Shape: (sequence_length,) Returns ------- torch.Tensor The tensor split into chunks. Shape: (num_chunks, chunk_size) """ x_np = x.cpu().numpy() chunks = self._create_chunks(x_np, self.chunk_size, self.overlap) chunks_tensor = torch.tensor(chunks, dtype=x.dtype, device=x.device) return chunks_tensor def _split_2d(self, x: torch.Tensor) -> torch.Tensor: """ Splits a 2D tensor into chunks for each sample in the batch. Parameters ---------- x : torch.Tensor Shape: (batch_size, sequence_length) Returns ------- torch.Tensor The tensor split into chunks. Shape: (batch_size * num_chunks, chunk_size) """ batch_size, seq_length = x.shape x_np = x.cpu().numpy() chunks = [] for i in range(batch_size): sample_chunks = self._create_chunks(x_np[i], self.chunk_size, self.overlap) chunks.append(sample_chunks) chunks_np = np.vstack(chunks) chunks_tensor = torch.tensor(chunks_np, dtype=x.dtype, device=x.device) return chunks_tensor def _create_chunks(self, data: np.ndarray, chunk_size: int, overlap: int) -> np.ndarray: """ Splits data into chunks. Parameters ---------- data : np.ndarray 1D data array. chunk_size : int Size of each chunk. overlap : int Overlap size between chunks. Returns ------- np.ndarray Array of chunks. Shape: (num_chunks, chunk_size) """ step = chunk_size - overlap num_chunks = (len(data) - overlap) // step if (len(data) - overlap) % step != 0: num_chunks += 1 # Add an extra chunk for remaining data chunks = [] for i in range(num_chunks): start = i * step end = start + chunk_size chunk = data[start:end] if len(chunk) < chunk_size: # Pad with the last value if chunk is incomplete pad_width = chunk_size - len(chunk) chunk = np.pad(chunk, (0, pad_width), mode="edge") chunks.append(chunk) return np.array(chunks) def _process_y(self, y: torch.Tensor, x_dim: int) -> torch.Tensor: """ Processes 'y' values to match the chunked 'x'. Parameters ---------- y : torch.Tensor Original 'y' values. Shape: (batch_size,) or (sequence_length,) x_dim : int Dimensionality of 'x'. Returns ------- torch.Tensor Processed 'y' values. """ if x_dim == 1: # For 1D tensor, split 'y' similarly to 'x' y_np = y.cpu().numpy() chunks = self._create_chunks(y_np, self.chunk_size, self.overlap) y_chunks = torch.tensor(chunks, dtype=y.dtype, device=y.device) return y_chunks elif x_dim == 2: # For 2D tensor, repeat 'y' to match the number of chunks y_np = y.cpu().numpy() y_repeated = np.repeat(y_np, self._get_num_chunks(y_np.shape[0])) y_chunks = torch.tensor(y_repeated, dtype=y.dtype, device=y.device) return y_chunks else: raise ValueError("Unsupported tensor dimension for 'y' processing.") def _process_fn(self, fn: Any, x_dim: int) -> Any: """ Processes 'fn' values to match the chunked 'x'. Parameters ---------- fn : Any Original 'fn' value. x_dim : int Dimensionality of 'x'. Returns ------- Any Processed 'fn' values. """ if x_dim == 1: # For 1D tensor, repeat 'fn' as a list num_chunks = self._get_num_chunks(len(fn)) if isinstance(fn, list) else 1 fn_chunks = [fn] * num_chunks return fn_chunks elif x_dim == 2: # For 2D tensor, repeat each 'fn' for its respective chunks if isinstance(fn, list): fn_chunks = [item for item in fn for _ in range(self._get_num_chunks(len(fn)))] else: fn_chunks = [fn] * self._get_num_chunks(len(fn)) return fn_chunks else: raise ValueError("Unsupported tensor dimension for 'fn' processing.") def _get_num_chunks(self, length: int) -> int: """ Calculates the number of chunks for a given data length. Parameters ---------- length : int Length of the original data. Returns ------- int Number of chunks. """ step = self.chunk_size - self.overlap num_chunks = (length - self.overlap) // step if (length - self.overlap) % step != 0: num_chunks += 1 return num_chunks
[문서] class Chunk: """ Class for reshaping the 'x' tensor within a batch. Parameters ---------- n : int The size to reshape the 'x' tensor. Expected Input Shapes --------------------- batch: Dict[str, Any] - 'x': torch.Tensor of shape (batch_size, n * some_integer) Methods ------- __call__(batch: Dict[str, Any]) -> Dict[str, Any]: Reshapes the 'x' tensor in the batch. """ def __init__(self, n: int): self.n = n
[문서] def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: if "x" not in batch: raise KeyError("Batch must contain 'x' key.") x = batch["x"] if not isinstance(x, torch.Tensor): raise TypeError("'x' must be a torch.Tensor.") if x.dim() < 2: raise ValueError("'x' tensor must have at least 2 dimensions.") modulo = x.shape[-1] % self.n if modulo != 0: x = x[:, :-modulo] x = x.reshape(-1, self.n) batch["x"] = x return batch
[문서] class HotEncoder: """ Class for converting labels to one-hot encoded format. Parameters ---------- n_cls : int Number of classes for one-hot encoding. Expected Input Shapes --------------------- batch: Dict[str, Any] - 'y': torch.Tensor of shape (batch_size,) or (batch_size, ...) Methods ------- __call__(batch: Dict[str, Any]) -> Dict[str, Any]: Applies one-hot encoding to the labels in the batch. """ def __init__(self, n_cls: int): if n_cls <= 0: raise ValueError("Number of classes 'n_cls' must be positive.") self.n_cls = n_cls
[문서] def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: if "y" not in batch: raise KeyError("Batch must contain 'y' key.") y = batch["y"] if y.dim() == 1: batch["y_hot"] = torch.eye(self.n_cls, device=y.device)[y.long()] else: # Apply one-hot encoding to the last dimension if y is multi-dimensional batch["y_hot"] = torch.eye(self.n_cls, device=y.device)[y.long()] return batch
[문서] class SignalFilter: """ Class for filtering signals within a batch. Parameters ---------- fs : float Sampling frequency of the signal. method : str Filtering method to apply. Options are 'bandpass', 'bandstop', 'lowpass', 'highpass'. lowcut : float, optional Low cut-off frequency for bandpass and bandstop filters. highcut : float, optional High cut-off frequency for bandpass and bandstop filters. cutoff : float, optional Cut-off frequency for lowpass and highpass filters. Expected Input Shapes --------------------- batch: Dict[str, Any] - 'x': torch.Tensor of shape (sequence_length,) or (batch_size, sequence_length) Methods ------- __call__(batch: Dict[str, Any]) -> Dict[str, Any]: Applies the specified filter to the signals in the batch. """ def __init__(self, fs: float, method: str, lowcut: float = None, highcut: float = None, cutoff: float = None): self.fs = fs self.method = method self.lowcut = lowcut self.highcut = highcut self.cutoff = cutoff
[문서] def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: if "x" not in batch: raise KeyError("Batch must contain 'x' key.") x = batch["x"] if not isinstance(x, torch.Tensor): raise TypeError("'x' must be a torch.Tensor.") if x.dim() < 1: raise ValueError("'x' tensor must have at least 1 dimension.") nyq = self.fs / 2 x_np = x.squeeze().cpu().numpy() if self.method == "bandpass": if self.lowcut is None or self.highcut is None: raise ValueError("Lowcut and highcut frequencies must be specified for bandpass filter.") b, a = signal.butter(2, [self.lowcut / nyq, self.highcut / nyq], btype="bandpass") elif self.method == "bandstop": if self.lowcut is None or self.highcut is None: raise ValueError("Lowcut and highcut frequencies must be specified for bandstop filter.") b, a = signal.butter(2, [self.lowcut / nyq, self.highcut / nyq], btype="bandstop") elif self.method == "lowpass": if self.cutoff is None: raise ValueError("Cutoff frequency must be specified for lowpass filter.") b, a = signal.butter(2, self.cutoff / nyq, btype="low") elif self.method == "highpass": if self.cutoff is None: raise ValueError("Cutoff frequency must be specified for highpass filter.") b, a = signal.butter(2, self.cutoff / nyq, btype="high") else: raise ValueError( f"Invalid method: {self.method}. Choose from 'bandpass', 'bandstop', 'lowpass', 'highpass'." ) # Apply filter with padding to minimize edge effects try: x_filtered = signal.filtfilt(b, a, np.concatenate([x_np] * 3), method="gust") # Remove padding x_filtered = x_filtered[len(x_np) : 2 * len(x_np)] except Exception as e: logger.error(f"Error during filtering: {e}") raise TypeError(f"Error during filtering: {e}") from e # Convert back to torch.Tensor and preserve device batch["x"] = torch.tensor(x_filtered.copy(), dtype=torch.float32).unsqueeze(0).to(x.device) return batch
[문서] class Spliter: """ Class for splitting signals within a batch with overlapping. Parameters ---------- chunk_size : int, default=2560 The size of each chunk after splitting. overlapped : int, default=0 The number of overlapping samples between adjacent segments. Expected Input Shapes --------------------- batch: Dict[str, Any] - 'x': torch.Tensor of shape (sequence_length,) Methods ------- __call__(batch: Dict[str, Any]) -> Dict[str, Any]: Splits the signal in the batch with the specified overlap. """ def __init__(self, chunk_size: int = 2560, overlapped: int = 0): self.chunk_size = chunk_size self.overlapped = overlapped
[문서] def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: if "x" not in batch: raise KeyError("Batch must contain 'x' key.") x = batch["x"] if not isinstance(x, torch.Tensor): raise TypeError("'x' must be a torch.Tensor.") if x.dim() < 1: raise ValueError("'x' tensor must have at least 1 dimension.") raw_data = x.squeeze() # Calculate step size step = self.chunk_size - self.overlapped if step <= 0: raise ValueError("overlapped must be less than chunk_size.") total_length = len(raw_data) remain_length = (total_length - self.chunk_size) % step if remain_length != 0: padding_length = step - remain_length raw_data = F.pad(raw_data.unsqueeze(0), (0, padding_length), mode="replicate").squeeze() # Unfold to create chunks splited = raw_data.unfold(dimension=0, size=self.chunk_size, step=step) # Update 'x' with splitted chunks batch["x"] = splited return batch
[문서] class SignalScaler: """ Class for scaling signals within a batch using MinMaxScaler. Parameters ---------- None Expected Input Shapes --------------------- batch: Dict[str, Any] - 'x': torch.Tensor of shape (features,) or (batch_size, features) Methods ------- __call__(batch: Dict[str, Any]) -> Dict[str, Any]: Applies MinMax scaling to the signals in the batch. """ def __init__(self): self.scaler = MinMaxScaler()
[문서] def __call__(self, batch: Dict[str, Any]) -> Dict[str, Any]: if "x" not in batch: raise KeyError("Batch must contain 'x' key.") x = batch["x"] if not isinstance(x, torch.Tensor): raise TypeError("'x' must be a torch.Tensor.") # Convert to numpy for scaling x_np = x.cpu().numpy() # Ensure x has at least 2 dimensions for scaler if x_np.ndim == 1: x_np = x_np.reshape(-1, 1) elif x_np.ndim > 2: raise ValueError("'x' tensor must have 1 or 2 dimensions.") # Fit and transform self.scaler.fit(x_np) x_scaled = self.scaler.transform(x_np) # If original input was 1D, flatten the scaled array if x_np.shape[1] == 1: x_scaled = x_scaled.flatten() # Convert back to torch.Tensor and preserve device batch["x"] = torch.tensor(x_scaled, dtype=torch.float32, device=x.device) return batch