|
|
| from model_neo import NeoMiniConfig, NeoMini
|
| import torch
|
|
|
| def extend_model_context(checkpoint_path="checkpoints/checkpoint_step_149999.pt",
|
| new_max_len=16384):
|
| """Extend model's context window from 2048 to 4096 tokens"""
|
|
|
| print(f"Extending context window to {new_max_len} tokens...")
|
|
|
|
|
| config = NeoMiniConfig()
|
| config.max_seq_len = new_max_len
|
|
|
|
|
| extended_model = NeoMini(config)
|
|
|
|
|
| checkpoint = torch.load(checkpoint_path, map_location='cpu')
|
| original_state = checkpoint['model_state_dict']
|
|
|
|
|
| extended_state = extended_model.state_dict()
|
|
|
| for key in original_state:
|
| if key in extended_state:
|
| if 'pos' in key and extended_state[key].shape != original_state[key].shape:
|
|
|
| print(f"Interpolating position embeddings: {key}")
|
| old_pos_emb = original_state[key]
|
| new_pos_emb = torch.nn.functional.interpolate(
|
| old_pos_emb.unsqueeze(0).unsqueeze(0),
|
| size=(new_max_len, old_pos_emb.shape[-1]),
|
| mode='linear'
|
| ).squeeze(0).squeeze(0)
|
| extended_state[key] = new_pos_emb
|
| else:
|
| extended_state[key] = original_state[key]
|
|
|
| extended_model.load_state_dict(extended_state)
|
|
|
|
|
| extended_checkpoint = {
|
| 'model_state_dict': extended_model.state_dict(),
|
| 'config': config.to_dict()
|
| }
|
|
|
| output_path = "checkpoints/extended_context_model.pt"
|
| torch.save(extended_checkpoint, output_path)
|
| print(f"Extended model saved to {output_path}")
|
|
|
| return extended_model, config
|
|
|
| if __name__ == "__main__":
|
| extend_model_context()
|
|
|