Loss functions

PyTorchLTR provides serveral common loss functions for LTR. Each loss function operates on a batch of query-document lists with corresponding relevance labels.

The input to an LTR loss function comprises three tensors:

  • scores: A tensor of size \((N, \texttt{list_size})\): the item scores

  • relevance: A tensor of size \((N, \texttt{list_size})\): the relevance labels

  • n: A tensor of size \((N)\): the number of docs per learning instance.

And produces the following output:

  • output: A tensor of size \((N)\): the loss per learning instance in the batch.

Example

The following is a usage example for the pairwise hinge loss but the same usage pattern holds for all the other losses.

>>> import torch
>>> from pytorchltr.loss import PairwiseHingeLoss
>>> scores = torch.tensor([[0.5, 2.0, 1.0], [0.9, -1.2, 0.0]])
>>> relevance = torch.tensor([[2, 0, 1], [0, 1, 0]])
>>> n = torch.tensor([3, 2])
>>> loss_fn = PairwiseHingeLoss()
>>> loss_fn(scores, relevance, n)
tensor([6.0000, 3.1000])
>>> loss_fn(scores, relevance, n).mean()
tensor(4.5500)

Additive ranking losses

Additive ranking losses optimize linearly decomposible ranking metrics [J02][ATZ+19]. These loss functions optimize an upper bound on the rank of relevant documents via either a hinge or logistic formulation.

class pytorchltr.loss.PairwiseHingeLoss

Pairwise hinge loss formulation of SVMRank:

\[l(\mathbf{s}, \mathbf{y}) = \sum_{y_i > y _j} max\left( 0, 1 - (s_i - s_j) \right)\]
Shape:
  • input scores: \((N, \texttt{list_size})\)

  • input relevance: \((N, \texttt{list_size})\)

  • input n: \((N)\)

  • output: \((N)\)

forward(scores, relevance, n)

Computes the loss for given batch of samples.

Parameters
  • scores (FloatTensor) – A batch of per-query-document scores.

  • relevance (LongTensor) – A batch of per-query-document relevance labels.

  • n (LongTensor) – A batch of per-query number of documents (for padding purposes).

Return type

FloatTensor

class pytorchltr.loss.PairwiseDCGHingeLoss

Pairwise DCG-modified hinge loss:

\[l(\mathbf{s}, \mathbf{y}) = \frac{-1}{\log\left( 2 + \sum_{y_i > y_j} max\left(0, 1 - (s_i - s_j)\right) \right)}\]
Shape:
  • input scores: \((N, \texttt{list_size})\)

  • input relevance: \((N, \texttt{list_size})\)

  • input n: \((N)\)

  • output: \((N)\)

forward(scores, relevance, n)

Computes the loss for given batch of samples.

Parameters
  • scores (FloatTensor) – A batch of per-query-document scores.

  • relevance (LongTensor) – A batch of per-query-document relevance labels.

  • n (LongTensor) – A batch of per-query number of documents (for padding purposes).

Return type

FloatTensor

class pytorchltr.loss.PairwiseLogisticLoss(sigma=1.0)

Pairwise logistic loss formulation of RankNet:

\[l(\mathbf{s}, \mathbf{y}) = \sum_{y_i > y_j} \log_2\left(1 + e^{ -\sigma \left(s_i - s_j\right) }\right)\]
Shape:
  • input scores: \((N, \texttt{list_size})\)

  • input relevance: \((N, \texttt{list_size})\)

  • input n: \((N)\)

  • output: \((N)\)

__init__(sigma=1.0)
Parameters

sigma (float) – Steepness of the logistic curve.

forward(scores, relevance, n)

Computes the loss for given batch of samples.

Parameters
  • scores (FloatTensor) – A batch of per-query-document scores.

  • relevance (LongTensor) – A batch of per-query-document relevance labels.

  • n (LongTensor) – A batch of per-query number of documents (for padding purposes).

Return type

FloatTensor

LambdaLoss

LambdaLoss [WLG+18] is a probabilistic framework for ranking metric optimization. We provide implementations for ARPLoss1, ARPLoss2, NDCGLoss1 and NDCGLoss2.

class pytorchltr.loss.LambdaARPLoss1(sigma=1.0)

ARP Loss 1:

\[l(\mathbf{s}, \mathbf{y}) = -\sum_{i=1}^n \sum_{j=1}^n \log_2 \left( \frac{1}{1 + e^{-\sigma (s_{\pi_i} - s_{\pi_j})}} \right)^{y_{\pi_i}}\]

where \(\pi_i\) is the index of the item at rank \(i\) after sorting the scores

Shape:
  • input scores: \((N, \texttt{list_size})\)

  • input relevance: \((N, \texttt{list_size})\)

  • input n: \((N)\)

  • output: \((N)\)

__init__(sigma=1.0)
Parameters

sigma (float) – Steepness of the logistic curve.

forward(scores, relevance, n)

Computes the loss for given batch of samples.

