Pj12 commited on
Commit
7d82963
·
verified ·
1 Parent(s): f20b851

Upload convert.py

Browse files
Files changed (1) hide show
  1. convert.py +150 -0
convert.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import HubertConfig, HubertModel
4
+ import logging
5
+
6
+ # Ignore fairseq's logger
7
+ logging.getLogger("fairseq").setLevel(logging.WARNING)
8
+ logging.getLogger("torch.distributed.nn.jit.instantiator").setLevel(logging.WARNING)
9
+
10
+ from fairseq import checkpoint_utils
11
+
12
+ models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
13
+ ["content-vec-best-legacy-500.pt"], suffix=""
14
+ )
15
+ model = models[0]
16
+ model.eval()
17
+ model.eval()
18
+
19
+
20
+ class HubertModelWithFinalProj(HubertModel):
21
+ def __init__(self, config):
22
+ super().__init__(config)
23
+
24
+ self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size)
25
+
26
+
27
+ # Default Config
28
+ hubert = HubertModelWithFinalProj(HubertConfig())
29
+
30
+ # huggingface: fairseq
31
+ mapping = {
32
+ "masked_spec_embed": "mask_emb",
33
+ "encoder.layer_norm.bias": "encoder.layer_norm.bias",
34
+ "encoder.layer_norm.weight": "encoder.layer_norm.weight",
35
+ "encoder.pos_conv_embed.conv.bias": "encoder.pos_conv.0.bias",
36
+ "encoder.pos_conv_embed.conv.weight_g": "encoder.pos_conv.0.weight_g",
37
+ "encoder.pos_conv_embed.conv.weight_v": "encoder.pos_conv.0.weight_v",
38
+ "feature_projection.layer_norm.bias": "layer_norm.bias",
39
+ "feature_projection.layer_norm.weight": "layer_norm.weight",
40
+ "feature_projection.projection.bias": "post_extract_proj.bias",
41
+ "feature_projection.projection.weight": "post_extract_proj.weight",
42
+ "final_proj.bias": "final_proj.bias",
43
+ "final_proj.weight": "final_proj.weight",
44
+ }
45
+
46
+ # Convert encoder
47
+ for layer in range(12):
48
+ for j in ["q", "k", "v"]:
49
+ mapping[
50
+ f"encoder.layers.{layer}.attention.{j}_proj.weight"
51
+ ] = f"encoder.layers.{layer}.self_attn.{j}_proj.weight"
52
+ mapping[
53
+ f"encoder.layers.{layer}.attention.{j}_proj.bias"
54
+ ] = f"encoder.layers.{layer}.self_attn.{j}_proj.bias"
55
+
56
+ mapping[
57
+ f"encoder.layers.{layer}.final_layer_norm.bias"
58
+ ] = f"encoder.layers.{layer}.final_layer_norm.bias"
59
+ mapping[
60
+ f"encoder.layers.{layer}.final_layer_norm.weight"
61
+ ] = f"encoder.layers.{layer}.final_layer_norm.weight"
62
+
63
+ mapping[
64
+ f"encoder.layers.{layer}.layer_norm.bias"
65
+ ] = f"encoder.layers.{layer}.self_attn_layer_norm.bias"
66
+ mapping[
67
+ f"encoder.layers.{layer}.layer_norm.weight"
68
+ ] = f"encoder.layers.{layer}.self_attn_layer_norm.weight"
69
+
70
+ mapping[
71
+ f"encoder.layers.{layer}.attention.out_proj.bias"
72
+ ] = f"encoder.layers.{layer}.self_attn.out_proj.bias"
73
+ mapping[
74
+ f"encoder.layers.{layer}.attention.out_proj.weight"
75
+ ] = f"encoder.layers.{layer}.self_attn.out_proj.weight"
76
+
77
+ mapping[
78
+ f"encoder.layers.{layer}.feed_forward.intermediate_dense.bias"
79
+ ] = f"encoder.layers.{layer}.fc1.bias"
80
+ mapping[
81
+ f"encoder.layers.{layer}.feed_forward.intermediate_dense.weight"
82
+ ] = f"encoder.layers.{layer}.fc1.weight"
83
+
84
+ mapping[
85
+ f"encoder.layers.{layer}.feed_forward.output_dense.bias"
86
+ ] = f"encoder.layers.{layer}.fc2.bias"
87
+ mapping[
88
+ f"encoder.layers.{layer}.feed_forward.output_dense.weight"
89
+ ] = f"encoder.layers.{layer}.fc2.weight"
90
+
91
+ # Convert Conv Layers
92
+ for layer in range(7):
93
+ mapping[
94
+ f"feature_extractor.conv_layers.{layer}.conv.weight"
95
+ ] = f"feature_extractor.conv_layers.{layer}.0.weight"
96
+
97
+ if layer != 0:
98
+ continue
99
+
100
+ mapping[
101
+ f"feature_extractor.conv_layers.{layer}.layer_norm.weight"
102
+ ] = f"feature_extractor.conv_layers.{layer}.2.weight"
103
+ mapping[
104
+ f"feature_extractor.conv_layers.{layer}.layer_norm.bias"
105
+ ] = f"feature_extractor.conv_layers.{layer}.2.bias"
106
+
107
+ hf_keys = set(hubert.state_dict().keys())
108
+ fair_keys = set(model.state_dict().keys())
109
+
110
+ hf_keys -= set(mapping.keys())
111
+ fair_keys -= set(mapping.values())
112
+
113
+ for i, j in zip(sorted(hf_keys), sorted(fair_keys)):
114
+ print(i, j)
115
+
116
+ print(hf_keys, fair_keys)
117
+ print(len(hf_keys), len(fair_keys))
118
+
119
+ # try loading the weights
120
+ new_state_dict = {}
121
+ for k, v in mapping.items():
122
+ new_state_dict[k] = model.state_dict()[v]
123
+
124
+ x = hubert.load_state_dict(new_state_dict, strict=False)
125
+ print(x)
126
+ hubert.eval()
127
+
128
+ with torch.no_grad():
129
+ new_input = torch.randn(1, 16384)
130
+
131
+ result1 = hubert(new_input, output_hidden_states=True)["hidden_states"][9]
132
+ result1 = hubert.final_proj(result1)
133
+
134
+ result2 = model.extract_features(
135
+ **{
136
+ "source": new_input,
137
+ "padding_mask": torch.zeros(1, 16384, dtype=torch.bool),
138
+ # "features_only": True,
139
+ "output_layer": 9,
140
+ }
141
+ )[0]
142
+ result2 = model.final_proj(result2)
143
+
144
+ assert torch.allclose(result1, result2, atol=1e-3)
145
+
146
+ print("Sanity check passed")
147
+
148
+ # Save huggingface model
149
+ hubert.save_pretrained(".")
150
+ print("Saved model")