| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
|
|
| import requests |
| import torch |
| from PIL import Image |
| from transformers import AutoProcessor |
|
|
| from nemo import lightning as nl |
| from nemo.collections import vlm |
| from nemo.utils import logging |
|
|
|
|
| def load_image(image_url: str) -> Image.Image: |
| |
| try: |
| response = requests.get(image_url, stream=True) |
| response.raise_for_status() |
| image = Image.open(response.raw) |
| return image |
| except requests.exceptions.RequestException as e: |
| print(f"Error loading image from {image_url}: {e}") |
| return None |
|
|
|
|
| def generate(model, processor, raw_image, text): |
| |
| messages = [ |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": "What are these?"}, |
| {"type": "image"}, |
| ], |
| } |
| ] |
|
|
| input_text = processor.apply_chat_template(messages, add_generation_prompt=True) |
| inputs = processor(input_text, raw_image, return_tensors='pt').to(0, torch.float32) |
|
|
| input_ids = inputs['input_ids'].cuda() |
| input_ids[input_ids == 32000] = -200 |
| media = inputs['pixel_values'].cuda() |
| media = media.reshape(media.size(1), 3, 336, 336) |
| position_ids = ( |
| torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device).unsqueeze(0).expand_as(input_ids) |
| ) |
|
|
| generated_ids = input_ids.clone() |
| width, height = raw_image.size |
| image_sizes = torch.tensor([[height, width]], dtype=torch.long).cuda() |
|
|
| for _ in range(20): |
| with torch.no_grad(): |
| attention_mask = (input_ids != 0).long().cuda() |
| output = model( |
| media=media, |
| input_ids=input_ids, |
| position_ids=position_ids, |
| image_sizes=image_sizes, |
| num_media_tiles=[media.size(0)], |
| attention_mask=attention_mask, |
| ) |
| next_token_ids = torch.argmax(output[:, -1], dim=-1, keepdim=True) |
|
|
| generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1) |
|
|
| input_ids = generated_ids |
| position_ids = ( |
| torch.arange(input_ids.size(1), dtype=torch.long, device=input_ids.device) |
| .unsqueeze(0) |
| .expand_as(input_ids) |
| ) |
| print(f"next_token_ids {next_token_ids}") |
|
|
| |
| if next_token_ids.item() == processor.tokenizer.eos_token_id: |
| print(f"breaking") |
| break |
| generated_ids[generated_ids == -200] = 0 |
| generated_texts = processor.tokenizer.batch_decode(generated_ids, skip_special_tokens=False) |
| logging.info("======== GENERATED TEXT OUTPUT ========") |
| logging.info(f"{generated_texts}") |
| logging.info("=======================================") |
|
|
|
|
| def main(args) -> None: |
| |
| model_id = 'llava-hf/llava-v1.6-vicuna-7b-hf' |
| strategy = nl.MegatronStrategy( |
| tensor_model_parallel_size=args.tp_size, |
| ckpt_load_optimizer=False, |
| ckpt_save_optimizer=False, |
| ) |
| trainer = nl.Trainer( |
| devices=args.tp_size, |
| max_steps=1000, |
| accelerator="gpu", |
| strategy=strategy, |
| plugins=nl.MegatronMixedPrecision(precision="bf16-mixed"), |
| val_check_interval=1000, |
| limit_val_batches=50, |
| ) |
|
|
| processor = AutoProcessor.from_pretrained(model_id) |
| tokenizer = processor.tokenizer |
|
|
| fabric = trainer.to_fabric() |
|
|
| if args.load_from_hf: |
| model = fabric.import_model("hf://llava-hf/llava-v1.6-vicuna-7b-hf", vlm.LlavaNextModel) |
| else: |
| model = vlm.LlavaNextModel(vlm.LlavaNextConfig7B(), tokenizer=tokenizer) |
| model = fabric.load_model(args.local_model_path, model) |
|
|
| model = model.module.cuda() |
| model.eval() |
| model = model.to(torch.bfloat16) |
|
|
| |
| raw_image = load_image(args.image_url) |
| if raw_image is None: |
| return |
|
|
| generate(model, processor, raw_image=raw_image, text="What are these?") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Llava Next Generation example") |
| parser.add_argument( |
| "--load_from_hf", |
| action="store_true", |
| help="Flag to indicate whether to load the model from Hugging Face hub.", |
| ) |
| parser.add_argument( |
| "--local_model_path", |
| type=str, |
| default=None, |
| help="Local path to the model if not loading from Hugging Face.", |
| ) |
| parser.add_argument( |
| "--image_url", |
| type=str, |
| |
| default="http://images.cocodataset.org/val2017/000000039769.jpg", |
| help="URL of the image to use for inference.", |
| ) |
| parser.add_argument("--devices", type=int, required=False, default=1) |
| parser.add_argument("--tp_size", type=int, required=False, default=1) |
|
|
| args = parser.parse_args() |
| main(args) |
|
|