medkit.training.trainable_component#
Classes#
TrainableComponent is the base protocol to be trainable in medkit. |
Module Contents#
- class medkit.training.trainable_component.TrainableComponent#
Bases:
typing_extensions.Protocol
TrainableComponent is the base protocol to be trainable in medkit.
- property device: torch.device#
- configure_optimizer(lr: float) torch.optim.Optimizer #
Create optimizer using the learning rate.
- preprocess(data_item: Any) dict[str, Any] #
Run preprocessing on the input data.
Preprocess the input data item and return a dictionary with everything needed for the forward pass.
This method is intended to preprocess an input, self.collate must be used to generate batches for self.forward to run properly. Preprocess should include labels to compute a loss.
- collate(batch: list[dict[str, Any]]) medkit.training.utils.BatchData #
Collate a list of data processed by preprocess to form a batch.
- forward(input_batch: medkit.training.utils.BatchData, return_loss: bool, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor | None] #
Perform the forward pass on a batch.
Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.
Before forwarding the model, this method must set the model to training or evaluation mode depending on eval_mode. In PyTorch models there are two methods to set the mode model.train() and model.eval().
- Parameters:
- input_batchBatchData
Input batch
- return_lossbool
Whether to return the computed loss as well
- eval_modebool
Whether to set the model to training (False) or evaluation mode (True)
- Returns:
- outputBatchData
Output after forward pass completion
- loss: torch.Tensor, optional
Loss after forward pass completion, if return_loss was set to True.
- save(path: str | pathlib.Path)#
Save model to disk.
- load(path: str | pathlib.Path)#
Load weights from disk.