| |
| |
| |
| |
| |
| from pathlib import Path |
| from typing import List, Tuple |
| import os |
| import numpy as np |
| import torch |
| from text.symbol_table import SymbolTable |
| from text import text_to_sequence |
|
|
|
|
| ''' |
| TextToken: map text to id |
| ''' |
| |
| |
| class TextTokenCollator: |
| def __init__( |
| self, |
| text_tokens: List[str], |
| add_eos: bool = True, |
| add_bos: bool = True, |
| pad_symbol: str = "<pad>", |
| bos_symbol: str = "<bos>", |
| eos_symbol: str = "<eos>", |
| ): |
| self.pad_symbol = pad_symbol |
| self.add_eos = add_eos |
| self.add_bos = add_bos |
| self.bos_symbol = bos_symbol |
| self.eos_symbol = eos_symbol |
|
|
| unique_tokens = [pad_symbol] |
| if add_bos: |
| unique_tokens.append(bos_symbol) |
| if add_eos: |
| unique_tokens.append(eos_symbol) |
| unique_tokens.extend(sorted(text_tokens)) |
|
|
| self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} |
| self.idx2token = unique_tokens |
|
|
| def index( |
| self, tokens_list: List[str] |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| seqs, seq_lens = [], [] |
| for tokens in tokens_list: |
| assert ( |
| all([True if s in self.token2idx else False for s in tokens]) |
| is True |
| ) |
| seq = ( |
| ([self.bos_symbol] if self.add_bos else []) |
| + list(tokens) |
| + ([self.eos_symbol] if self.add_eos else []) |
| ) |
| seqs.append(seq) |
| seq_lens.append(len(seq)) |
|
|
| max_len = max(seq_lens) |
| for k, (seq, seq_len) in enumerate(zip(seqs, seq_lens)): |
| seq.extend([self.pad_symbol] * (max_len - seq_len)) |
|
|
| tokens = torch.from_numpy( |
| np.array( |
| [[self.token2idx[token] for token in seq] for seq in seqs], |
| dtype=np.int64, |
| ) |
| ) |
| tokens_lens = torch.IntTensor(seq_lens) |
|
|
| return tokens, tokens_lens |
| |
| def __call__(self, text): |
| tokens_seq = [p for p in text] |
| seq = ( |
| ([self.bos_symbol] if self.add_bos else []) |
| + tokens_seq |
| + ([self.eos_symbol] if self.add_eos else []) |
| ) |
|
|
| token_ids = [self.token2idx[token] for token in seq] |
| token_lens = len(tokens_seq) + self.add_eos + self.add_bos |
|
|
| return token_ids, token_lens |
| |
|
|
| def get_text_token_collater(text_tokens_file: str) -> TextTokenCollator: |
| text_tokens_path = Path(text_tokens_file) |
| unique_tokens = SymbolTable.from_file(text_tokens_path) |
| collater = TextTokenCollator( |
| unique_tokens.symbols, add_bos=True, add_eos=True |
| ) |
| token2idx = collater.token2idx |
| return collater, token2idx |
|
|
|
|
| class phoneIDCollation: |
| def __init__(self, cfg, dataset=None, symbols_dict_file=None) -> None: |
|
|
| if cfg.preprocess.phone_extractor != 'lexicon': |
| |
| if symbols_dict_file is None: |
| assert dataset is not None |
| symbols_dict_file = os.path.join( |
| cfg.preprocess.processed_dir, |
| dataset, |
| cfg.preprocess.symbols_dict |
| ) |
| self.text_token_colloator, token2idx = get_text_token_collater(symbols_dict_file) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| def get_phone_id_sequence(self, cfg, phones_seq): |
| |
| if cfg.preprocess.phone_extractor == 'lexicon': |
| phones_seq = ' '.join(phones_seq) |
| sequence = text_to_sequence(phones_seq, cfg.preprocess.text_cleaners) |
| else: |
| sequence, seq_len = self.text_token_colloator(phones_seq) |
| return sequence |
| |