| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from configuration_indictrans import IndicTransConfig |
| from modeling_indictrans import IndicTransForConditionalGeneration |
|
|
|
|
| def remove_ignore_keys_(state_dict): |
| ignore_keys = [ |
| "encoder.version", |
| "decoder.version", |
| "model.encoder.version", |
| "model.decoder.version", |
| "_float_tensor", |
| "encoder.embed_positions._float_tensor", |
| "decoder.embed_positions._float_tensor", |
| ] |
| for k in ignore_keys: |
| state_dict.pop(k, None) |
|
|
|
|
| def make_linear_from_emb(emb): |
| vocab_size, emb_size = emb.shape |
| lin_layer = nn.Linear(vocab_size, emb_size, bias=False) |
| lin_layer.weight.data = emb.data |
| return lin_layer |
|
|
|
|
| def convert_fairseq_IT2_checkpoint_from_disk(checkpoint_path): |
| model = torch.load(checkpoint_path, map_location="cpu") |
| args = model["args"] or model["cfg"]["model"] |
| state_dict = model["model"] |
| remove_ignore_keys_(state_dict) |
| encoder_vocab_size = state_dict["encoder.embed_tokens.weight"].shape[0] |
| decoder_vocab_size = state_dict["decoder.embed_tokens.weight"].shape[0] |
|
|
| config = IndicTransConfig( |
| encoder_vocab_size=encoder_vocab_size, |
| decoder_vocab_size=decoder_vocab_size, |
| max_source_positions=args.max_source_positions, |
| max_target_positions=args.max_target_positions, |
| encoder_layers=args.encoder_layers, |
| decoder_layers=args.decoder_layers, |
| layernorm_embedding=args.layernorm_embedding, |
| encoder_normalize_before=args.encoder_normalize_before, |
| decoder_normalize_before=args.decoder_normalize_before, |
| encoder_attention_heads=args.encoder_attention_heads, |
| decoder_attention_heads=args.decoder_attention_heads, |
| encoder_ffn_dim=args.encoder_ffn_embed_dim, |
| decoder_ffn_dim=args.decoder_ffn_embed_dim, |
| encoder_embed_dim=args.encoder_embed_dim, |
| decoder_embed_dim=args.decoder_embed_dim, |
| encoder_layerdrop=args.encoder_layerdrop, |
| decoder_layerdrop=args.decoder_layerdrop, |
| dropout=args.dropout, |
| attention_dropout=args.attention_dropout, |
| activation_dropout=args.activation_dropout, |
| activation_function=args.activation_fn, |
| share_decoder_input_output_embed=args.share_decoder_input_output_embed, |
| scale_embedding=not args.no_scale_embedding, |
| ) |
|
|
| model = IndicTransForConditionalGeneration(config) |
| model.model.load_state_dict(state_dict, strict=False) |
| if not args.share_decoder_input_output_embed: |
| model.lm_head = make_linear_from_emb( |
| state_dict["decoder.output_projection.weight"] |
| ) |
| print(model) |
| return model |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| |
| parser.add_argument( |
| "--fairseq_path", |
| default="indic-en/model/checkpoint_best.pt", |
| type=str, |
| help="path to a model.pt on local filesystem.", |
| ) |
| parser.add_argument( |
| "--pytorch_dump_folder_path", |
| default="indic-en/hf_model", |
| type=str, |
| help="Path to the output PyTorch model.", |
| ) |
|
|
| args = parser.parse_args() |
| model = convert_fairseq_IT2_checkpoint_from_disk(args.fairseq_path) |
| model.save_pretrained(args.pytorch_dump_folder_path) |
|
|