| import logging |
| from typing import Union, List, Optional, Dict, Any, Literal |
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoTokenizer |
| from transformers_neuronx import MistralForSampling, GQA, NeuronConfig |
| import time |
| import math |
|
|
|
|
| model_name = './checkpoint-3000' |
| amp = 'bf16' |
| batch_size = 1 |
| tp_degree = 8 |
| n_positions = 8192 |
| neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS) |
|
|
|
|
| model = MistralForSampling.from_pretrained( |
| model_name, |
| amp=amp, |
| batch_size=batch_size, |
| tp_degree=tp_degree, |
| n_positions=n_positions, |
| neuron_config=neuron_config |
| ) |
| model.to_neuron() |