Source code for quaterion.dataset.similarity_dataset
from typing import Sized
from torch.utils.data import Dataset
from quaterion.dataset.similarity_samples import SimilarityGroupSample
[docs]class SimilarityGroupDataset(Dataset[SimilarityGroupSample]):
"""Wrapper, which converts standard dataset of classification task into dataset,
compatible with :class:`~quaterion.dataset.similarity_data_loader.GroupSimilarityDataLoader`.
Args:
dataset: a dataset, which return data in format: `(record, label)`
"""
def __init__(self, dataset: Dataset):
self._dataset = dataset
def __len__(self) -> int:
if isinstance(self._dataset, Sized):
return len(self._dataset)
else:
raise NotImplementedError
def __getitem__(self, index) -> SimilarityGroupSample:
record, label = self._dataset.__getitem__(index)
return SimilarityGroupSample(obj=record, group=label)