quaterion.eval.accumulators.group_accumulator module
- class GroupAccumulator[source]
Bases:
Accumulator
Accumulate embeddings and groups for group-based tasks.
- reset()[source]
Reset accumulator state
Reset accumulator status, accumulated embeddings and groups
- update(embeddings: Tensor, groups: Tensor, device=None)[source]
Update accumulator state.
Move provided embeddings and groups to proper device and add to accumulated state.
- Parameters:
embeddings – embeddings to accumulate
groups – corresponding groups to accumulate
device – device to store tensors on
- property groups
Concatenate list of groups to Tensor
Help to avoid concatenating groups for each batch during accumulation. Instead, concatenate it only on call.
- Returns:
torch.Tensor – batch of groups
- property state: Dict[str, Tensor]
Accumulated state
- Returns:
Dict[str, torch.Tensor] - dictionary with embeddings and groups.