| import numpy as np |
| import torch |
| import torch.nn as nn |
| from torch import hub |
|
|
| from . import vggish_input, vggish_params |
|
|
|
|
| class VGG(nn.Module): |
| def __init__(self, features): |
| super(VGG, self).__init__() |
| self.features = features |
| self.embeddings = nn.Sequential( |
| nn.Linear(512 * 4 * 6, 4096), |
| nn.ReLU(True), |
| nn.Linear(4096, 4096), |
| nn.ReLU(True), |
| nn.Linear(4096, 128), |
| nn.ReLU(True)) |
|
|
| def forward(self, x): |
| x = self.features(x) |
|
|
| |
| |
| x = torch.transpose(x, 1, 3) |
| x = torch.transpose(x, 1, 2) |
| x = x.contiguous() |
| x = x.view(x.size(0), -1) |
|
|
| return self.embeddings(x) |
|
|
|
|
| class Postprocessor(nn.Module): |
| """Post-processes VGGish embeddings. Returns a torch.Tensor instead of a |
| numpy array in order to preserve the gradient. |
| |
| "The initial release of AudioSet included 128-D VGGish embeddings for each |
| segment of AudioSet. These released embeddings were produced by applying |
| a PCA transformation (technically, a whitening transform is included as well) |
| and 8-bit quantization to the raw embedding output from VGGish, in order to |
| stay compatible with the YouTube-8M project which provides visual embeddings |
| in the same format for a large set of YouTube videos. This class implements |
| the same PCA (with whitening) and quantization transformations." |
| """ |
|
|
| def __init__(self): |
| """Constructs a postprocessor.""" |
| super(Postprocessor, self).__init__() |
| |
| self.pca_eigen_vectors = torch.empty( |
| (vggish_params.EMBEDDING_SIZE, vggish_params.EMBEDDING_SIZE,), |
| dtype=torch.float, |
| ) |
| self.pca_means = torch.empty( |
| (vggish_params.EMBEDDING_SIZE, 1), dtype=torch.float |
| ) |
|
|
| self.pca_eigen_vectors = nn.Parameter(self.pca_eigen_vectors, requires_grad=False) |
| self.pca_means = nn.Parameter(self.pca_means, requires_grad=False) |
|
|
| def postprocess(self, embeddings_batch): |
| """Applies tensor postprocessing to a batch of embeddings. |
| |
| Args: |
| embeddings_batch: An tensor of shape [batch_size, embedding_size] |
| containing output from the embedding layer of VGGish. |
| |
| Returns: |
| A tensor of the same shape as the input, containing the PCA-transformed, |
| quantized, and clipped version of the input. |
| """ |
| assert len(embeddings_batch.shape) == 2, "Expected 2-d batch, got %r" % ( |
| embeddings_batch.shape, |
| ) |
| assert ( |
| embeddings_batch.shape[1] == vggish_params.EMBEDDING_SIZE |
| ), "Bad batch shape: %r" % (embeddings_batch.shape,) |
|
|
| |
| |
| |
| |
| |
| |
| |
| pca_applied = torch.mm(self.pca_eigen_vectors, (embeddings_batch.t() - self.pca_means)).t() |
|
|
| |
| |
| clipped_embeddings = torch.clamp( |
| pca_applied, vggish_params.QUANTIZE_MIN_VAL, vggish_params.QUANTIZE_MAX_VAL |
| ) |
| |
| quantized_embeddings = torch.round( |
| (clipped_embeddings - vggish_params.QUANTIZE_MIN_VAL) |
| * ( |
| 255.0 |
| / (vggish_params.QUANTIZE_MAX_VAL - vggish_params.QUANTIZE_MIN_VAL) |
| ) |
| ) |
| return torch.squeeze(quantized_embeddings) |
|
|
| def forward(self, x): |
| return self.postprocess(x) |
|
|
|
|
| def make_layers(): |
| layers = [] |
| in_channels = 1 |
| for v in [64, "M", 128, "M", 256, 256, "M", 512, 512, "M"]: |
| if v == "M": |
| layers += [nn.MaxPool2d(kernel_size=2, stride=2)] |
| else: |
| conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) |
| layers += [conv2d, nn.ReLU(inplace=True)] |
| in_channels = v |
| return nn.Sequential(*layers) |
|
|
|
|
| def _vgg(): |
| return VGG(make_layers()) |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| class VGGish(VGG): |
| def __init__(self, cfg, device=None): |
| super().__init__(make_layers()) |
| if cfg.FREEZE_AUDIO_EXTRACTOR: |
| state_dict = torch.load(cfg.PRETRAINED_VGGISH_MODEL_PATH) |
| super().load_state_dict(state_dict) |
| print(f'==> Load pretrained VGGish parameters from {cfg.PRETRAINED_VGGISH_MODEL_PATH}') |
|
|
| if device is None: |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print("device: ", device) |
| self.device = device |
|
|
| self.preprocess = cfg.PREPROCESS_AUDIO_TO_LOG_MEL |
| self.postprocess = cfg.POSTPROCESS_LOG_MEL_WITH_PCA |
| if self.postprocess: |
| self.pproc = Postprocessor() |
| if cfg.FREEZE_AUDIO_EXTRACTOR: |
| state_dict = torch.load(cfg.PRETRAINED_PCA_PARAMS_PATH) |
| |
| state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME] = torch.as_tensor( |
| state_dict[vggish_params.PCA_EIGEN_VECTORS_NAME], dtype=torch.float |
| ) |
| state_dict[vggish_params.PCA_MEANS_NAME] = torch.as_tensor( |
| state_dict[vggish_params.PCA_MEANS_NAME].reshape(-1, 1), dtype=torch.float |
| ) |
| self.pproc.load_state_dict(state_dict) |
| self.to(self.device) |
|
|
| def forward(self, x): |
| if self.preprocess: |
| print(">>> pre processing...") |
| x = self._preprocess(x) |
| x = x.to(self.device) |
| x = VGG.forward(self, x) |
| if self.postprocess: |
| print(">>> post processing...") |
| x = self._postprocess(x) |
| return x |
|
|
| def _preprocess(self, x): |
| |
| |
| if isinstance(x, str): |
| x = vggish_input.wavfile_to_examples(x) |
| else: |
| raise AttributeError |
| return x |
|
|
| def _postprocess(self, x): |
| return self.pproc(x) |
|
|