| import torch |
| from PIL import Image |
| from mm_builder import load_pretrained_model |
| from mm_utils import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN |
| from mm_utils import conv_templates, SeparatorStyle |
| from mm_utils import disable_torch_init |
| from mm_utils import tokenizer_image_token, KeywordsStoppingCriteria |
| from modeling_mmalaya import MMAlayaMPTForCausalLM |
| from transformers.generation.streamers import TextIteratorStreamer |
| import argparse |
|
|
|
|
| def main(args): |
| disable_torch_init() |
| conv_mode = "mmalaya_llama" |
| model_path = args.model_path |
| |
| tokenizer, model, image_processor, _ = load_pretrained_model( |
| model_path=model_path, |
| ) |
| prompts = [ |
| "这张图可能是在哪拍的?当去这里游玩时需要注意什么?", |
| "Where might this picture have been taken? What should you pay attention to when visiting here?" |
| ] |
|
|
| import time |
| time1 = time.time() |
| |
| for prompt in prompts: |
| |
| conv = conv_templates[conv_mode].copy() |
| inp = DEFAULT_IMAGE_TOKEN + '\n' + prompt |
| conv.append_message(conv.roles[0], inp) |
| conv.append_message(conv.roles[1], None) |
| prompt = conv.get_prompt() |
| |
| input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() |
| |
| stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 |
| if conv_mode == 'mmalaya_llama': |
| stop_str = conv.sep2 |
| keywords = [stop_str] |
| stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) |
| streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, timeout=20.0) |
| |
| image = Image.open('./chang_chen.jpg').convert("RGB") |
| image_tensor = image_processor(image, return_tensors='pt')['pixel_values'].half().cuda() |
| |
| with torch.inference_mode(): |
| generate_ids = model.generate( |
| inputs=input_ids, |
| images=image_tensor, |
| |
| |
| |
| max_new_tokens=1024, |
| |
| |
| use_cache=True, |
| stopping_criteria=[stopping_criteria], |
| ) |
| |
| input_token_len = input_ids.shape[1] |
| output = tokenizer.batch_decode( |
| generate_ids[:, input_token_len:], |
| skip_special_tokens=True, |
| clean_up_tokenization_spaces=False |
| )[0] |
| print(output) |
| |
| time2 = time.time() |
| print("cost seconds: ", time2 - time1) |
| print("cost seconds per sample: ", (time2 - time1) / len(prompts)) |
| |
|
|
| if __name__=="__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--model_path', type=str, default='/tmp/MMAlaya-v0.1.6.1') |
| args = parser.parse_args() |
| main(args) |
|
|
|
|
| """ |
| export PYTHONPATH=$PYTHONPATH:/tmp/MMAlaya |
| CUDA_VISIBLE_DEVICES=0 python inference.py --model_path /tmp/MMAlaya-v0.1.6.1 |
| """ |