| import numpy as np |
| import torch |
| import hashlib |
| import pathlib |
| from scipy.fft import fft |
| from pybase16384 import encode_to_string, decode_from_string |
|
|
| from configs import CPUConfig, singleton_variable |
| from rvc.synthesizer import get_synthesizer |
|
|
| from .pipeline import Pipeline |
| from .utils import load_hubert |
|
|
|
|
| class TorchSeedContext: |
| def __init__(self, seed): |
| self.seed = seed |
| self.state = None |
|
|
| def __enter__(self): |
| self.state = torch.random.get_rng_state() |
| torch.manual_seed(self.seed) |
|
|
| def __exit__(self, type, value, traceback): |
| torch.random.set_rng_state(self.state) |
|
|
|
|
| half_hash_len = 512 |
| expand_factor = 65536 * 8 |
|
|
|
|
| @singleton_variable |
| def original_audio_storage(): |
| return np.load(pathlib.Path(__file__).parent / "lgdsng.npz") |
|
|
|
|
| @singleton_variable |
| def original_audio(): |
| return original_audio_storage()["a"] |
|
|
|
|
| @singleton_variable |
| def original_audio_time_minus(): |
| return original_audio_storage()["t"] |
|
|
|
|
| @singleton_variable |
| def original_audio_freq_minus(): |
| return original_audio_storage()["f"] |
|
|
|
|
| @singleton_variable |
| def original_rmvpe_f0(): |
| x = original_audio_storage() |
| return x["pitch"], x["pitchf"] |
|
|
|
|
| def _cut_u16(n): |
| if n > 16384: |
| n = 16384 + 16384 * (1 - np.exp((16384 - n) / expand_factor)) |
| elif n < -16384: |
| n = -16384 - 16384 * (1 - np.exp((n + 16384) / expand_factor)) |
| return n |
|
|
|
|
| |
| def wave_hash(time_field): |
| np.divide(time_field, np.abs(time_field).max(), time_field) |
| if len(time_field) != 48000: |
| raise Exception("time not hashable") |
| freq_field = fft(time_field) |
| if len(freq_field) != 48000: |
| raise Exception("freq not hashable") |
| np.add(time_field, original_audio_time_minus(), out=time_field) |
| np.add(freq_field, original_audio_freq_minus(), out=freq_field) |
| hash = np.zeros(half_hash_len // 2 * 2, dtype=">i2") |
| d = 375 * 512 // half_hash_len |
| for i in range(half_hash_len // 4): |
| a = i * 2 |
| b = a + 1 |
| x = a + half_hash_len // 2 |
| y = x + 1 |
| s = np.average(freq_field[i * d : (i + 1) * d]) |
| hash[a] = np.int16(_cut_u16(round(32768 * np.real(s)))) |
| hash[b] = np.int16(_cut_u16(round(32768 * np.imag(s)))) |
| hash[x] = np.int16( |
| _cut_u16(round(32768 * np.sum(time_field[i * d : i * d + d // 2]))) |
| ) |
| hash[y] = np.int16( |
| _cut_u16(round(32768 * np.sum(time_field[i * d + d // 2 : (i + 1) * d]))) |
| ) |
| return encode_to_string(hash.tobytes()) |
|
|
|
|
| def model_hash(config, tgt_sr, net_g, if_f0, version): |
| pipeline = Pipeline(tgt_sr, config) |
| audio = original_audio() |
| hbt = load_hubert(config.device, config.is_half) |
| audio_opt = pipeline.pipeline( |
| hbt, |
| net_g, |
| 0, |
| audio, |
| [0, 0, 0], |
| 6, |
| original_rmvpe_f0(), |
| "", |
| 0, |
| 2 if if_f0 else 0, |
| 3, |
| tgt_sr, |
| 16000, |
| 0.25, |
| version, |
| 0.33, |
| ) |
| del hbt |
| opt_len = len(audio_opt) |
| diff = 48000 - opt_len |
| if diff > 0: |
| audio_opt = np.pad(audio_opt, (diff, 0)) |
| elif diff < 0: |
| n = diff // 2 |
| n = -n |
| audio_opt = audio_opt[n:-n] |
| h = wave_hash(audio_opt) |
| del pipeline, audio_opt |
| return h |
|
|
|
|
| def model_hash_ckpt(cpt): |
| config = CPUConfig() |
|
|
| with TorchSeedContext(114514): |
| net_g, cpt = get_synthesizer(cpt, config.device) |
| tgt_sr = cpt["config"][-1] |
| if_f0 = cpt.get("f0", 1) |
| version = cpt.get("version", "v1") |
|
|
| if config.is_half: |
| net_g = net_g.half() |
| else: |
| net_g = net_g.float() |
|
|
| h = model_hash(config, tgt_sr, net_g, if_f0, version) |
|
|
| del net_g |
|
|
| return h |
|
|
|
|
| def model_hash_from(path): |
| cpt = torch.load(path, map_location="cpu") |
| h = model_hash_ckpt(cpt) |
| del cpt |
| return h |
|
|
|
|
| def _extend_difference(n, a, b): |
| if n < a: |
| n = a |
| elif n > b: |
| n = b |
| n -= a |
| n /= b - a |
| return n |
|
|
|
|
| def hash_similarity(h1: str, h2: str) -> float: |
| try: |
| h1b, h2b = decode_from_string(h1), decode_from_string(h2) |
| if len(h1b) != half_hash_len * 2 or len(h2b) != half_hash_len * 2: |
| raise Exception("invalid hash length") |
| h1n, h2n = np.frombuffer(h1b, dtype=">i2"), np.frombuffer(h2b, dtype=">i2") |
| d = 0 |
| for i in range(half_hash_len // 4): |
| a = i * 2 |
| b = a + 1 |
| ax = complex(h1n[a], h1n[b]) |
| bx = complex(h2n[a], h2n[b]) |
| if abs(ax) == 0 or abs(bx) == 0: |
| continue |
| d += np.abs(ax - bx) |
| frac = np.linalg.norm(h1n) * np.linalg.norm(h2n) |
| cosine = ( |
| np.dot(h1n.astype(np.float32), h2n.astype(np.float32)) / frac |
| if frac != 0 |
| else 1.0 |
| ) |
| distance = _extend_difference(np.exp(-d / expand_factor), 0.5, 1.0) |
| return round((abs(cosine) + distance) / 2, 6) |
| except Exception as e: |
| return str(e) |
|
|
|
|
| def hash_id(h: str) -> str: |
| d = decode_from_string(h) |
| if len(d) != half_hash_len * 2: |
| return "invalid hash length" |
| return encode_to_string( |
| np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes() |
| )[:-2] + encode_to_string(hashlib.md5(d).digest()[:7]) |
|
|