Shortcuts

quaterion.train.trainable_model module

class TrainableModel(*args: Any, **kwargs: Any)[source]

Bases: LightningModule, CacheMixin

Base class for models to be trained.

TrainableModel is used to describe how and which components of the model should be trained.

It assembles model from building blocks like Encoder, EncoderHead, etc.

┌─────────┐ ┌─────────┐ ┌─────────┐
│Encoder 1│ │Encoder 2│ │Encoder 3│
└────┬────┘ └────┬────┘ └────┬────┘
     │           │           │
     └────────┐  │  ┌────────┘
              │  │  │
          ┌───┴──┴──┴───┐
          │   concat    │
          └──────┬──────┘
                 │
          ┌──────┴──────┐
          │    Head     │
          └─────────────┘

TrainableModel also handles the majority of the training process routine: training and validation steps, tensors device management, logging, and many more. Most of the training routines are inherited from LightningModule, which is a direct ancestor of TrainableModel.

To train a model you need to inherit it from TrainableModel and implement required methods and attributes.

Minimal Example:

class ExampleModel(TrainableModel):
    def __init__(self, lr=10e-5, *args, **kwargs):
        self.lr = lr
        super().__init__(*args, **kwargs)

    # backbone of the model
    def configure_encoders(self):
        return YourAwesomeEncoder()

    # top layer of the model
    def configure_head(self, input_embedding_size: int):
        return SkipConnectionHead(input_embedding_size)

    def configure_optimizers(self):
        return Adam(self.model.parameters(), lr=self.lr)

    def configure_loss(self):
        return ContrastiveLoss()
configure_caches() Optional[CacheConfig][source]

Method to provide cache configuration

Use this method to define which encoders should cache calculated embeddings and what kind of cache they should use.

Returns:

Optional[CacheConfig] – cache configuration to be applied if provided, None otherwise.

Examples:

Do not use cache (default):

return None

Configure cache automatically for all non-trainable encoders:

return CacheConfig(CacheType.AUTO)

Specify cache type for each encoder individually:

return CacheConfig(mapping={
        "text_encoder": CacheType.GPU,
        # Store cache in GPU for `text_encoder`
        "image_encoder": CacheType.CPU
        # Store cache in RAM for `image_encoder`
    }
)

Specify key for cache object disambiguation:

return CacheConfig(
    cache_type=CacheType.AUTO,
    key_extractors={"text_encoder": hash}
)

This function might be useful if you want to provide some more sophisticated way of storing association between cached vectors and original object. Item numbers from dataset will be used by default if key is not specified.

configure_encoders() Union[Encoder, Dict[str, Encoder]][source]

Method to provide encoders configuration

Use this function to define an initial state of encoders. This function should be used to assign initial values for encoders before training as well as during the checkpoint loading.

Returns:

Union[Encoder, Dict[str, Encoder]]: instance of encoder which will be assigned to DEFAULT_ENCODER_KEY, or mapping of names and encoders.

configure_head(input_embedding_size: int) EncoderHead[source]

Use this function to define an initial state for head layer of the model.

Parameters:

input_embedding_size – size of embeddings produced by encoders

Returns:

EncoderHead – head to be added on top of a model

configure_loss() SimilarityLoss[source]

Method to configure loss function to use.

configure_metrics() Union[AttachedMetric, List[AttachedMetric]][source]

Method to configure batch-wise metrics for a training process

Use this method to attach batch-wise metrics to a training process. Provided metrics have to have similar to PairMetric or GroupMetric

Returns:

Union[AttachedMetric, List[AttachedMetric]] - metrics attached to the model

Examples:

return [
    AttachedMetric(
        "RetrievalPrecision",
        RetrievalPrecision(k=1),
        prog_bar=True,
        on_epoch=True,
    ),
    AttachedMetric(
        "RetrievalReciprocalRank",
        RetrievalReciprocalRank(),
        prog_bar=True,
    ),
]
configure_xbm() XbmConfig[source]

Method to enable and configure Cross-Batch Memory (XBM).

XBM is a method relies on the idea of “slow drift” of embeddings in the course of training. It keeps recent N embeddings and target values in a ring buffer where N is much greater than the batch size. Then, it calculates a scaled loss with the values in this buffer and adds it to the regular loss. This enables to mine a large number of hard negatives.

See the paper for more details: https://arxiv.org/pdf/1912.06798.pdf

To enable it in a training process, you must return an instance of XbmConfig. The default return value is None, i.e., no XBM applied.

Note

XBM is currently supported only with GroupLoss instances.

process_results(embeddings: Tensor, targets: Dict[str, Any], batch_idx: int, stage: TrainStage, **kwargs)[source]

Method to provide any additional evaluations of embeddings.

Parameters:
  • embeddings – shape: (batch_size, embedding_size) - model’s output.

  • targets – output of batch target collate.

  • batch_idx – ID of the processing batch.

  • stage – train, validation or test stage.

save_servable(path: str)[source]

Save model for serving, independent of Pytorch Lightning

Parameters:

path – path to save to

setup_dataloader(dataloader: SimilarityDataLoader)[source]

Setup data loader for encoder-specific settings, Setup encoder-specific collate function

Each encoder have its own unique way to transform a list of records into NN-compatible format. These transformations are usually done during data pre-processing step.

property loss: SimilarityLoss

Property to get the loss function to use.

property model: SimilarityModel

Origin model to be trained

Returns:

SimilarityModel – model to be trained

training: bool

Qdrant

Learn more about Qdrant vector search project and ecosystem

Discover Qdrant

Similarity Learning

Explore practical problem solving with Similarity Learning

Learn Similarity Learning

Community

Find people dealing with similar problems and get answers to your questions

Join Community