Source code for quaterion.loss.pairwise_loss
from torch import Tensor
from quaterion.distances import Distance
from quaterion.loss.similarity_loss import SimilarityLoss
[docs]class PairwiseLoss(SimilarityLoss):
"""Base class for pairwise losses.
Args:
distance_metric_name: Name of the distance function, e.g.,
:class:`~quaterion.distances.Distance`.
"""
def __init__(self, distance_metric_name: Distance = Distance.COSINE):
super(PairwiseLoss, self).__init__(distance_metric_name=distance_metric_name)
[docs] def forward(
self,
embeddings: Tensor,
pairs: Tensor,
labels: Tensor,
subgroups: Tensor,
) -> Tensor:
"""Compute loss value.
Args:
embeddings: shape: (batch_size, vector_length)
pairs: shape: (2 * pairs_count,) - contains a list of known similarity pairs
in batch
labels: shape: (pairs_count,) - similarity of the pair
subgroups: shape: (2 * pairs_count,) - subgroup ids of objects
Returns:
Tensor: zero-size tensor, loss value
"""
raise NotImplementedError()