| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ PyTorch Conditional DETR model.""" |
|
|
| from transformers.utils import ( |
| is_scipy_available, |
| is_vision_available, |
| logging |
| ) |
|
|
| import torch |
| from torch import Tensor, nn |
|
|
| if is_scipy_available(): |
| from scipy.optimize import linear_sum_assignment |
|
|
| if is_vision_available(): |
| from transformers.image_transforms import center_to_corners_format |
|
|
| logger = logging.get_logger(__name__) |
|
|
| |
| class ConditionalDetrHungarianMatcher(nn.Module): |
| """ |
| This class computes an assignment between the targets and the predictions of the network. |
| |
| For efficiency reasons, the targets don't include the no_object. Because of this, in general, there are more |
| predictions than targets. In this case, we do a 1-to-1 matching of the best predictions, while the others are |
| un-matched (and thus treated as non-objects). |
| |
| Args: |
| class_cost: |
| The relative weight of the classification error in the matching cost. |
| bbox_cost: |
| The relative weight of the L1 error of the bounding box coordinates in the matching cost. |
| giou_cost: |
| The relative weight of the giou loss of the bounding box in the matching cost. |
| """ |
|
|
| def __init__(self, class_cost: float = 1, bbox_cost: float = 1, giou_cost: float = 1): |
| super().__init__() |
|
|
| self.class_cost = class_cost |
| self.bbox_cost = bbox_cost |
| self.giou_cost = giou_cost |
| if class_cost == 0 and bbox_cost == 0 and giou_cost == 0: |
| raise ValueError("All costs of the Matcher can't be 0") |
|
|
| @torch.no_grad() |
| def forward(self, outputs, targets): |
| """ |
| Args: |
| outputs (`dict`): |
| A dictionary that contains at least these entries: |
| * "logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits |
| * "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates. |
| targets (`List[dict]`): |
| A list of targets (len(targets) = batch_size), where each target is a dict containing: |
| * "class_labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of |
| ground-truth |
| objects in the target) containing the class labels |
| * "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates. |
| |
| Returns: |
| `List[Tuple]`: A list of size `batch_size`, containing tuples of (index_i, index_j) where: |
| - index_i is the indices of the selected predictions (in order) |
| - index_j is the indices of the corresponding selected targets (in order) |
| For each batch element, it holds: len(index_i) = len(index_j) = min(num_queries, num_target_boxes) |
| """ |
| batch_size, num_queries = outputs["logits"].shape[:2] |
|
|
| |
| out_prob = outputs["logits"].flatten(0, 1).sigmoid() |
| out_bbox = outputs["pred_boxes"].flatten(0, 1) |
|
|
| |
| target_ids = torch.cat([v["class_labels"] for v in targets]) |
| target_bbox = torch.cat([v["boxes"] for v in targets]) |
|
|
| |
| alpha = 0.25 |
| gamma = 2.0 |
| neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(1 - out_prob + 1e-8).log()) |
| pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log()) |
| class_cost = pos_cost_class[:, target_ids] - neg_cost_class[:, target_ids] |
|
|
| |
| bbox_cost = torch.cdist(out_bbox, target_bbox, p=1) |
|
|
| |
| giou_cost = -generalized_box_iou(center_to_corners_format(out_bbox), center_to_corners_format(target_bbox)) |
|
|
| |
| cost_matrix = self.bbox_cost * bbox_cost + self.class_cost * class_cost + self.giou_cost * giou_cost |
| cost_matrix = cost_matrix.view(batch_size, num_queries, -1).cpu() |
|
|
| sizes = [len(v["boxes"]) for v in targets] |
| indices = [linear_sum_assignment(c[i]) for i, c in enumerate(cost_matrix.split(sizes, -1))] |
| return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices] |
|
|
|
|
| |
| def _upcast(t: Tensor) -> Tensor: |
| |
| if t.is_floating_point(): |
| return t if t.dtype in (torch.float32, torch.float64) else t.float() |
| else: |
| return t if t.dtype in (torch.int32, torch.int64) else t.int() |
|
|
|
|
| |
| def box_area(boxes: Tensor) -> Tensor: |
| """ |
| Computes the area of a set of bounding boxes, which are specified by its (x1, y1, x2, y2) coordinates. |
| |
| Args: |
| boxes (`torch.FloatTensor` of shape `(number_of_boxes, 4)`): |
| Boxes for which the area will be computed. They are expected to be in (x1, y1, x2, y2) format with `0 <= x1 |
| < x2` and `0 <= y1 < y2`. |
| |
| Returns: |
| `torch.FloatTensor`: a tensor containing the area for each box. |
| """ |
| boxes = _upcast(boxes) |
| return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) |
|
|
|
|
| |
| def box_iou(boxes1, boxes2): |
| area1 = box_area(boxes1) |
| area2 = box_area(boxes2) |
|
|
| left_top = torch.max(boxes1[:, None, :2], boxes2[:, :2]) |
| right_bottom = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) |
|
|
| width_height = (right_bottom - left_top).clamp(min=0) |
| inter = width_height[:, :, 0] * width_height[:, :, 1] |
|
|
| union = area1[:, None] + area2 - inter |
|
|
| iou = inter / union |
| return iou, union |
|
|
|
|
| |
| def generalized_box_iou(boxes1, boxes2): |
| """ |
| Generalized IoU from https://giou.stanford.edu/. The boxes should be in [x0, y0, x1, y1] (corner) format. |
| |
| Returns: |
| `torch.FloatTensor`: a [N, M] pairwise matrix, where N = len(boxes1) and M = len(boxes2) |
| """ |
| |
| |
| if not (boxes1[:, 2:] >= boxes1[:, :2]).all(): |
| raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") |
| if not (boxes2[:, 2:] >= boxes2[:, :2]).all(): |
| raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") |
| iou, union = box_iou(boxes1, boxes2) |
|
|
| top_left = torch.min(boxes1[:, None, :2], boxes2[:, :2]) |
| bottom_right = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) |
|
|
| width_height = (bottom_right - top_left).clamp(min=0) |
| area = width_height[:, :, 0] * width_height[:, :, 1] |
|
|
| return iou - (area - union) / area |
|
|