| import torch |
| import torch.nn as nn |
| import numpy as np |
| from typing import Dict, List, Optional, Tuple, Union |
| from transformers.models.mask2former.modeling_mask2former import ( |
| Mask2FormerMaskedAttentionDecoderOutput, Mask2FormerModelOutput, |
| Mask2FormerForUniversalSegmentationOutput, Mask2FormerMLPPredictionHead, |
| sample_point, pair_wise_sigmoid_cross_entropy_loss, pair_wise_dice_loss, |
| sigmoid_cross_entropy_loss, dice_loss) |
| from torch import Tensor |
| import torch.nn.functional as F |
|
|
| from transformers.file_utils import is_scipy_available |
|
|
| if is_scipy_available(): |
| from scipy.optimize import linear_sum_assignment |
|
|
|
|
| def get_classification_logits(x, text_classifier, logit_scale): |
| |
| |
| |
| |
| x = F.normalize(x, dim=-1) |
| text_classifier = F.normalize(text_classifier, dim=-1) |
| logit_scale = torch.clamp(logit_scale.exp(), max=100) |
| pred_logits = logit_scale * x @ text_classifier.T |
| return pred_logits |
|
|
|
|
| def _post_init(self): |
| self.class_embed = Mask2FormerMLPPredictionHead(self.config.hidden_dim, self.config.hidden_dim, self.config.hidden_dim, 3) |
| self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
|
| def ov_class_predictor(self, x, text_classifier): |
| x = self.class_embed(x) |
| all_pred_logits = [] |
| for per_x, per_text_classifier in zip(x, text_classifier): |
| per_pred_logits = get_classification_logits(per_x.unsqueeze(0), per_text_classifier, self.logit_scale) |
| all_pred_logits.append(per_pred_logits.squeeze(0)) |
|
|
| return all_pred_logits |
|
|
|
|
|
|
| def Mask2FormerLoss_loss_labels( |
| self, class_queries_logits: Tensor, class_labels: List[Tensor], indices: Tuple[np.array] |
| ) -> Dict[str, Tensor]: |
| batch_size = len(class_queries_logits) |
| num_queries = class_queries_logits[0].shape[0] |
| all_ce_loss = [] |
| for i in range(batch_size): |
| num_labels_plus1 = class_queries_logits[i].shape[-1] |
| empty_weight = torch.ones(num_labels_plus1) |
| empty_weight[-1] = self.eos_coef |
| empty_weight = empty_weight.to(class_queries_logits[i].device).to(class_queries_logits[i].dtype) |
| criterion = nn.CrossEntropyLoss(weight=empty_weight, reduction='none') |
| target_classes_o = class_labels[i][indices[i][1]] |
| target_classes = torch.full( |
| (num_queries, ), fill_value=num_labels_plus1-1, dtype=torch.int64, device=class_queries_logits[i].device) |
| target_classes[indices[i][0]] = target_classes_o.to(class_queries_logits[i].device) |
| target_classes = target_classes.unsqueeze(0) |
| pred_logits = class_queries_logits[i].unsqueeze(0).transpose(1, 2) |
| loss_ce = criterion(pred_logits, target_classes) |
| all_ce_loss.append(loss_ce) |
| losses = {"loss_cross_entropy": torch.cat(all_ce_loss, dim=-1).mean()} |
| return losses |
|
|
| def Mask2FormerLoss_loss_masks( |
| self, |
| masks_queries_logits: torch.Tensor, |
| mask_labels: List[torch.Tensor], |
| indices: Tuple[np.array], |
| num_masks: int |
| ) -> Dict[str, torch.Tensor]: |
| src_idx = self._get_predictions_permutation_indices(indices) |
| tgt_idx = self._get_targets_permutation_indices(indices) |
| |
| pred_masks = masks_queries_logits[src_idx] |
| |
| |
| target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) |
| target_masks = target_masks[tgt_idx] |
|
|
| |
| pred_masks = pred_masks[:, None] |
| target_masks = target_masks[:, None] |
|
|
| |
| with torch.no_grad(): |
| point_coordinates = self.sample_points_using_uncertainty( |
| pred_masks, |
| lambda logits: self.calculate_uncertainty(logits), |
| self.num_points, |
| self.oversample_ratio, |
| self.importance_sample_ratio, |
| ) |
| point_labels = sample_point(target_masks.to(torch.bfloat16), point_coordinates.to(torch.bfloat16), align_corners=False).squeeze(1) |
| |
| point_logits = sample_point(pred_masks, point_coordinates.to(pred_masks.dtype), align_corners=False).squeeze(1) |
|
|
| losses = { |
| "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), |
| "loss_dice": dice_loss(point_logits, point_labels, num_masks), |
| } |
|
|
| del pred_masks |
| del target_masks |
| return losses |
|
|
| def Mask2FormerLoss_sample_points_using_uncertainty( |
| self, |
| logits: torch.Tensor, |
| uncertainty_function, |
| num_points: int, |
| oversample_ratio: int, |
| importance_sample_ratio: float, |
| ) -> torch.Tensor: |
| |
| num_boxes = logits.shape[0] |
| num_points_sampled = int(num_points * oversample_ratio) |
|
|
| |
| point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device) |
| |
| point_logits = sample_point(logits, point_coordinates.to(logits.dtype), align_corners=False) |
| |
| point_uncertainties = uncertainty_function(point_logits) |
|
|
| num_uncertain_points = int(importance_sample_ratio * num_points) |
| num_random_points = num_points - num_uncertain_points |
|
|
| idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] |
| shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device) |
| idx += shift[:, None] |
| point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) |
|
|
| if num_random_points > 0: |
| point_coordinates = torch.cat( |
| [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)], |
| dim=1, |
| ) |
| return point_coordinates |
|
|
|
|
|
|
| @torch.no_grad() |
| def Mask2FormerHungarianMatcher_forward( |
| self, |
| masks_queries_logits: torch.Tensor, |
| class_queries_logits: torch.Tensor, |
| mask_labels: torch.Tensor, |
| class_labels: torch.Tensor, |
| ) -> List[Tuple[Tensor]]: |
| indices: List[Tuple[np.array]] = [] |
|
|
| |
| batch_size = masks_queries_logits.shape[0] |
| for i in range(batch_size): |
| pred_probs = class_queries_logits[i].softmax(-1) |
| pred_mask = masks_queries_logits[i] |
|
|
| |
| cost_class = -pred_probs[:, class_labels[i]] |
| target_mask = mask_labels[i].to(pred_mask) |
| target_mask = target_mask[:, None] |
| pred_mask = pred_mask[:, None] |
|
|
| |
| point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device) |
|
|
| target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1).to(target_mask.dtype) |
| target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) |
|
|
| pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1).to(pred_mask.dtype) |
| pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) |
|
|
| |
| cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) |
| |
| cost_dice = pair_wise_dice_loss(pred_mask, target_mask) |
| |
| cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice |
| |
| cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10)) |
| cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10)) |
| cost_matrix = torch.nan_to_num(cost_matrix, 0) |
| |
| assigned_indices: Tuple[np.array] = linear_sum_assignment(cost_matrix.to(torch.float32).cpu()) |
| indices.append(assigned_indices) |
|
|
| |
| matched_indices = [ |
| (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices |
| ] |
| return matched_indices |
|
|
|
|
|
|
|
|
| def Mask2FormerMaskedAttentionDecoder_forward_first3layers( |
| self, |
| inputs_embeds: torch.Tensor = None, |
| multi_stage_positional_embeddings: torch.Tensor = None, |
| pixel_embeddings: torch.Tensor = None, |
| encoder_hidden_states: torch.Tensor = None, |
| query_position_embeddings: torch.Tensor = None, |
| feature_size_list: List = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| r""" |
| Args: |
| inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): |
| The query embeddings that are passed into the decoder. |
| multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): |
| Position embeddings that are added to the keys in each cross(masked)-attention layer. |
| pixel_embeddings (`torch.FloatTensor`): |
| Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel |
| Decoder. |
| query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): |
| , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. |
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the encoder. Used in the |
| cross(masked)-attention of the decoder. |
| feature_size_list (`List[torch.Size]`): |
| This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if inputs_embeds is not None: |
| hidden_states = inputs_embeds |
|
|
| |
| intermediate = () |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| attentions = () if output_attentions else None |
|
|
| |
| intermediate_mask_predictions = () |
|
|
| intermediate_hidden_states = self.layernorm(inputs_embeds) |
| intermediate += (intermediate_hidden_states,) |
|
|
| predicted_mask, attention_mask = self.mask_predictor( |
| intermediate_hidden_states, pixel_embeddings, feature_size_list[0] |
| ) |
| intermediate_mask_predictions += (predicted_mask,) |
|
|
| for idx, decoder_layer in enumerate(self.layers[:3]): |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| dropout_probability = torch.rand([]) |
|
|
| if self.training and (dropout_probability < self.layerdrop): |
| continue |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| None, |
| None, |
| output_attentions, |
| ) |
|
|
| else: |
| level_index = idx % self.num_feature_levels |
|
|
| where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) |
| |
| attention_mask = attention_mask * where.unsqueeze(-1) |
|
|
| layer_outputs = decoder_layer( |
| hidden_states, |
| level_index=level_index, |
| position_embeddings=multi_stage_positional_embeddings, |
| query_position_embeddings=query_position_embeddings, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| ) |
|
|
| intermediate_hidden_states = self.layernorm(layer_outputs[0]) |
|
|
| predicted_mask, attention_mask = self.mask_predictor( |
| intermediate_hidden_states, |
| pixel_embeddings, |
| feature_size_list[(idx + 1) % self.num_feature_levels], |
| ) |
|
|
| intermediate_mask_predictions += (predicted_mask,) |
|
|
| |
| intermediate += (intermediate_hidden_states,) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| attentions += (layer_outputs[1],) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| hidden_states = hidden_states.transpose(1, 0) |
| if not return_dict: |
| outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] |
| return tuple(v for v in outputs if v is not None) |
|
|
| return Mask2FormerMaskedAttentionDecoderOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=attentions, |
| intermediate_hidden_states=intermediate, |
| masks_queries_logits=intermediate_mask_predictions, |
| ) |
|
|
|
|
| def Mask2FormerMaskedAttentionDecoder_forward_last3layers( |
| self, |
| inputs_embeds: torch.Tensor = None, |
| multi_stage_positional_embeddings: torch.Tensor = None, |
| pixel_embeddings: torch.Tensor = None, |
| encoder_hidden_states: torch.Tensor = None, |
| query_position_embeddings: torch.Tensor = None, |
| feature_size_list: List = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| r""" |
| Args: |
| inputs_embeds (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): |
| The query embeddings that are passed into the decoder. |
| multi_stage_positional_embeddings (`torch.FloatTensor` of shape `(height*width, batch_size, num_channels)`): |
| Position embeddings that are added to the keys in each cross(masked)-attention layer. |
| pixel_embeddings (`torch.FloatTensor`): |
| Tensor of shape `(batch_size, num_channels, height, width)`, 1/4 scale features from the last Pixel |
| Decoder. |
| query_position_embeddings (`torch.FloatTensor` of shape `(num_queries, batch_size, hidden_size)`): |
| , *optional*): Position embeddings that are added to the queries and keys in each self-attention layer. |
| encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, encoder_sequence_length, hidden_size)`): |
| Sequence of hidden-states at the output of the last layer of the encoder. Used in the |
| cross(masked)-attention of the decoder. |
| feature_size_list (`List[torch.Size]`): |
| This is a list containing shapes (height & width) of multi-scale features from the Pixel Decoder. |
| output_attentions (`bool`, *optional*): |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
| returned tensors for more detail. |
| output_hidden_states (`bool`, *optional*): |
| Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
| for more detail. |
| return_dict (`bool`, *optional*): |
| Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if inputs_embeds is not None: |
| hidden_states = inputs_embeds |
|
|
| |
| intermediate = () |
|
|
| |
| all_hidden_states = () if output_hidden_states else None |
| attentions = () if output_attentions else None |
|
|
| |
| intermediate_mask_predictions = () |
|
|
| intermediate_hidden_states = self.layernorm(inputs_embeds) |
| intermediate += (intermediate_hidden_states,) |
|
|
| predicted_mask, attention_mask = self.mask_predictor( |
| intermediate_hidden_states, pixel_embeddings, feature_size_list[0] |
| ) |
| intermediate_mask_predictions += (predicted_mask,) |
|
|
| for _idx, decoder_layer in enumerate(self.layers[3:]): |
| idx = _idx + 3 |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| dropout_probability = torch.rand([]) |
|
|
| if self.training and (dropout_probability < self.layerdrop): |
| continue |
|
|
| if self.gradient_checkpointing and self.training: |
| layer_outputs = self._gradient_checkpointing_func( |
| decoder_layer.__call__, |
| hidden_states, |
| attention_mask, |
| encoder_hidden_states, |
| None, |
| None, |
| output_attentions, |
| ) |
|
|
| else: |
| level_index = idx % self.num_feature_levels |
|
|
| where = (attention_mask.sum(-1) != attention_mask.shape[-1]).to(attention_mask.dtype) |
| |
| attention_mask = attention_mask * where.unsqueeze(-1) |
|
|
| layer_outputs = decoder_layer( |
| hidden_states, |
| level_index=level_index, |
| position_embeddings=multi_stage_positional_embeddings, |
| query_position_embeddings=query_position_embeddings, |
| encoder_hidden_states=encoder_hidden_states, |
| encoder_attention_mask=attention_mask, |
| output_attentions=output_attentions, |
| ) |
|
|
| intermediate_hidden_states = self.layernorm(layer_outputs[0]) |
|
|
| predicted_mask, attention_mask = self.mask_predictor( |
| intermediate_hidden_states, |
| pixel_embeddings, |
| feature_size_list[(idx + 1) % self.num_feature_levels], |
| ) |
|
|
| intermediate_mask_predictions += (predicted_mask,) |
|
|
| |
| intermediate += (intermediate_hidden_states,) |
|
|
| hidden_states = layer_outputs[0] |
|
|
| if output_attentions: |
| attentions += (layer_outputs[1],) |
|
|
| |
| if output_hidden_states: |
| all_hidden_states += (hidden_states,) |
|
|
| hidden_states = hidden_states.transpose(1, 0) |
| if not return_dict: |
| outputs = [hidden_states, all_hidden_states, attentions, intermediate, intermediate_mask_predictions] |
| return tuple(v for v in outputs if v is not None) |
|
|
| return Mask2FormerMaskedAttentionDecoderOutput( |
| last_hidden_state=hidden_states, |
| hidden_states=all_hidden_states, |
| attentions=attentions, |
| intermediate_hidden_states=intermediate, |
| masks_queries_logits=intermediate_mask_predictions, |
| ) |
|
|
|
|
| def Mask2FormerTransformerModule_forward_first_part( |
| self, |
| multi_scale_features: List[Tensor], |
| mask_features: Tensor, |
| output_hidden_states: bool = False, |
| output_attentions: bool = False, |
| ) -> Mask2FormerMaskedAttentionDecoderOutput: |
| multi_stage_features = [] |
| multi_stage_positional_embeddings = [] |
| size_list = [] |
|
|
| for i in range(self.num_feature_levels): |
| size_list.append(multi_scale_features[i].shape[-2:]) |
| multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) |
| multi_stage_features.append( |
| self.input_projections[i](multi_scale_features[i]).flatten(2) |
| + self.level_embed.weight[i][None, :, None] |
| ) |
|
|
| |
| multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) |
| multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) |
|
|
| _, batch_size, _ = multi_stage_features[0].shape |
|
|
| |
| query_embeddings = self.queries_embedder.weight.unsqueeze(1).repeat(1, batch_size, 1) |
| query_features = self.queries_features.weight.unsqueeze(1).repeat(1, batch_size, 1) |
|
|
| decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_first3layers( |
| inputs_embeds=query_features, |
| multi_stage_positional_embeddings=multi_stage_positional_embeddings, |
| pixel_embeddings=mask_features, |
| encoder_hidden_states=multi_stage_features, |
| query_position_embeddings=query_embeddings, |
| feature_size_list=size_list, |
| output_hidden_states=output_hidden_states, |
| output_attentions=output_attentions, |
| return_dict=True, |
| ) |
|
|
| return decoder_output |
|
|
|
|
| def Mask2FormerTransformerModule_forward_second_part( |
| self, |
| query_features: Tensor, |
| query_embeddings: Tensor, |
| multi_scale_features: List[Tensor], |
| mask_features: Tensor, |
| output_hidden_states: bool = False, |
| output_attentions: bool = False, |
| ) -> Mask2FormerMaskedAttentionDecoderOutput: |
| multi_stage_features = [] |
| multi_stage_positional_embeddings = [] |
| size_list = [] |
|
|
| for i in range(self.num_feature_levels): |
| size_list.append(multi_scale_features[i].shape[-2:]) |
| multi_stage_positional_embeddings.append(self.position_embedder(multi_scale_features[i], None).flatten(2)) |
| multi_stage_features.append( |
| self.input_projections[i](multi_scale_features[i]).flatten(2) |
| + self.level_embed.weight[i][None, :, None] |
| ) |
|
|
| |
| multi_stage_positional_embeddings[-1] = multi_stage_positional_embeddings[-1].permute(2, 0, 1) |
| multi_stage_features[-1] = multi_stage_features[-1].permute(2, 0, 1) |
|
|
| _, batch_size, _ = multi_stage_features[0].shape |
|
|
| |
| |
| |
|
|
| decoder_output = self.decoder.Mask2FormerMaskedAttentionDecoder_forward_last3layers( |
| inputs_embeds=query_features, |
| multi_stage_positional_embeddings=multi_stage_positional_embeddings, |
| pixel_embeddings=mask_features, |
| encoder_hidden_states=multi_stage_features, |
| query_position_embeddings=query_embeddings, |
| feature_size_list=size_list, |
| output_hidden_states=output_hidden_states, |
| output_attentions=output_attentions, |
| return_dict=True, |
| ) |
|
|
| return decoder_output |
|
|
|
|
| def Mask2FormerModel_forward_first_part( |
| self, |
| pixel_values: Tensor, |
| pixel_mask: Optional[Tensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Mask2FormerModelOutput: |
| r""" |
| Returns: |
| `Mask2FormerModelOutput` |
| |
| Examples: |
| ```python |
| >>> import torch |
| >>> from PIL import Image |
| >>> import requests |
| >>> from transformers import AutoImageProcessor, Mask2FormerModel |
| |
| >>> # load image |
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| >>> image = Image.open(requests.get(url, stream=True).raw) |
| |
| >>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset |
| >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
| >>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
| >>> inputs = image_processor(image, return_tensors="pt") |
| |
| >>> # forward pass |
| >>> with torch.no_grad(): |
| ... outputs = model(**inputs) |
| |
| >>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size) |
| >>> print(outputs.transformer_decoder_last_hidden_state.shape) |
| torch.Size([1, 100, 256]) |
| ``` |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| batch_size, _, height, width = pixel_values.shape |
|
|
| if pixel_mask is None: |
| pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) |
|
|
| pixel_level_module_output = self.pixel_level_module( |
| pixel_values=pixel_values, output_hidden_states=output_hidden_states |
| ) |
|
|
| transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_first_part( |
| multi_scale_features=pixel_level_module_output.decoder_hidden_states, |
| mask_features=pixel_level_module_output.decoder_last_hidden_state, |
| output_hidden_states=True, |
| output_attentions=output_attentions, |
| ) |
|
|
| query_features = transformer_module_output.last_hidden_state |
| return query_features, pixel_level_module_output |
|
|
|
|
| def Mask2FormerModel_forward_second_part( |
| self, |
| query_features: Tensor, |
| query_embeddings: Tensor, |
| pixel_level_module_output, |
| pixel_values: Tensor, |
| pixel_mask: Optional[Tensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Mask2FormerModelOutput: |
| r""" |
| Returns: |
| `Mask2FormerModelOutput` |
| |
| Examples: |
| ```python |
| >>> import torch |
| >>> from PIL import Image |
| >>> import requests |
| >>> from transformers import AutoImageProcessor, Mask2FormerModel |
| |
| >>> # load image |
| >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" |
| >>> image = Image.open(requests.get(url, stream=True).raw) |
| |
| >>> # load image preprocessor and Mask2FormerModel trained on COCO instance segmentation dataset |
| >>> image_processor = AutoImageProcessor.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
| >>> model = Mask2FormerModel.from_pretrained("facebook/mask2former-swin-small-coco-instance") |
| >>> inputs = image_processor(image, return_tensors="pt") |
| |
| >>> # forward pass |
| >>> with torch.no_grad(): |
| ... outputs = model(**inputs) |
| |
| >>> # model outputs last hidden states of shape (batch_size, num_queries, hidden_size) |
| >>> print(outputs.transformer_decoder_last_hidden_state.shape) |
| torch.Size([1, 100, 256]) |
| ``` |
| """ |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| batch_size, _, height, width = pixel_values.shape |
|
|
| if pixel_mask is None: |
| pixel_mask = torch.ones((batch_size, height, width), device=pixel_values.device) |
|
|
| transformer_module_output = self.transformer_module.Mask2FormerTransformerModule_forward_second_part( |
| query_features=query_features, |
| query_embeddings=query_embeddings, |
| multi_scale_features=pixel_level_module_output.decoder_hidden_states, |
| mask_features=pixel_level_module_output.decoder_last_hidden_state, |
| output_hidden_states=True, |
| output_attentions=output_attentions, |
| ) |
|
|
| encoder_hidden_states = None |
| pixel_decoder_hidden_states = None |
| transformer_decoder_hidden_states = None |
| transformer_decoder_intermediate_states = None |
|
|
| if output_hidden_states: |
| encoder_hidden_states = pixel_level_module_output.encoder_hidden_states |
| pixel_decoder_hidden_states = pixel_level_module_output.decoder_hidden_states |
| transformer_decoder_hidden_states = transformer_module_output.hidden_states |
| transformer_decoder_intermediate_states = transformer_module_output.intermediate_hidden_states |
|
|
| output = Mask2FormerModelOutput( |
| encoder_last_hidden_state=pixel_level_module_output.encoder_last_hidden_state, |
| pixel_decoder_last_hidden_state=pixel_level_module_output.decoder_last_hidden_state, |
| transformer_decoder_last_hidden_state=transformer_module_output.last_hidden_state, |
| encoder_hidden_states=encoder_hidden_states, |
| pixel_decoder_hidden_states=pixel_decoder_hidden_states, |
| transformer_decoder_hidden_states=transformer_decoder_hidden_states, |
| transformer_decoder_intermediate_states=transformer_decoder_intermediate_states, |
| attentions=transformer_module_output.attentions, |
| masks_queries_logits=transformer_module_output.masks_queries_logits, |
| ) |
|
|
| if not return_dict: |
| output = tuple(v for v in output.values() if v is not None) |
|
|
| return output |
|
|
|
|
| def Mask2FormerForUniversalSegmentation_forward_first_part( |
| self, |
| pixel_values: Tensor, |
| mask_labels: Optional[List[Tensor]] = None, |
| class_labels: Optional[List[Tensor]] = None, |
| pixel_mask: Optional[Tensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_auxiliary_logits: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Mask2FormerForUniversalSegmentationOutput: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| query_features, pixel_level_module_output = self.model.Mask2FormerModel_forward_first_part( |
| pixel_values=pixel_values, |
| pixel_mask=pixel_mask, |
| output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, |
| output_attentions=output_attentions, |
| return_dict=True, |
| ) |
|
|
| return query_features, pixel_level_module_output |
|
|
|
|
| def Mask2FormerForUniversalSegmentation_forward_second_part( |
| self, |
| query_features, |
| query_embeddings, |
| pixel_level_module_output, |
| text_classifier, |
| pixel_values: Tensor, |
| mask_labels: Optional[List[Tensor]] = None, |
| class_labels: Optional[List[Tensor]] = None, |
| pixel_mask: Optional[Tensor] = None, |
| output_hidden_states: Optional[bool] = None, |
| output_auxiliary_logits: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ) -> Mask2FormerForUniversalSegmentationOutput: |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| outputs = self.model.Mask2FormerModel_forward_second_part( |
| query_features=query_features, |
| query_embeddings=query_embeddings, |
| pixel_level_module_output=pixel_level_module_output, |
| pixel_values=pixel_values, |
| pixel_mask=pixel_mask, |
| output_hidden_states=output_hidden_states or self.config.use_auxiliary_loss, |
| output_attentions=output_attentions, |
| return_dict=True, |
| ) |
|
|
| loss, loss_dict, auxiliary_logits = None, None, None |
| class_queries_logits = () |
|
|
| for decoder_output in outputs.transformer_decoder_intermediate_states: |
| class_prediction = self.ov_class_predictor(decoder_output.transpose(0, 1), text_classifier) |
| |
| class_queries_logits += (class_prediction,) |
|
|
| masks_queries_logits = outputs.masks_queries_logits |
|
|
| auxiliary_logits = self.get_auxiliary_logits(class_queries_logits, masks_queries_logits) |
|
|
| if mask_labels is not None and class_labels is not None: |
| loss_dict = self.get_loss_dict( |
| masks_queries_logits=masks_queries_logits[-1], |
| class_queries_logits=class_queries_logits[-1], |
| mask_labels=mask_labels, |
| class_labels=class_labels, |
| auxiliary_predictions=auxiliary_logits, |
| ) |
| loss = self.get_loss(loss_dict) |
|
|
| encoder_hidden_states = None |
| pixel_decoder_hidden_states = None |
| transformer_decoder_hidden_states = None |
|
|
| if output_hidden_states: |
| encoder_hidden_states = outputs.encoder_hidden_states |
| pixel_decoder_hidden_states = outputs.pixel_decoder_hidden_states |
| transformer_decoder_hidden_states = outputs.transformer_decoder_hidden_states |
|
|
| output_auxiliary_logits = ( |
| self.config.output_auxiliary_logits if output_auxiliary_logits is None else output_auxiliary_logits |
| ) |
| if not output_auxiliary_logits: |
| auxiliary_logits = None |
|
|
| output = Mask2FormerForUniversalSegmentationOutput( |
| loss=loss, |
| class_queries_logits=class_queries_logits[-1], |
| masks_queries_logits=masks_queries_logits[-1], |
| auxiliary_logits=auxiliary_logits, |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
| pixel_decoder_last_hidden_state=outputs.pixel_decoder_last_hidden_state, |
| transformer_decoder_last_hidden_state=outputs.transformer_decoder_last_hidden_state, |
| encoder_hidden_states=encoder_hidden_states, |
| pixel_decoder_hidden_states=pixel_decoder_hidden_states, |
| transformer_decoder_hidden_states=transformer_decoder_hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| if not return_dict: |
| output = tuple(v for v in output.values() if v is not None) |
| if loss is not None: |
| output = (loss) + output |
| return output |
|
|