| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| This script is designed to extract features from different layers of a pretrained SSL model. |
| The extracted features will be in *.npy format, and in the shape of [L, D, T], where L is the |
| number of layers, D is the feature dimension, and T is the time dimension. |
| |
| Example usage: |
| |
| python extract_features.py \ |
| --model_path="nvidia/ssl_en_nest_large_v1.0" \ |
| --input=<path to input manifest, or a dir containing audios, or path to audio> \ |
| --output=<output directory to store features and manifest> \ |
| --layers="all" \ |
| --batch_size=8 \ |
| --workers=8 \ |
| --max_cache=1000 # save features every 1000 samples to avoid OOM in system memory |
| """ |
|
|
|
|
| import argparse |
| import os |
| import tempfile |
| from pathlib import Path |
| from typing import List |
|
|
| import lightning.pytorch as pl |
| import numpy as np |
| import torch |
| from tqdm import tqdm |
|
|
| from nemo.collections.asr.data.audio_to_text_dataset import get_char_dataset |
| from nemo.collections.asr.models import EncDecDenoiseMaskedTokenPredModel |
| from nemo.collections.asr.modules import ConformerMultiLayerFeatureExtractor |
| from nemo.collections.asr.parts.utils.manifest_utils import read_manifest, write_manifest |
| from nemo.collections.common.data.utils import move_data_to_device |
| from nemo.collections.common.parts.preprocessing.manifest import get_full_path |
| from nemo.core.classes.common import typecheck |
| from nemo.utils import logging |
|
|
| typecheck.set_typecheck_enabled(enabled=False) |
|
|
| parser = argparse.ArgumentParser(description="Extract audio features using an SSL model") |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| required=True, |
| help="Path to the .nemo model file or a pretrained model name from the NGC/HF model hub", |
| ) |
| parser.add_argument( |
| "-i", |
| "--input", |
| type=str, |
| required=True, |
| help="Path to the input audio file, or list of files, directory or jsonl manifest", |
| ) |
| parser.add_argument( |
| "-o", "--output", type=str, required=True, help="Path to the output directory that contains .npy file" |
| ) |
| parser.add_argument( |
| "-l", |
| "--layers", |
| type=str, |
| default="all", |
| help="Layers to extract features from, use 'all' to extract from all layer, 'last' for last layer, " |
| "or comma-separated indices of the target layers (e.g. '0,1,2')", |
| ) |
| parser.add_argument("-b", "--batch_size", type=int, default=8, help="Batch size for feature extraction") |
| parser.add_argument("-w", "--workers", type=int, default=8, help="Number of workers for feature extraction") |
| parser.add_argument("-d", "--device", type=str, default="cuda", help="Device to use for feature extraction") |
| parser.add_argument("-t", "--type", type=str, default="wav", help="audio file type, only needed for directory input") |
| parser.add_argument("--use_amp", action="store_true", help="Use automatic mixed precision") |
| parser.add_argument( |
| "--amp_dtype", |
| type=str, |
| default="float16", |
| choices=["float16", "bfloat16"], |
| help="Data type for automatic mixed precision", |
| ) |
| parser.add_argument("-mc", "--max_cache", type=int, default=-1, help="Max cache size before saving features") |
| args = parser.parse_args() |
|
|
|
|
| def get_input_manifest(input: str) -> List[dict]: |
| """ |
| Build manifest from input path or directory |
| """ |
| if input.endswith(".json") or input.endswith(".jsonl") and os.path.isfile(input): |
| logging.info(f"Reading manifest from: {input}") |
| manifest = [ |
| {"audio_filepath": str(get_full_path(item["audio_filepath"], input)), "duration": None, "text": "-"} |
| for item in read_manifest(input) |
| ] |
| elif os.path.isdir(input): |
| logging.info(f"Creating manifest from directory: {input}") |
| manifest = [ |
| {"audio_filepath": str(p), "duration": None, "text": "-"} for p in Path(input).rglob(f"*.{args.type}") |
| ] |
| logging.info(f"Found {len(manifest)} items of {args.type} files") |
| elif os.path.isfile(input): |
| logging.info(f"Reading single file: {input}") |
| manifest = [{"audio_filepath": Path(input).absolute.as_posix(), "duration": None, "text": "-"}] |
| else: |
| raise ValueError(f"Invalid input: {input}") |
| return manifest |
|
|
|
|
| def load_model(model_path): |
| """ |
| Load SSL model from local or pretrained |
| """ |
| if model_path.endswith(".nemo") and os.path.isfile(model_path): |
| logging.info(f"Loading model from local: {model_path}") |
| model = EncDecDenoiseMaskedTokenPredModel.restore_from(model_path) |
| else: |
| logging.info(f"Loading model from pretrained: {model_path}") |
| model = EncDecDenoiseMaskedTokenPredModel.from_pretrained(model_name=model_path) |
| return model |
|
|
|
|
| class FeatureExtractor(pl.LightningModule): |
| """ |
| Wrapper class for extracting features from SSL model |
| """ |
|
|
| def __init__(self, ssl_model: EncDecDenoiseMaskedTokenPredModel, layer: str = "all"): |
| super().__init__() |
| self.preprocessor = ssl_model.preprocessor |
| self.encoder = ssl_model.encoder |
| self.layer_idx_list = None |
| self.sample_rate = ssl_model.cfg.sample_rate |
| if layer == "all": |
| self.layer_idx_list = None |
| elif layer == "last": |
| self.layer_idx_list = [len(self.encoder.layers) - 1] |
| else: |
| try: |
| self.layer_idx_list = [int(l) for l in layer.split(",")] |
| except Exception as e: |
| raise ValueError(f"Invalid layer argument: {layer}. Error: {e}") |
| self.feature_extractor = ConformerMultiLayerFeatureExtractor( |
| self.encoder, aggregator=None, layer_idx_list=self.layer_idx_list |
| ) |
|
|
| def forward( |
| self, |
| input_signal=None, |
| input_signal_length=None, |
| processed_signal=None, |
| processed_signal_length=None, |
| ): |
| """ |
| Forward pass to extract features, same input interface as EncDecDenoiseMaskedTokenPredModel.forward |
| """ |
| has_input_signal = input_signal is not None and input_signal_length is not None |
| has_processed_signal = processed_signal is not None and processed_signal_length is not None |
| if (has_input_signal ^ has_processed_signal) == False: |
| raise ValueError( |
| f"{self} Arguments ``input_signal`` and ``input_signal_length`` are mutually exclusive " |
| " with ``processed_signal`` and ``processed_signal_len`` arguments." |
| ) |
| if not has_processed_signal: |
| processed_signal, processed_signal_length = self.preprocessor( |
| input_signal=input_signal, |
| length=input_signal_length, |
| ) |
| encoded, encoded_len = self.feature_extractor(audio_signal=processed_signal, length=processed_signal_length) |
| return encoded, encoded_len |
|
|
|
|
| def maybe_save_features(output_dir, results, max_cache, manifest): |
| """ |
| Check if the cache is full and save features to disk |
| """ |
| if len(results) == 0 or max_cache < 0 or len(results) < max_cache: |
| return |
| os.makedirs(output_dir, exist_ok=True) |
| logging.info(f"Saving {len(results)} features to {output_dir}") |
|
|
| for sample_id, audio_file, features_np in tqdm(results, desc="Saving features", total=len(results)): |
| filename = str(audio_file).replace("/", "_").replace(".", "_") |
| if len(filename) > 256: |
| filename = filename[-256:] |
| output_path = os.path.join(output_dir, f"{filename}.npy") |
| np.save(output_path, features_np) |
| manifest[sample_id]["feature_path"] = output_path |
|
|
| logging.info(f"Saved {len(results)} features to {output_dir}") |
| results.clear() |
|
|
|
|
| def extract_features(args): |
| """ |
| Main function to extract and save features from SSL model |
| """ |
|
|
| logging.info(f"Extracting features using params: {vars(args)}") |
|
|
| |
| model = load_model(args.model_path) |
| feature_extractor = FeatureExtractor(model, args.layers) |
| device = torch.device(args.device) |
| feature_extractor.to(device) |
|
|
| |
| logging.info(f"Building dataset from input: {args.input}") |
| tmp_manifest = tempfile.NamedTemporaryFile(mode="w", delete=False) |
| manifest = get_input_manifest(args.input) |
| write_manifest(tmp_manifest.name, manifest) |
| total_num_samples = len(manifest) |
|
|
| |
| config = { |
| "manifest_filepath": tmp_manifest.name, |
| "sample_rate": feature_extractor.sample_rate, |
| "return_sample_id": True, |
| } |
| dataset = get_char_dataset(config) |
| logging.info(f"Built dataset with {len(dataset)} samples") |
| dataloader = torch.utils.data.DataLoader( |
| dataset=dataset, |
| collate_fn=dataset.collate_fn, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.workers, |
| pin_memory=True, |
| drop_last=False, |
| ) |
|
|
| |
| indices = set() |
| results = [] |
| amp_dtype = torch.float16 if args.amp_dtype == "float16" else torch.bfloat16 |
| logging.info(f"Extracting features using AMP: {args.use_amp}, dtype: {amp_dtype}") |
| with torch.amp.autocast('cuda' if torch.cuda.is_available() else 'cpu', dtype=amp_dtype, enabled=args.use_amp): |
| with torch.inference_mode(): |
| for batch in tqdm(dataloader, desc="Extracting features"): |
| batch = move_data_to_device(batch, device) |
| audio_signal, audio_signal_len, _, _, sample_id = batch |
| features, features_len = feature_extractor( |
| input_signal=audio_signal, input_signal_length=audio_signal_len |
| ) |
| batch_size = features[0].size(0) |
| num_layers = len(features) |
| for i in range(batch_size): |
| sid_i = sample_id[i] |
| if sid_i in indices: |
| logging.warning(f"Skipping duplicated sample_id: {sample_id}") |
| continue |
|
|
| feat_i_len = features_len[0][i] |
| feat_i = [] |
| for j in range(num_layers): |
| feat_i.append(features[j][i][:, :feat_i_len]) |
|
|
| feat_i_np = torch.stack(feat_i, dim=0).cpu().numpy() |
|
|
| indices.add(sid_i) |
| results.append((sid_i, manifest[sid_i]['audio_filepath'], feat_i_np)) |
|
|
| maybe_save_features(args.output, results, args.max_cache, manifest) |
|
|
| maybe_save_features(args.output, results, 0, manifest) |
|
|
| output_manifest = Path(args.output) / "features.json" |
| write_manifest(output_manifest, manifest) |
| os.remove(tmp_manifest.name) |
| logging.info(f"Extracted features from {total_num_samples} samples to {args.output}") |
| logging.info(f"Manifest saved to: {output_manifest}") |
|
|
|
|
| if __name__ == "__main__": |
| extract_features(args) |
|
|