medkit.training.trainer#
Classes#
Class faciltating training and evaluation of PyTorch models to generate annotations.  | 
Module Contents#
- class medkit.training.trainer.Trainer(component: medkit.training.trainable_component.TrainableComponent, config: medkit.training.trainer_config.TrainerConfig, train_data: Any, eval_data: Any, metrics_computer: medkit.training.utils.MetricsComputer | None = None, lr_scheduler_builder: Callable[[torch.optim.Optimizer], Any] | None = None, callback: medkit.training.callbacks.TrainerCallback | None = None)#
 Class faciltating training and evaluation of PyTorch models to generate annotations.
- Parameters:
 - component:
 The component to train, the component must implement the TrainableComponent protocol.
- config:
 A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpoints
- train_data:
 The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.
- eval_data:
 The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.
- metrics_computer:
 Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.
- lr_scheduler_builder:
 Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.
- callback:
 Optional callback to customize training.
- output_dir#
 
- component#
 
- batch_size#
 
- dataloader_drop_last = False#
 
- dataloader_nb_workers#
 
- dataloader_pin_memory = False#
 
- device#
 
- train_dataloader#
 
- eval_dataloader#
 
- nb_training_epochs#
 
- config#
 
- optimizer#
 
- lr_scheduler#
 
- metrics_computer#
 
- callback#
 
- get_dataloader(data: dict, shuffle: bool) torch.utils.data.DataLoader#
 Return a DataLoader with transformations defined in the component to train.
- Parameters:
 - datadict
 Training data
- shuffle: bool
 Whether to use sequential or shuffled sampling
- Returns:
 - torch.utils.data.DataLoader
 The corresponding instance of a DataLoader
- training_epoch() dict[str, float]#
 Perform an epoch using the training data.
When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.
- Returns:
 - dict of str to float
 A dictionary containing the training metrics
- evaluation_epoch(eval_dataloader) dict[str, float]#
 Perform an epoch using the evaluation data.
The additional metrics are prepared per batch.
- Parameters:
 - eval_dataloadertorch.utils.data.DataLoader
 The evaluation dataset as a PyTorch DataLoader
- Returns:
 - dict of str to float
 A dictionary containing the evaluation metrics
- make_forward_pass(inputs: medkit.training.utils.BatchData, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor]#
 Run forward safely, same device as the component.
- update_learning_rate(eval_metrics: dict[str, float]) None#
 Call the learning rate scheduler if defined.
- train() list[dict]#
 Call the training and evaluation loop.
- Returns:
 - list of dict of str to float
 The list of computed metrics per epoch
- save(epoch: int) str#
 Save a checkpoint.
Checkpoints include trainer configuration, model weights, optimizer and scheduler.
- Parameters:
 - epochint
 Epoch corresponding of the current training state (will be included in the checkpoint name)
- Returns:
 - str
 Path to the saved checkpoint