Balanced Sinkhorn from the Python Optimal Transport Library
The Optimal Transport problem aims at minimizing the transport cost between a source distribution (the predictions) and a target distribution [Kan58, Mon81, Vil09]. This particular problem solves a regularization of the original problem with Sinkhorn’s algorithm, well suited for GPU parallelization [Cut13, PeyreC+19]. In the limit of reg (or reg_dimless) to 0, this becomes the Hungarian algorithm. We refer to Proposition 1 of [DPDPS+23] for further information. By consequence, this matching can be seen as a regularized version of the Hungarian one.
We consider the matching cost \(\mathcal{L}_{\text{match}}\) = cls_match_module + loc_match_module between the \(N_p\) predictions \(\hat{\mathbf{y}}_i\) and \(N_t\) targets \(\mathbf{y}_j\). In particular, the cost of the background \(\mathbf{y}_{N_t+1} = \varnothing\) is given by \(\mathcal{L}_{\text{match}}\left(\hat{\mathbf{y}}_i, \varnothing\right)\) = bg_cost.
For each element in the batch, the following problem is solved and the match \(\mathbf{P}\) is retrieved.
with \(\mathrm{H}: \Delta^{N \times M} \rightarrow \mathbb{R}_{\geq 0} : \mathbf{P} \mapsto -\sum_{i,j} P_{i,j}(\log(P_{i,j})-1)\) the entropy of the match \(\mathbf{P}\), with \(0 \ln(0) = 0\) by definition.
In the particular case where no background is used, the problem remains the same but the last column of \(\mathbf{P}\) is just unnused.
Warning
If the formulation converges to the Hungarian algorithm in the limit of reg (or reg_dimless) to 0, it becomes more and more unstable if solved using Sinkhorn’s algorithm. We encourage to use uotod.match.Hungarian if no regularization at all is explicitly seeked, at the cost of loosing parallelization.
Class
- class uotod.match.BalancedPOT(**kwargs)
- available_methods = ['sinkhorn', 'greenkorn', 'sinkhorn_log', 'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']
- compute_cost_matrix(input: Dict[str, Tensor] | List[Dict[str, Tensor]], target: Dict[str, Tensor] | List[Dict[str, Tensor]], anchors: Tensor | None = None) Tensor
Computes a batch of cost matrices between the predicted and target boxes.
- Parameters:
input (dictionary) – Input containing the predicted logits and boxes. “pred_logits”: Tensor of shape (batch_size, num_pred, num_classes). “pred_boxes”: Tensor of shape (batch_size, num_pred, 4), where the last dimension is (x1, y1, x2, y2).
target (dictionary) – Target containing the target classes, boxes and mask. “labels”: Tensor of shape (batch_size, num_targets). “boxes”: Tensor of shape (batch_size, num_targets, 4), where the last dimension is (x1, y1, x2, y2). “mask”: Tensor of shape (batch_size, num_targets).
anchors (Tensor) – the anchors used to compute the predicted boxes. (batch_size, num_pred, 4), where the last dimension is (x1, y1, x2, y2).
background (bool, optional) – Indicated whether the background has to be added.
- Returns:
the matching between the predicted and target boxes: Tensor of shape (batch_size, num_pred, num_targets + 1) or (batch_size, num_pred, num_targets) if background is False.
- Return type:
Tensor (float)
- compute_matching(cost_matrix: Tensor, target_mask: Tensor) Tensor
Computes the matching between the predicted and target boxes. The optimal transport problem is solved using the Sinkhorn algorithm. :param cost_matrix: the cost matrix. Tensor of shape (batch_size, num_pred, num_tgt + 1). :param target_mask: the target mask. Tensor of shape (batch_size, num_tgt). :return: the matching. Tensor of shape (batch_size, num_pred, num_tgt + 1). The last entry of the last dimension is the background.
Computes the matching.
- Parameters:
cost_matrix (Tensor) – Cost matrix of shape (batch_size, num_pred, num_targets + 1).
target_mask (BoolTensor, optional) – Target mask of shape (batch_size, num_targets).
- Returns:
The matching \(\mathbf{P}\) for each element of the batch. Tensor of shape (batch_size, num_pred, num_targets + 1). The last entry of the last dimension [:, :, num_target+1] is the background.
- forward(input: Dict[str, Tensor] | List[Dict[str, Tensor]], target: Dict[str, Tensor] | List[Dict[str, Tensor]], anchors: Tensor | None = None, cost_matrix: Tensor | None = None, save: bool = True) Tensor | Tuple[Tensor, Tensor]
Computes a batch of matchings between the predicted and target boxes.
- Parameters:
input (dictionary) – Input containing the predicted logits and boxes. “pred_logits”: Tensor of shape (batch_size, num_pred, num_classes). “pred_boxes”: Tensor of shape (batch_size, num_pred, 4), where the last dimension is (x1, y1, x2, y2).
target (dictionary) – Target containing the target classes, boxes and mask. “labels”: Tensor of shape (batch_size, num_targets). “boxes”: Tensor of shape (batch_size, num_targets, 4), where the last dimension is (x1, y1, x2, y2). “mask”: Tensor of shape (batch_size, num_targets).
anchors (Tensor) – the anchors used to compute the predicted boxes. (batch_size, num_pred, 4), where the last dimension is (x1, y1, x2, y2).
- Returns:
the matching between the predicted and target boxes, and the cost matrix if returns_cost is True: Tensor of shape (batch_size, num_pred, num_targets + 1). The last entry of the last dimension is the background.
- Return type:
Tensor (float) or Tuple(Tensor, Tensor)
- plot(idx=0, img: Tensor | ndarray | None = None, plot_cost: bool = True, plot_match: bool = True, max_background_match: float | int = 1.0, background: bool = True, erase: bool = False)
Plots from the last batch # TODO: extensive description
- Parameters:
idx (int, optional) – Index of the image to be plotted.
img (Tensor or ndarray, optional) – Image to be plotted. If it is not None, the boxes plot is computed.
plot_cost (bool, optional) – Plots the cost matrix between the predictions and the targets, including background.
plot_match (bool, optional) – Plots the cost matrix between the predictions and the targets, including background.
max_background_match (float, optional) – A threshold to only plot relevant matched predictions. The predictions are only plotted if the value matched to the background does not exceed max_background_match. Defaults to 1.
- Returns:
Matplotlib figures
- Return type:
Tuple(fig, fig, fig)
Available Methods
Value |
Description |
Extra optional arguments |
Reference |
|---|---|---|---|