medkit.training.trainer#
Classes#
Class faciltating training and evaluation of PyTorch models to generate annotations. |
Module Contents#
- class medkit.training.trainer.Trainer(component: medkit.training.trainable_component.TrainableComponent, config: medkit.training.trainer_config.TrainerConfig, train_data: Any, eval_data: Any, metrics_computer: medkit.training.utils.MetricsComputer | None = None, lr_scheduler_builder: Callable[[torch.optim.Optimizer], Any] | None = None, callback: medkit.training.callbacks.TrainerCallback | None = None)#
Class faciltating training and evaluation of PyTorch models to generate annotations.
- Parameters:
- component:
The component to train, the component must implement the TrainableComponent protocol.
- config:
A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpoints
- train_data:
The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.
- eval_data:
The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.
- metrics_computer:
Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.
- lr_scheduler_builder:
Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.
- callback:
Optional callback to customize training.
- output_dir#
- component#
- batch_size#
- dataloader_drop_last = False#
- dataloader_nb_workers#
- dataloader_pin_memory = False#
- device#
- train_dataloader#
- eval_dataloader#
- nb_training_epochs#
- config#
- optimizer#
- lr_scheduler#
- metrics_computer#
- callback#
- get_dataloader(data: dict, shuffle: bool) torch.utils.data.DataLoader #
Return a DataLoader with transformations defined in the component to train.
- Parameters:
- datadict
Training data
- shuffle: bool
Whether to use sequential or shuffled sampling
- Returns:
- torch.utils.data.DataLoader
The corresponding instance of a DataLoader
- training_epoch() dict[str, float] #
Perform an epoch using the training data.
When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.
- Returns:
- dict of str to float
A dictionary containing the training metrics
- evaluation_epoch(eval_dataloader) dict[str, float] #
Perform an epoch using the evaluation data.
The additional metrics are prepared per batch.
- Parameters:
- eval_dataloadertorch.utils.data.DataLoader
The evaluation dataset as a PyTorch DataLoader
- Returns:
- dict of str to float
A dictionary containing the evaluation metrics
- make_forward_pass(inputs: medkit.training.utils.BatchData, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor] #
Run forward safely, same device as the component.
- update_learning_rate(eval_metrics: dict[str, float]) None #
Call the learning rate scheduler if defined.
- train() list[dict] #
Call the training and evaluation loop.
- Returns:
- list of dict of str to float
The list of computed metrics per epoch
- save(epoch: int) str #
Save a checkpoint.
Checkpoints include trainer configuration, model weights, optimizer and scheduler.
- Parameters:
- epochint
Epoch corresponding of the current training state (will be included in the checkpoint name)
- Returns:
- str
Path to the saved checkpoint