Balanced Sinkhorn from the Python Optimal Transport Library

The Optimal Transport problem aims at minimizing the transport cost between a source distribution (the predictions) and a target distribution [Kan58, Mon81, Vil09]. This particular problem solves a regularization of the original problem with Sinkhorn’s algorithm, well suited for GPU parallelization [Cut13, PeyreC+19]. In the limit of reg (or reg_dimless) to 0, this becomes the Hungarian algorithm. We refer to Proposition 1 of [DPDPS+23] for further information. By consequence, this matching can be seen as a regularized version of the Hungarian one.

We consider the matching cost \(\mathcal{L}_{\text{match}}\) = cls_match_module + loc_match_module between the \(N_p\) predictions \(\hat{\mathbf{y}}_i\) and \(N_t\) targets \(\mathbf{y}_j\). In particular, the cost of the background \(\mathbf{y}_{N_t+1} = \varnothing\) is given by \(\mathcal{L}_{\text{match}}\left(\hat{\mathbf{y}}_i, \varnothing\right)\) = bg_cost.

For each element in the batch, the following problem is solved and the match \(\mathbf{P}\) is retrieved.

\begin{align} N_p * \underset{\mathbf{P}\in \mathbb{R}^{N_p \times N_t+1}}{\mathrm{arg\,min}} &\sum_{i,j=1}^{N_p,N_t+1} P_{i,j}\mathcal{L}_{\text{match}}\left(\hat{\mathbf{y}}_i, \mathbf{y}_j\right) - \mathtt{reg} * \,\mathrm{H}(\mathbf{P}), &\\ \mathrm{s.t.} & \sum_{j=1}^{N_t+1} P_{i,j} = 1/N_p, & \forall\; 0 \leq i \leq N_p \;\text{(predictions)},\\ & \sum_{i=1}^{N_p} P_{i,j} = 1/N_p, & \forall\; 0 \leq j \leq N_t \;\text{(targets)},\\ & \sum_{i=1}^{N_p} P_{i,j} = (N_p - N_t)/N_p, & j = N_t+1\;\text{(background)}. \end{align}

with \(\mathrm{H}: \Delta^{N \times M} \rightarrow \mathbb{R}_{\geq 0} : \mathbf{P} \mapsto -\sum_{i,j} P_{i,j}(\log(P_{i,j})-1)\) the entropy of the match \(\mathbf{P}\), with \(0 \ln(0) = 0\) by definition.

In the particular case where no background is used, the problem remains the same but the last column of \(\mathbf{P}\) is just unnused.

Warning

If the formulation converges to the Hungarian algorithm in the limit of reg (or reg_dimless) to 0, it becomes more and more unstable if solved using Sinkhorn’s algorithm. We encourage to use uotod.match.Hungarian if no regularization at all is explicitly seeked, at the cost of loosing parallelization.

Class

class uotod.match.BalancedPOT(**kwargs)
available_methods = ['sinkhorn', 'greenkorn', 'sinkhorn_log', 'sinkhorn_stabilized', 'sinkhorn_epsilon_scaling']
compute_cost_matrix(input: Dict[str, Tensor] | List[Dict[str, Tensor]], target: Dict[str, Tensor] | List[Dict[str, Tensor]], anchors: Tensor | None = None) Tensor

Computes a batch of cost matrices between the predicted and target boxes.

Parameters:
  • input (dictionary) – Input containing the predicted logits and boxes. “pred_logits”: Tensor of shape (batch_size, num_pred, num_classes). “pred_boxes”: Tensor of shape (batch_size, num_pred, 4), where the last dimension is (x1, y1, x2, y2).

  • target (dictionary) – Target containing the target classes, boxes and mask. “labels”: Tensor of shape (batch_size, num_targets). “boxes”: Tensor of shape (batch_size, num_targets, 4), where the last dimension is (x1, y1, x2, y2). “mask”: Tensor of shape (batch_size, num_targets).

  • anchors (Tensor) – the anchors used to compute the predicted boxes. (batch_size, num_pred, 4), where the last dimension is (x1, y1, x2, y2).

  • background (bool, optional) – Indicated whether the background has to be added.

Returns:

the matching between the predicted and target boxes: Tensor of shape (batch_size, num_pred, num_targets + 1) or (batch_size, num_pred, num_targets) if background is False.

Return type:

Tensor (float)

compute_matching(cost_matrix: Tensor, target_mask: Tensor) Tensor

Computes the matching between the predicted and target boxes. The optimal transport problem is solved using the Sinkhorn algorithm. :param cost_matrix: the cost matrix. Tensor of shape (batch_size, num_pred, num_tgt + 1). :param target_mask: the target mask. Tensor of shape (batch_size, num_tgt). :return: the matching. Tensor of shape (batch_size, num_pred, num_tgt + 1). The last entry of the last dimension is the background.

Computes the matching.

Parameters:
  • cost_matrix (Tensor) – Cost matrix of shape (batch_size, num_pred, num_targets + 1).

  • target_mask (BoolTensor, optional) – Target mask of shape (batch_size, num_targets).

Returns:

The matching \(\mathbf{P}\) for each element of the batch. Tensor of shape (batch_size, num_pred, num_targets + 1). The last entry of the last dimension [:, :, num_target+1] is the background.

forward(input: Dict[str, Tensor] | List[Dict[str, Tensor]], target: Dict[str, Tensor] | List[Dict[str, Tensor]], anchors: Tensor | None = None, cost_matrix: Tensor | None = None, save: bool = True) Tensor | Tuple[Tensor, Tensor]

Computes a batch of matchings between the predicted and target boxes.

Parameters:
  • input (dictionary) – Input containing the predicted logits and boxes. “pred_logits”: Tensor of shape (batch_size, num_pred, num_classes). “pred_boxes”: Tensor of shape (batch_size, num_pred, 4), where the last dimension is (x1, y1, x2, y2).

  • target (dictionary) – Target containing the target classes, boxes and mask. “labels”: Tensor of shape (batch_size, num_targets). “boxes”: Tensor of shape (batch_size, num_targets, 4), where the last dimension is (x1, y1, x2, y2). “mask”: Tensor of shape (batch_size, num_targets).

  • anchors (Tensor) – the anchors used to compute the predicted boxes. (batch_size, num_pred, 4), where the last dimension is (x1, y1, x2, y2).

Returns:

the matching between the predicted and target boxes, and the cost matrix if returns_cost is True: Tensor of shape (batch_size, num_pred, num_targets + 1). The last entry of the last dimension is the background.

Return type:

Tensor (float) or Tuple(Tensor, Tensor)

plot(idx=0, img: Tensor | ndarray | None = None, plot_cost: bool = True, plot_match: bool = True, max_background_match: float | int = 1.0, background: bool = True, erase: bool = False)

Plots from the last batch # TODO: extensive description

Parameters:
  • idx (int, optional) – Index of the image to be plotted.

  • img (Tensor or ndarray, optional) – Image to be plotted. If it is not None, the boxes plot is computed.

  • plot_cost (bool, optional) – Plots the cost matrix between the predictions and the targets, including background.

  • plot_match (bool, optional) – Plots the cost matrix between the predictions and the targets, including background.

  • max_background_match (float, optional) – A threshold to only plot relevant matched predictions. The predictions are only plotted if the value matched to the background does not exceed max_background_match. Defaults to 1.

Returns:

Matplotlib figures

Return type:

Tuple(fig, fig, fig)

Available Methods

Value

Description

Extra optional arguments

Reference