quaterion.eval.samplers.group_sampler module¶
- class GroupSampler(sample_size=-1, encode_batch_size=16, device: Optional[Union[device, str]] = None, log_progress: bool = True)[source]¶
Bases:
BaseSampler
Perform selection of embeddings and targets for group based tasks.
- accumulate(model: SimilarityModel, dataset: Union[Sized, Iterable, Dataset])[source]¶
Encodes objects and accumulates embeddings with the corresponding raw labels
- Parameters:
model – model to encode objects
dataset – Sized object, like list, tuple, torch.utils.data.Dataset, etc. to accumulate
- sample(dataset: Sized, metric: GroupMetric, model: SimilarityModel) Tuple[Tensor, Tensor] [source]¶
Sample embeddings and targets for groups based tasks.
- Parameters:
dataset – Sized object, like list, tuple, torch.utils.data.Dataset, etc. to sample
metric – GroupMetric instance to compute final labels representation
model – model to encode objects
- Returns:
torch.Tensor, torch.Tensor – metrics labels and computed distance matrix