| from abc import abstractmethod, ABC |
| from typing import Sequence |
| import torch |
| import torch.nn as nn |
|
|
|
|
| class AutoEncoderBase(ABC): |
| def __init__( |
| self, downsampling_ratio: int, sample_rate: int, |
| latent_shape: Sequence[int | None] |
| ): |
| self.downsampling_ratio = downsampling_ratio |
| self.sample_rate = sample_rate |
| self.latent_token_rate = sample_rate // downsampling_ratio |
| self.latent_shape = latent_shape |
| self.time_dim = latent_shape.index(None) + 1 |
|
|
| @abstractmethod |
| def encode( |
| self, waveform: torch.Tensor, waveform_lengths: torch.Tensor |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| ... |
|
|