Parameters
  • scores (FloatTensor) – A batch of per-query-document scores.

  • relevance (LongTensor) – A batch of per-query-document relevance labels.

  • n (LongTensor) – A batch of per-query number of documents (for padding purposes).

Return type

FloatTensor

class pytorchltr.loss.LambdaARPLoss2(sigma=1.0)

ARP Loss 2:

\[l(\mathbf{s}, \mathbf{y}) = \sum_{y_i > y_j} |y_i - y_j| \log_2 \left( 1 + e^{-\sigma(s_i - s_j)} \right)\]
Shape:
  • input scores: \((N, \texttt{list_size})\)

  • input relevance: \((N, \texttt{list_size})\)

  • input n: \((N)\)

  • output: \((N)\)

__init__(sigma=1.0)
Parameters

sigma (float) – Steepness of the logistic curve.

forward(scores, relevance, n)

Computes the loss for given batch of samples.

Parameters
  • scores (FloatTensor) – A batch of per-query-document scores.

  • relevance (LongTensor) – A batch of per-query-document relevance labels.

  • n (LongTensor) – A batch of per-query number of documents (for padding purposes).

Return type

FloatTensor

class pytorchltr.loss.LambdaNDCGLoss1(sigma=1.0)

NDCG Loss 1:

\[l(\mathbf{s}, \mathbf{y}) = -\sum_{i=1}^n \sum_{j=1}^n \log_2 \left( \frac{1}{1 + e^{-\sigma (s_{\pi_i} - s_{\pi_j})}} \right)^{\frac{G_{\pi_i}}{D_i}}\]

where \(\pi_i\) is the index of the item at rank \(i\) after sorting the scores and \(G_{\pi_i} = \frac{2^{y_{\pi_i}} - 1}{\text{maxDCG}}\) and \(D_i = \log_2(1 + i)\).

Shape:
  • input scores: \((N, \texttt{list_size})\)

  • input relevance: \((N, \texttt{list_size})\)

  • input n: \((N)\)

  • output: \((N)\)

__init__(sigma=1.0)
Parameters

sigma (float) – Steepness of the logistic curve.

forward(scores, relevance, n)

Computes the loss for given batch of samples.

Parameters
  • scores (FloatTensor) – A batch of per-query-document scores.

  • relevance (LongTensor) – A batch of per-query-document relevance labels.

  • n (LongTensor) – A batch of per-query number of documents (for padding purposes).

Return type

FloatTensor

class pytorchltr.loss.LambdaNDCGLoss2(sigma=1.0)

NDCG Loss 2:

\[l(\mathbf{s}, \mathbf{y}) = \sum_{y_i > y_j} \log_2 \left( \frac{1}{1 + e^{-\sigma (s_{\pi_i} - s_{\pi_j})}} \right)^{\delta_{ij} | G_{\pi_i} - G_{\pi_j} |}\]

where \(\pi_i\) is the index of the item at rank \(i\) after sorting the scores and \(G_{\pi_i} = \frac{2^{y_{\pi_i}} - 1}{\text{maxDCG}}\) and \(\delta_{ij} = \left|\frac{1}{D_{|i-j|}} - \frac{1}{D_{|i-j|+1}} \right|\) and \(D_i = \log_2(1 + i)\).

Shape:
  • input scores: \((N, \texttt{list_size})\)

  • input relevance: \((N, \texttt{list_size})\)

  • input n: \((N)\)

  • output: \((N)\)

__init__(sigma=1.0)
Parameters

sigma (float) – Steepness of the logistic curve.

forward(scores, relevance, n)

Computes the loss for given batch of samples.

Parameters
  • scores (FloatTensor) – A batch of per-query-document scores.

  • relevance (LongTensor) – A batch of per-query-document relevance labels.

  • n (LongTensor) – A batch of per-query number of documents (for padding purposes).

Return type

FloatTensor

References

J02

Thorsten Joachims. Optimizing search engines using clickthrough data. In Proceedings of the Eighth ACM SIGKDD International Conference on Knowledge Discovery and Data Mining, KDD ’02, 133–142. New York, NY, USA, 2002. Association for Computing Machinery. doi:10.1145/775047.775067.

ATZ+19

Aman Agarwal, Kenta Takatsu, Ivan Zaitsev, and Thorsten Joachims. A general framework for counterfactual learning-to-rank. In Proceedings of the 42nd International ACM SIGIR Conference on Research and Development in Information Retrieval, SIGIR’19, 5–14. New York, NY, USA, 2019. Association for Computing Machinery. doi:10.1145/3331184.3331202.

WLG+18

Xuanhui Wang, Cheng Li, Nadav Golbandi, Michael Bendersky, and Marc Najork. The lambdaloss framework for ranking metric optimization. In Proceedings of the 27th ACM International Conference on Information and Knowledge Management, CIKM ’18, 1313–1322. New York, NY, USA, 2018. Association for Computing Machinery. doi:10.1145/3269206.3271784.