medkit.training.trainable_component#

Classes#

TrainableComponent

TrainableComponent is the base protocol to be trainable in medkit.

Module Contents#

class medkit.training.trainable_component.TrainableComponent#

Bases: typing_extensions.Protocol

TrainableComponent is the base protocol to be trainable in medkit.

property device: torch.device#
configure_optimizer(lr: float) torch.optim.Optimizer#

Create optimizer using the learning rate.

preprocess(data_item: Any) dict[str, Any]#

Run preprocessing on the input data.

Preprocess the input data item and return a dictionary with everything needed for the forward pass.

This method is intended to preprocess an input, self.collate must be used to generate batches for self.forward to run properly. Preprocess should include labels to compute a loss.

collate(batch: list[dict[str, Any]]) medkit.training.utils.BatchData#

Collate a list of data processed by preprocess to form a batch.

forward(input_batch: medkit.training.utils.BatchData, return_loss: bool, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor | None]#

Perform the forward pass on a batch.

Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.

Before forwarding the model, this method must set the model to training or evaluation mode depending on eval_mode. In PyTorch models there are two methods to set the mode model.train() and model.eval().

Parameters:
input_batchBatchData

Input batch

return_lossbool

Whether to return the computed loss as well

eval_modebool

Whether to set the model to training (False) or evaluation mode (True)

Returns:
outputBatchData

Output after forward pass completion

loss: torch.Tensor, optional

Loss after forward pass completion, if return_loss was set to True.

save(path: str | pathlib.Path)#

Save model to disk.

load(path: str | pathlib.Path)#

Load weights from disk.