| import json |
| from argparse import ArgumentParser |
|
|
| import datasets |
| import torch |
| import transformers |
| from transformers import AutoModelForCausalLM, BatchEncoding |
|
|
| """ |
| Usage examples (with the best batch sizes on A100-80GB-400W) |
| ============================================================ |
| python -m benchmark_hf_model --model_name_or_path="Deci/DeciLM-7B" --batch_size=352 |
| python -m benchmark_hf_model --model_name_or_path="mistralai/Mistral-7B-v0.1" --batch_size=192 --model_kwargs_json='{"use_flash_attention_2": true}' |
| python -m benchmark_hf_model --model_name_or_path="meta-llama/Llama-2-7b-hf" --batch_size=48 --model_kwargs_json='{"use_flash_attention_2": true}' |
| """ |
|
|
|
|
| def parse_args(): |
| parser = ArgumentParser() |
|
|
| parser.add_argument( |
| "--model_name_or_path", |
| type=str, |
| required=True, |
| ) |
| parser.add_argument( |
| "--warmup_iters", |
| type=int, |
| default=10, |
| ) |
| parser.add_argument( |
| "--iterations", |
| type=int, |
| default=5, |
| ) |
| parser.add_argument( |
| "--batch_size", |
| type=int, |
| default=32, |
| ) |
| parser.add_argument( |
| "--prompt_length", |
| type=int, |
| default=512, |
| ) |
| parser.add_argument( |
| "--max_new_tokens", |
| type=int, |
| default=512, |
| ) |
| parser.add_argument( |
| "--precision", |
| type=str, |
| default="bf16", |
| help="Model precision, from: fp32, fp16 or bf16", |
| ) |
| parser.add_argument( |
| "--model_kwargs_json", |
| type=str, |
| default=None, |
| ) |
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| transformers.logging.set_verbosity_error() |
| datasets.logging.set_verbosity_error() |
|
|
| dict_precisions = { |
| "fp32": torch.float32, |
| "fp16": torch.float16, |
| "bf16": torch.bfloat16, |
| } |
| if args.precision not in dict_precisions: |
| raise ValueError( |
| f"Non valid precision {args.precision}, choose from: fp16, fp32, bf16" |
| ) |
| dtype = dict_precisions[args.precision] |
|
|
| model_kwargs = {} |
| if args.model_kwargs_json is not None: |
| model_kwargs = json.loads(args.model_kwargs_json) |
|
|
| print(f"loading model...") |
| model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path, trust_remote_code=True, |
| torch_dtype=dtype, **model_kwargs) |
| try: |
| print(model.model.layers[0].self_attn) |
| except: |
| print("couldn't print the model's attention module") |
|
|
| starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) |
| model.cuda() |
| model.eval() |
|
|
| prompt = torch.ones(args.prompt_length, dtype=torch.long) |
| inputs = BatchEncoding({"input_ids": prompt.repeat(args.batch_size, 1)}) |
| inputs = inputs.to(model.device) |
|
|
| |
| print(f"warming up for {args.warmup_iters} iterations...") |
| for _ in range(args.warmup_iters): |
| with torch.no_grad(): |
| _ = model.generate( |
| **inputs, |
| max_new_tokens=1, |
| do_sample=False, |
| eos_token_id=-1234, |
| ) |
| print('finished warmup') |
| torch.cuda.synchronize() |
|
|
| print( |
| f"prefill ({args.prompt_length} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}) + generation ({args.max_new_tokens} tokens{f' x {args.batch_size} batch' if args.batch_size > 1 else ''}):") |
| tokens_generated = args.max_new_tokens * args.batch_size |
| prefill_and_generation = [] |
| for gen_iter in range(args.iterations): |
| starter.record() |
| with torch.no_grad(): |
| _ = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| do_sample=False, |
| eos_token_id=-1234, |
| ) |
| ender.record() |
| torch.cuda.synchronize() |
| t = starter.elapsed_time(ender) / 1000 |
| prefill_and_generation.append(t) |
| print(f" iter {gen_iter + 1}: {t:.03f} sec total, {tokens_generated / t:.02f} generated tokens/sec") |
| aver = sum(prefill_and_generation) / len(prefill_and_generation) |
| print(f" average: {aver:.03f} sec total, {tokens_generated / aver:.02f} generated tokens/sec") |
| print(f"These results are obtained for model '{args.model_name_or_path}' with {args.batch_size=}.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|