| import torch |
|
|
|
|
| @torch.no_grad() |
| def generate( |
| vlm, |
| samples, |
| use_nucleus_sampling=False, |
| num_beams=5, |
| max_length=256, |
| min_length=1, |
| top_p=0.9, |
| repetition_penalty=1.5, |
| length_penalty=1.0, |
| num_captions=1, |
| temperature=1, |
| ): |
| if "prompt" in samples.keys(): |
| prompt = samples["prompt"] |
| else: |
| prompt = vlm.prompt |
|
|
| image = samples["image"] |
|
|
| bs = image.size(0) |
|
|
| if isinstance(prompt, str): |
| prompt = [prompt] * bs |
| else: |
| assert len(prompt) == bs, "The number of prompts must be equal to the batch size." |
|
|
| |
| if "ocr_tokens" in samples.keys() and "{}" in prompt[0]: |
| prompt = [p.format(', '.join(samples['ocr_tokens'][i][:30])) for i, p in enumerate(prompt)] |
|
|
| query_tokens = vlm.query_tokens.expand(bs, -1, -1) |
| if vlm.qformer_text_input: |
| |
| |
| |
|
|
| text_Qformer = vlm.tokenizer( |
| prompt, |
| padding='longest', |
| truncation=True, |
| max_length=vlm.max_txt_len, |
| return_tensors="pt", |
| ).to(image.device) |
| query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device) |
| Qformer_atts = torch.cat([query_atts,text_Qformer.attention_mask],dim=1) |
|
|
| |
| if image.dim() == 5: |
| inputs_t5, atts_t5 = [], [] |
| for j in range(image.size(2)): |
| this_frame = image[:,:,j,:,:] |
| with vlm.maybe_autocast(): |
| frame_embeds = vlm.ln_vision(vlm.visual_encoder(this_frame)) |
| frame_atts = torch.ones(frame_embeds.size()[:-1], dtype=torch.long).to(image.device) |
|
|
| if vlm.qformer_text_input: |
| frame_query_output = vlm.Qformer.bert( |
| text_Qformer.input_ids, |
| attention_mask = Qformer_atts, |
| query_embeds=query_tokens, |
| encoder_hidden_states=frame_embeds, |
| encoder_attention_mask=frame_atts, |
| return_dict=True, |
| ) |
| else: |
| frame_query_output = vlm.Qformer.bert( |
| query_embeds=query_tokens, |
| encoder_hidden_states=frame_embeds, |
| encoder_attention_mask=frame_atts, |
| return_dict=True, |
| ) |
|
|
| frame_inputs_t5 = vlm.t5_proj(frame_query_output.last_hidden_state[:,:query_tokens.size(1),:]) |
| frame_atts_t5 = torch.ones(frame_inputs_t5.size()[:-1], dtype=torch.long).to(image.device) |
| inputs_t5.append(frame_inputs_t5) |
| atts_t5.append(frame_atts_t5) |
| inputs_t5 = torch.cat(inputs_t5, dim=1) |
| atts_t5 = torch.cat(atts_t5, dim=1) |
| else: |
| with vlm.maybe_autocast(): |
| image_embeds = vlm.ln_vision(vlm.visual_encoder(image)) |
| image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device) |
|
|
| if vlm.qformer_text_input: |
| query_output = vlm.Qformer.bert( |
| text_Qformer.input_ids, |
| attention_mask=Qformer_atts, |
| query_embeds=query_tokens, |
| encoder_hidden_states=image_embeds, |
| encoder_attention_mask=image_atts, |
| return_dict=True, |
| ) |
| else: |
| query_output = vlm.Qformer.bert( |
| query_embeds=query_tokens, |
| encoder_hidden_states=image_embeds, |
| encoder_attention_mask=image_atts, |
| return_dict=True, |
| ) |
|
|
| inputs_t5 = vlm.t5_proj(query_output.last_hidden_state[:,:query_tokens.size(1),:]) |
| atts_t5 = torch.ones(inputs_t5.size()[:-1], dtype=torch.long).to(image.device) |
|
|
| input_tokens = vlm.t5_tokenizer( |
| prompt, |
| padding="longest", |
| return_tensors="pt" |
| ).to(image.device) |
|
|
| encoder_atts = torch.cat([atts_t5, input_tokens.attention_mask], dim=1) |
|
|
| with vlm.maybe_autocast(dtype=torch.bfloat16): |
| inputs_embeds = vlm.t5_model.encoder.embed_tokens(input_tokens.input_ids) |
| inputs_embeds = torch.cat([inputs_t5, inputs_embeds], dim=1) |
|
|
| outputs = vlm.t5_model.generate( |
| return_dict_in_generate=True, |
| output_scores=True, |
| inputs_embeds=inputs_embeds, |
| attention_mask=encoder_atts, |
| do_sample=use_nucleus_sampling, |
| top_p=top_p, |
| temperature=temperature, |
| num_beams=num_beams, |
| max_new_tokens=max_length, |
| min_length=min_length, |
| repetition_penalty=repetition_penalty, |
| length_penalty=length_penalty, |
| num_return_sequences=num_captions, |
| ) |
| output_text = vlm.t5_tokenizer.batch_decode( |
| outputs.sequences, skip_special_tokens=True |
| ) |
|
|
| return output_text, outputs.sequences_scores |
|
|