| import time |
| import os |
| import base64 |
| from io import BytesIO |
| import concurrent.futures |
| import logging |
| import numpy as np |
| from PIL import Image |
| import torch |
| import torch.nn as nn |
| import torch_neuronx |
| import transformers |
| from transformers import AutoConfig, AutoTokenizer |
| from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN |
| from llava.conversation import conv_templates |
| from llava.model.utils import LayerNorm |
| from llava.mm_utils import tokenizer_image_token |
| from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor |
| from transformers_neuronx import MistralForSampling, GQA, NeuronConfig, QuantizationConfig |
| from typing import Dict, Optional, Any |
| from fastapi import FastAPI, Request, HTTPException |
| |
| transformers.logging.set_verbosity_error() |
| NUM_SEGMENTS = 10 |
| WEIGHT_ROOT = '/home/ubuntu/' |
| CONFIG_DIR = os.path.join(WEIGHT_ROOT, "llava-mistral_videollava_ptv12_250k_samep_only_sopv2_mistralv2_scratch") |
| NEURON_VISION_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", "neuron_eva_vit_batch7.pth") |
| NEURON_BERT_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", "neuron_bert.pth") |
| PROJECTOR_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'projector.pth') |
| EMBED_TOKEN_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'embed_tokens.pth') |
| QUERY_TOKEN_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'query_tokens.pth') |
| LAYERNORM_SAVE_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'ln_state_dict.pth') |
| POSITION_ENCODING_SAVE_PATH = os.path.join(WEIGHT_ROOT, "inf2_weights", 'frame_position_encoding.pth') |
| COMPILED_MODEL_PATH = os.path.join(WEIGHT_ROOT, 'mistral-compiled') |
| class MistralModel: |
| def __init__(self, model_name): |
| self.neuron_config = NeuronConfig(group_query_attention=GQA.SHARD_OVER_HEADS, |
| quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16')) |
| self.model_name = model_name |
| self.amp = 'bf16' |
| self.batch_size = 1 |
| self.tp_degree = 2 |
| self.n_positions = 4096 |
| self.context_length_estimate_start = 2289 |
| self.context_length_estimate = [self.context_length_estimate_start, 4096] |
| self.model = MistralForSampling.from_pretrained( |
| self.model_name, |
| amp=self.amp, |
| batch_size=self.batch_size, |
| tp_degree=self.tp_degree, |
| n_positions=self.n_positions, |
| neuron_config=self.neuron_config, |
| context_length_estimate=self.context_length_estimate |
| ) |
| self.model.load(COMPILED_MODEL_PATH) |
| self.model.to_neuron() |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
| def generate(self, inputs: torch.tensor, parameters: Optional[Dict[str, Any]] = None) -> str: |
| try: |
| max_new_tokens = parameters.get("max_new_tokens", 256) |
| top_k = parameters.get("top_k", 100) |
| top_p = parameters.get("top_p", 0.1) |
| temperature = parameters.get("temperature", 0.1) |
| no_repeat_ngram_size = parameters.get("no_repeat_ngram_size", 3) |
| with torch.inference_mode(): |
| generated_sequence = self.model.sample(inputs, |
| sequence_length=min(self.n_positions, self.context_length_estimate_start + max_new_tokens), |
| start_ids=None, top_k=top_k, top_p=top_p, temperature=temperature, |
| no_repeat_ngram_size=no_repeat_ngram_size) |
| with concurrent.futures.ThreadPoolExecutor(16) as executor: |
| decoded_output = list(executor.map(self.tokenizer.decode, generated_sequence)) |
| generated_text = decoded_output[0].strip("</s>").strip() |
| return generated_text |
| except Exception as e: |
| logging.error(f"Error generating text: {e}") |
| raise |
| |
| app = FastAPI() |
| mistral_model = MistralModel(model_name=CONFIG_DIR) |
| processor = Blip2ImageTrainProcessor(image_size=224, is_training=False) |
| def generate_input_ids(tokenizer): |
| conv = conv_templates['thoth'].copy() |
| qs = "Please describe this video in detail." |
| qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs |
| conv.append_message(conv.roles[0], qs) |
| conv.append_message(conv.roles[1], None) |
| prompt = conv.get_prompt() |
| input_ids = tokenizer_image_token(prompt, tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze(0) |
| return input_ids |
| def uniform_sample(frames, num_segments): |
| indices = np.linspace(start=0, stop=len(frames) - 1, num=num_segments).astype( |
| int) |
| frames = [frames[ind] for ind in indices] |
| return frames |
| def image_open_byteio(byte_data): |
| output = Image.open(BytesIO(byte_data)).convert('RGB') |
| return output |
| def process_anyres_image(image): |
| new_image = Image.new('RGB', (224, 224), (0, 0, 0)) |
| new_image.paste(image.resize((224, 224)), (0, 0)) |
| torch_stack = processor.preprocess(new_image).repeat(7,1,1,1) |
| return torch_stack |
| |
| config = AutoConfig.from_pretrained(CONFIG_DIR, trust_remote_code=True) |
| tokenizer = mistral_model.tokenizer |
| input_ids = generate_input_ids(tokenizer) |
| input_ids = input_ids[0].to('cpu') |
| with torch_neuronx.experimental.neuron_cores_context(start_nc=0, nc_count=2): |
| vision_module_neuron = torch.jit.load(NEURON_VISION_PATH) |
| vision_module_neuron = vision_module_neuron.eval() |
| |
| padding_idx = config.pad_token_id |
| embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx) |
| embed_weight = torch.load(EMBED_TOKEN_PATH) |
| embed_tokens.load_state_dict(embed_weight) |
| embed_tokens = embed_tokens.eval() |
| embed_tokens.to(torch.float16).to('cpu') |
| |
| vision_width = 1408 |
| ln_vision = LayerNorm(vision_width) |
| ln_vision_weight = torch.load(LAYERNORM_SAVE_PATH) |
| ln_vision.load_state_dict(ln_vision_weight) |
| ln_vision = ln_vision.eval() |
| ln_vision = ln_vision.to(torch.float32) |
| num_query_token = 32 |
| query_tokens = nn.Parameter( |
| torch.zeros(1, num_query_token, 768) |
| ) |
| query_tokens.data.normal_(mean=0.0, std=0.02) |
| query_tokens_weight = torch.load(QUERY_TOKEN_PATH)['query_tokens'] |
| query_tokens.data = query_tokens_weight |
| frame_position_encoding = nn.Embedding(10, 768) |
| frame_position_encoding_weight = torch.load(POSITION_ENCODING_SAVE_PATH) |
| frame_position_encoding.load_state_dict(frame_position_encoding_weight) |
| projector = nn.Linear(config.mm_hidden_size, config.hidden_size) |
| projector_weight = torch.load(PROJECTOR_PATH) |
| projector.load_state_dict(projector_weight) |
| neuron_bert = torch.jit.load(NEURON_BERT_PATH) |
| neuron_bert = neuron_bert.eval() |
| @app.post("/generate") |
| async def generate(request: Request) -> Dict[str, str]: |
| """ |
| Generate text using the Mistral model. |
| Args: |
| request (Request): The incoming request object. |
| Returns: |
| Dict[str, str]: A dictionary containing the generated text or an error message. |
| """ |
| try: |
| s1 = time.time() |
| request_payload = await request.json() |
| request_payload_keys = request_payload.keys() |
| s11 = time.time() |
| print("request_payload_keys time: ", s11-s1) |
| if "images" in request_payload_keys: |
| packed_data = request_payload.get("images") |
| s12 = time.time() |
| print("packed_data time: ", s12-s11) |
| with concurrent.futures.ThreadPoolExecutor(10) as executor: |
| unpacked_data = list(executor.map(base64.b64decode, packed_data)) |
| s13 = time.time() |
| print("unpacked_data time: ", s13-s12) |
| with concurrent.futures.ThreadPoolExecutor(10) as executor: |
| input_images = list(executor.map(image_open_byteio, unpacked_data)) |
| s14 = time.time() |
| print("image_open_byteio time: ", s14-s13) |
| input_images = uniform_sample(input_images, NUM_SEGMENTS) |
| s15 = time.time() |
| print("uniform_sample time: ", s15-s14) |
| with concurrent.futures.ThreadPoolExecutor(10) as executor: |
| new_images = list(executor.map(process_anyres_image, input_images)) |
| input_images = torch.stack(new_images, dim=0) |
| s16 = time.time() |
| print("process_images_v2 time: ", s16-s15) |
| print("s1 - input_images time: ", time.time() - s1) |
| si = time.time() |
| with torch.inference_mode(): |
| with concurrent.futures.ThreadPoolExecutor(2) as executor: |
| image_features_list = list(executor.map(vision_module_neuron, input_images)) |
| image_features = torch.cat(image_features_list, dim=0) |
| print("si - image_features neuron time: ", time.time() - si) |
| s2 = time.time() |
| image_features = ln_vision(image_features) |
| attn_mask = torch.ones(image_features.size()[:-1], dtype=torch.long).to(image_features.device) |
| query_tokens_inputs = query_tokens.expand(image_features.shape[0], -1, -1) |
| image_features = neuron_bert( |
| query_tokens_inputs.to(torch.float32), |
| image_features.to(torch.float32), |
| attn_mask.to(torch.int64) |
| )["last_hidden_state"].to(torch.float32) |
| frame_ids = torch.arange(input_images.shape[0], dtype=torch.long, device=image_features.device).unsqueeze(1) |
| frame_ids = frame_ids.repeat(1, input_images.shape[1]).flatten(0, 1) |
| image_features += frame_position_encoding(frame_ids).unsqueeze(-2) |
| projected_features = projector(image_features) |
| image_features = projected_features.flatten(0, 1) |
| print(image_features.shape) |
| image_features.to(device='cpu', dtype=torch.float16) |
| print("s2 - image_features prepare time: ", time.time() - s2) |
| s3 = time.time() |
| vision_token_indice = torch.where(input_ids == MM_TOKEN_INDEX)[0][0] |
| pre_text_token = embed_tokens(input_ids[:vision_token_indice]) |
| post_text_token = embed_tokens(input_ids[vision_token_indice + 1:]) |
| print("s3 - text_token time: ", time.time() - s3) |
| s4 = time.time() |
| inputs_embeds = torch.cat([pre_text_token, image_features, post_text_token]).unsqueeze(0) |
| print("s4 - inputs time: ", time.time() - s4) |
| else: |
| raise HTTPException(status_code=400, detail="Please provide correct input") |
| s5 = time.time() |
| parameters = request_payload.get("parameters", {}) |
| generated_text = mistral_model.generate(inputs_embeds, parameters) |
| print("s5 - generated_text time: ", time.time() - s5) |
| print("total inference time: ", time.time() - si) |
| return {"generated_text": generated_text} |
| except Exception as e: |
| raise HTTPException(status_code=500, detail=f"Error generating text: {str(e)}") |