zae_engine.data.collate package

Submodules

zae_engine.data.collate.core module

class zae_engine.data.collate.core.CollateBase(*, x_key: Sequence[str] = ('x',), y_key: Sequence[str] = ('y',), aux_key: Sequence[str] = ('aux',), functions: List[Callable] | OrderedDict[str, Callable] = None)[source]

Bases: object

Base class for collating and processing batches of data using a sequence of functions.

This class allows you to define a sequence of preprocessing functions that will be applied to data batches in the specified order. It supports initialization with either an OrderedDict or a list of functions.

Parameters:
  • x_key (Sequence[str], default=["x"]) – The key(s) in the batch dictionary that represent the input data.

  • y_key (Sequence[str], default=["y"]) – The key(s) in the batch dictionary that represent the labels.

  • aux_key (Sequence[str], default=["aux"]) – The key(s) in the batch dictionary that represent the auxiliary data.

  • functions (Union[List[Callable], OrderedDict[str, Callable]], optional) – The preprocessing functions to apply to the batches in sequence.

__len__():

Returns the number of functions in the collator.

__iter__():

Returns an iterator over the functions in the collator.

io_check(sample_data: dict | OrderedDict, check_all: bool = False) None:[source]

Checks if the registered functions maintain the structure of the sample data.

set_batch(batch: dict | OrderedDict) None:[source]

Sets the sample batch for structure checking.

add_fn(name: str, fn: Callable) None:[source]

Adds a function to the collator with the given name.

__call__(batch: List[dict | OrderedDict]) Union[dict, OrderedDict]:[source]

Applies the registered functions to the input batch in sequence.

accumulate(batches: Tuple | List) Dict:[source]

Convert a list of dictionaries per data to a batch dictionary with list-type values.

Usage()
-----
Example 1: Initialization with a list of functions
>>> def fn1(batch):
>>>     # Function to process batch
>>>     return batch
>>> def fn2(batch):
>>>     # Another function to process batch
>>>     return batch
>>> collator = CollateBase(x_key=['x'], y_key=['y'], aux_key=['aux'], functions=[fn1, fn2])
>>> batch = {'x': [1, 2, 3], 'y': [1], 'aux': [0.5], 'filename': 'sample.txt'}
>>> processed_batch = collator([batch, batch])
Example 2: Initialization with an OrderedDict
>>> from collections import OrderedDict
>>> functions = OrderedDict([('fn1', fn1), ('fn2', fn2)])
>>> collator = CollateBase(x_key=['x'], y_key=['y'], aux_key=['aux'], functions=functions)
>>> processed_batch = collator([batch, batch])
Example 3: Checking input-output consistency
>>> sample_data = {'x': [1, 2, 3], 'y': [1], 'aux': [0.5], 'filename': 'sample.txt'}
>>> collator.set_batch(sample_data)
>>> collator.add_fn('fn3', fn3)  # This will check if fn3 maintains the structure of sample_data
accumulate(batches: Tuple | List) Dict[source]

Convert a list of dictionaries per data to a batch dictionary with list-type values.

Parameters:

batches (Union[Tuple, List]) – A list of dictionaries, where each dictionary represents a data batch.

Returns:

A dictionary where keys are batch attributes and values are lists or concatenated tensors.

Return type:

Dict

add_fn(name: str, fn: Callable) None[source]

Adds a new preprocessing function to the pipeline after validation.

Parameters:
  • name (str) – The name of the function.

  • fn (Callable) – The preprocessing function.

io_check(sample_data: dict | OrderedDict, check_all: bool = False) None[source]

Checks if the registered functions maintain the structure of the sample batch. Only checks the newly added function if check_all is False.

Parameters:
  • sample_data (Union[dict, OrderedDict]) – The sample data to test the functions with.

  • check_all (bool) – If True, checks all registered functions. Otherwise, checks only the newly added function.

set_batch(batch: dict | OrderedDict) None[source]

Sets the sample batch to be used for input-output structure validation.

Parameters:

batch (Union[dict, OrderedDict]) – The sample batch.

wrap(func: Callable = None)[source]

zae_engine.data.collate.modules module

class zae_engine.data.collate.modules.Chunk(n: int)[source]

Bases: object

Class for reshaping the ‘x’ tensor within a batch.

Parameters:
  • n (int) – The size to reshape the ‘x’ tensor.

  • Shapes (Expected Input)

  • ---------------------

  • batch (Dict[str, Any]) –

    • ‘x’: torch.Tensor of shape (batch_size, n * some_integer)

__call__(batch: Dict[str, Any]) Dict[str, Any]:[source]

Reshapes the ‘x’ tensor in the batch.

class zae_engine.data.collate.modules.HotEncoder(n_cls: int)[source]

Bases: object

Class for converting labels to one-hot encoded format.

Parameters:
  • n_cls (int) – Number of classes for one-hot encoding.

  • Shapes (Expected Input)

  • ---------------------

  • batch (Dict[str, Any]) –

    • ‘y’: torch.Tensor of shape (batch_size,) or (batch_size, …)

__call__(batch: Dict[str, Any]) Dict[str, Any]:[source]

Applies one-hot encoding to the labels in the batch.

class zae_engine.data.collate.modules.SignalFilter(fs: float, method: str, lowcut: float = None, highcut: float = None, cutoff: float = None)[source]

Bases: object

Class for filtering signals within a batch.

Parameters:
  • fs (float) – Sampling frequency of the signal.

  • method (str) – Filtering method to apply. Options are ‘bandpass’, ‘bandstop’, ‘lowpass’, ‘highpass’.

  • lowcut (float, optional) – Low cut-off frequency for bandpass and bandstop filters.

  • highcut (float, optional) – High cut-off frequency for bandpass and bandstop filters.

  • cutoff (float, optional) – Cut-off frequency for lowpass and highpass filters.

  • Shapes (Expected Input)

  • ---------------------

  • batch (Dict[str, Any]) –

    • ‘x’: torch.Tensor of shape (sequence_length,) or (batch_size, sequence_length)

__call__(batch: Dict[str, Any]) Dict[str, Any]:[source]

Applies the specified filter to the signals in the batch.

class zae_engine.data.collate.modules.SignalScaler[source]

Bases: object

Class for scaling signals within a batch using MinMaxScaler.

Parameters:
  • None

  • Shapes (Expected Input)

  • ---------------------

  • batch (Dict[str, Any]) –

    • ‘x’: torch.Tensor of shape (features,) or (batch_size, features)

__call__(batch: Dict[str, Any]) Dict[str, Any]:[source]

Applies MinMax scaling to the signals in the batch.

class zae_engine.data.collate.modules.Spliter(chunk_size: int = 2560, overlapped: int = 0)[source]

Bases: object

Class for splitting signals within a batch with overlapping.

Parameters:
  • chunk_size (int, default=2560) – The size of each chunk after splitting.

  • overlapped (int, default=0) – The number of overlapping samples between adjacent segments.

  • Shapes (Expected Input)

  • ---------------------

  • batch (Dict[str, Any]) –

    • ‘x’: torch.Tensor of shape (sequence_length,)

__call__(batch: Dict[str, Any]) Dict[str, Any]:[source]

Splits the signal in the batch with the specified overlap.

class zae_engine.data.collate.modules.UnifiedChunker(chunk_size: int, overlap: int = 0)[source]

Bases: object

Unified Chunker Class: Splits input data into chunks based on the tensor’s dimensionality.

Parameters:
  • chunk_size (int) – The size of each chunk.

  • overlap (int, optional) – The overlap size between consecutive chunks. Default is 0.

Module contents