| import re |
| import torch |
|
|
| |
| src = torch.load("./models/xtts/model.pth", map_location="cpu")['model'] |
|
|
| dst = { |
| "ar": "./models/tortoise/autoregressive.pth", |
| "df": "./models/tortoise/diffusion_decoder.pth", |
| } |
|
|
| for model, path in dst.items(): |
| dst[model] = torch.load(path, map_location="cpu") |
| torch.save( dst[model], f'{path}.bkp' ) |
|
|
| |
| regexes = { |
| "ar": r'^gpt\.', |
| "df": r'^diffusion_decoder\.', |
| } |
| for k, v in src.items(): |
| for model, regex in regexes.items(): |
| if re.match(regex, k): |
| key = re.sub(regex, "", k) |
| if key not in dst[model]: |
| continue |
| print(f"Writing {k} into {key}") |
| dst[model][key] = v |
| break |
|
|
| |
| torch.save(dst['ar'], "./models/tortoise/autoregressive.xtts.pth") |
| torch.save(dst['df'], "./models/tortoise/diffusion_decoder.xtts.pth") |