zae_engine.data.collate package¶
Submodules¶
zae_engine.data.collate.core module¶
- class zae_engine.data.collate.core.CollateBase(*, x_key: Sequence[str] = ('x',), y_key: Sequence[str] = ('y',), aux_key: Sequence[str] = ('aux',), functions: List[Callable] | OrderedDict[str, Callable] = None)[소스]¶
기반 클래스:
object
Base class for collating and processing batches of data using a sequence of functions.
This class allows you to define a sequence of preprocessing functions that will be applied to data batches in the specified order. It supports initialization with either an OrderedDict or a list of functions.
- 매개변수:
x_key (Sequence[str], default=["x"]) – The key(s) in the batch dictionary that represent the input data.
y_key (Sequence[str], default=["y"]) – The key(s) in the batch dictionary that represent the labels.
aux_key (Sequence[str], default=["aux"]) – The key(s) in the batch dictionary that represent the auxiliary data.
functions (Union[List[Callable], OrderedDict[str, Callable]], optional) – The preprocessing functions to apply to the batches in sequence.
- __len__():
Returns the number of functions in the collator.
- __iter__():
Returns an iterator over the functions in the collator.
- io_check(sample_data: dict | OrderedDict, check_all: bool = False) None: [소스]¶
Checks if the registered functions maintain the structure of the sample data.
- __call__(batch: List[dict | OrderedDict]) Union[dict, OrderedDict]: [소스]¶
Applies the registered functions to the input batch in sequence.
- accumulate(batches: Tuple | List) Dict: [소스]¶
Convert a list of dictionaries per data to a batch dictionary with list-type values.
- Usage()¶
- -----
- Example 1: Initialization with a list of functions
- >>> def fn1(batch):
- >>> # Function to process batch
- >>> return batch
- >>> def fn2(batch):
- >>> # Another function to process batch
- >>> return batch
- >>> collator = CollateBase(x_key=['x'], y_key=['y'], aux_key=['aux'], functions=[fn1, fn2])
- >>> batch = {'x': [1, 2, 3], 'y': [1], 'aux': [0.5], 'filename': 'sample.txt'}
- >>> processed_batch = collator([batch, batch])
- Example 2: Initialization with an OrderedDict
- >>> from collections import OrderedDict
- >>> functions = OrderedDict([('fn1', fn1), ('fn2', fn2)])
- >>> collator = CollateBase(x_key=['x'], y_key=['y'], aux_key=['aux'], functions=functions)
- >>> processed_batch = collator([batch, batch])
- Example 3: Checking input-output consistency
- >>> sample_data = {'x': [1, 2, 3], 'y': [1], 'aux': [0.5], 'filename': 'sample.txt'}
- >>> collator.set_batch(sample_data)
- >>> collator.add_fn('fn3', fn3) # This will check if fn3 maintains the structure of sample_data
- accumulate(batches: Tuple | List) Dict [소스]¶
Convert a list of dictionaries per data to a batch dictionary with list-type values.
- 매개변수:
batches (Union[Tuple, List]) – A list of dictionaries, where each dictionary represents a data batch.
- 반환:
A dictionary where keys are batch attributes and values are lists or concatenated tensors.
- 반환 형식:
Dict
- add_fn(name: str, fn: Callable) None [소스]¶
Adds a new preprocessing function to the pipeline after validation.
- 매개변수:
name (str) – The name of the function.
fn (Callable) – The preprocessing function.
- io_check(sample_data: dict | OrderedDict, check_all: bool = False) None [소스]¶
Checks if the registered functions maintain the structure of the sample batch. Only checks the newly added function if check_all is False.
- 매개변수:
sample_data (Union[dict, OrderedDict]) – The sample data to test the functions with.
check_all (bool) – If True, checks all registered functions. Otherwise, checks only the newly added function.
zae_engine.data.collate.modules module¶
- class zae_engine.data.collate.modules.Chunk(n: int)[소스]¶
기반 클래스:
object
Class for reshaping the ‘x’ tensor within a batch.
- 매개변수:
n (int) – The size to reshape the ‘x’ tensor.
Shapes (Expected Input)
---------------------
batch (Dict[str, Any]) –
‘x’: torch.Tensor of shape (batch_size, n * some_integer)
- class zae_engine.data.collate.modules.HotEncoder(n_cls: int)[소스]¶
기반 클래스:
object
Class for converting labels to one-hot encoded format.
- 매개변수:
n_cls (int) – Number of classes for one-hot encoding.
Shapes (Expected Input)
---------------------
batch (Dict[str, Any]) –
‘y’: torch.Tensor of shape (batch_size,) or (batch_size, …)
- class zae_engine.data.collate.modules.SignalFilter(fs: float, method: str, lowcut: float = None, highcut: float = None, cutoff: float = None)[소스]¶
기반 클래스:
object
Class for filtering signals within a batch.
- 매개변수:
fs (float) – Sampling frequency of the signal.
method (str) – Filtering method to apply. Options are ‘bandpass’, ‘bandstop’, ‘lowpass’, ‘highpass’.
lowcut (float, optional) – Low cut-off frequency for bandpass and bandstop filters.
highcut (float, optional) – High cut-off frequency for bandpass and bandstop filters.
cutoff (float, optional) – Cut-off frequency for lowpass and highpass filters.
Shapes (Expected Input)
---------------------
batch (Dict[str, Any]) –
‘x’: torch.Tensor of shape (sequence_length,) or (batch_size, sequence_length)
- class zae_engine.data.collate.modules.SignalScaler[소스]¶
기반 클래스:
object
Class for scaling signals within a batch using MinMaxScaler.
- 매개변수:
None
Shapes (Expected Input)
---------------------
batch (Dict[str, Any]) –
‘x’: torch.Tensor of shape (features,) or (batch_size, features)
- class zae_engine.data.collate.modules.Spliter(chunk_size: int = 2560, overlapped: int = 0)[소스]¶
기반 클래스:
object
Class for splitting signals within a batch with overlapping.
- 매개변수:
chunk_size (int, default=2560) – The size of each chunk after splitting.
overlapped (int, default=0) – The number of overlapping samples between adjacent segments.
Shapes (Expected Input)
---------------------
batch (Dict[str, Any]) –
‘x’: torch.Tensor of shape (sequence_length,)
- class zae_engine.data.collate.modules.UnifiedChunker(chunk_size: int, overlap: int = 0)[소스]¶
기반 클래스:
object
Unified Chunker Class: Splits input data into chunks based on the tensor’s dimensionality.
- 매개변수:
chunk_size (int) – The size of each chunk.
overlap (int, optional) – The overlap size between consecutive chunks. Default is 0.