| import os |
| import uuid |
| from typing import Union |
|
|
| import torch |
| from box import Box |
|
|
| from modules import models |
| from modules.utils.SeedContext import SeedContext |
|
|
|
|
| def create_speaker_from_seed(seed): |
| chat_tts = models.load_chat_tts() |
| with SeedContext(seed, True): |
| emb = chat_tts.sample_random_speaker() |
| return emb |
|
|
|
|
| class Speaker: |
| @staticmethod |
| def from_file(file_like): |
| speaker = torch.load(file_like, map_location=torch.device("cpu")) |
| speaker.fix() |
| return speaker |
|
|
| @staticmethod |
| def from_tensor(tensor): |
| speaker = Speaker(seed_or_tensor=-2) |
| speaker.emb = tensor |
| return speaker |
|
|
| @staticmethod |
| def from_seed(seed: int): |
| speaker = Speaker(seed_or_tensor=seed) |
| speaker.emb = create_speaker_from_seed(seed) |
| return speaker |
|
|
| def __init__( |
| self, seed_or_tensor: Union[int, torch.Tensor], name="", gender="", describe="" |
| ): |
| self.id = uuid.uuid4() |
| self.seed = -2 if isinstance(seed_or_tensor, torch.Tensor) else seed_or_tensor |
| self.name = name |
| self.gender = gender |
| self.describe = describe |
| self.emb = None if isinstance(seed_or_tensor, int) else seed_or_tensor |
|
|
| |
| self.tokens = [] |
|
|
| def to_json(self, with_emb=False): |
| return Box( |
| **{ |
| "id": str(self.id), |
| "seed": self.seed, |
| "name": self.name, |
| "gender": self.gender, |
| "describe": self.describe, |
| "emb": self.emb.tolist() if with_emb else None, |
| } |
| ) |
|
|
| def fix(self): |
| is_update = False |
| if "id" not in self.__dict__: |
| setattr(self, "id", uuid.uuid4()) |
| is_update = True |
| if "seed" not in self.__dict__: |
| setattr(self, "seed", -2) |
| is_update = True |
| if "name" not in self.__dict__: |
| setattr(self, "name", "") |
| is_update = True |
| if "gender" not in self.__dict__: |
| setattr(self, "gender", "*") |
| is_update = True |
| if "describe" not in self.__dict__: |
| setattr(self, "describe", "") |
| is_update = True |
|
|
| return is_update |
|
|
| def __hash__(self): |
| return hash(str(self.id)) |
|
|
| def __eq__(self, other): |
| if not isinstance(other, Speaker): |
| return False |
| return str(self.id) == str(other.id) |
|
|
|
|
| |
| |
| |
| |
| |
| class SpeakerManager: |
| def __init__(self): |
| self.speakers = {} |
| self.speaker_dir = "./data/speakers/" |
| self.refresh_speakers() |
|
|
| def refresh_speakers(self): |
| self.speakers = {} |
| for speaker_file in os.listdir(self.speaker_dir): |
| if speaker_file.endswith(".pt"): |
| self.speakers[speaker_file] = Speaker.from_file( |
| self.speaker_dir + speaker_file |
| ) |
| |
| for fname, spk in self.speakers.items(): |
| if not os.path.exists(self.speaker_dir + fname): |
| del self.speakers[fname] |
|
|
| def list_speakers(self) -> list[Speaker]: |
| return list(self.speakers.values()) |
|
|
| def create_speaker_from_seed(self, seed, name="", gender="", describe=""): |
| if name == "": |
| name = seed |
| filename = name + ".pt" |
| speaker = Speaker(seed, name=name, gender=gender, describe=describe) |
| speaker.emb = create_speaker_from_seed(seed) |
| torch.save(speaker, self.speaker_dir + filename) |
| self.refresh_speakers() |
| return speaker |
|
|
| def create_speaker_from_tensor( |
| self, tensor, filename="", name="", gender="", describe="" |
| ): |
| if filename == "": |
| filename = name |
| speaker = Speaker( |
| seed_or_tensor=-2, name=name, gender=gender, describe=describe |
| ) |
| if isinstance(tensor, torch.Tensor): |
| speaker.emb = tensor |
| if isinstance(tensor, list): |
| speaker.emb = torch.tensor(tensor) |
| torch.save(speaker, self.speaker_dir + filename + ".pt") |
| self.refresh_speakers() |
| return speaker |
|
|
| def get_speaker(self, name) -> Union[Speaker, None]: |
| for speaker in self.speakers.values(): |
| if speaker.name == name: |
| return speaker |
| return None |
|
|
| def get_speaker_by_id(self, id) -> Union[Speaker, None]: |
| for speaker in self.speakers.values(): |
| if str(speaker.id) == str(id): |
| return speaker |
| return None |
|
|
| def get_speaker_filename(self, id: str): |
| filename = None |
| for fname, spk in self.speakers.items(): |
| if str(spk.id) == str(id): |
| filename = fname |
| break |
| return filename |
|
|
| def update_speaker(self, speaker: Speaker): |
| filename = None |
| for fname, spk in self.speakers.items(): |
| if str(spk.id) == str(speaker.id): |
| filename = fname |
| break |
|
|
| if filename: |
| torch.save(speaker, self.speaker_dir + filename) |
| self.refresh_speakers() |
| return speaker |
| else: |
| raise ValueError("Speaker not found for update") |
|
|
| def save_all(self): |
| for speaker in self.speakers.values(): |
| filename = self.get_speaker_filename(speaker.id) |
| torch.save(speaker, self.speaker_dir + filename) |
| |
|
|
| def __len__(self): |
| return len(self.speakers) |
|
|
|
|
| speaker_mgr = SpeakerManager() |
|
|