medkit.training.trainer#

Classes#

Trainer

Class faciltating training and evaluation of PyTorch models to generate annotations.

Module Contents#

class medkit.training.trainer.Trainer(component: medkit.training.trainable_component.TrainableComponent, config: medkit.training.trainer_config.TrainerConfig, train_data: Any, eval_data: Any, metrics_computer: medkit.training.utils.MetricsComputer | None = None, lr_scheduler_builder: Callable[[torch.optim.Optimizer], Any] | None = None, callback: medkit.training.callbacks.TrainerCallback | None = None)#

Class faciltating training and evaluation of PyTorch models to generate annotations.

Parameters:
component:

The component to train, the component must implement the TrainableComponent protocol.

config:

A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpoints

train_data:

The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.

eval_data:

The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.

metrics_computer:

Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.

lr_scheduler_builder:

Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.

callback:

Optional callback to customize training.

output_dir#
component#
batch_size#
dataloader_drop_last = False#
dataloader_nb_workers#
dataloader_pin_memory = False#
device#
train_dataloader#
eval_dataloader#
nb_training_epochs#
config#
optimizer#
lr_scheduler#
metrics_computer#
callback#
get_dataloader(data: dict, shuffle: bool) torch.utils.data.DataLoader#

Return a DataLoader with transformations defined in the component to train.

Parameters:
datadict

Training data

shuffle: bool

Whether to use sequential or shuffled sampling

Returns:
torch.utils.data.DataLoader

The corresponding instance of a DataLoader

training_epoch() dict[str, float]#

Perform an epoch using the training data.

When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.

Returns:
dict of str to float

A dictionary containing the training metrics

evaluation_epoch(eval_dataloader) dict[str, float]#

Perform an epoch using the evaluation data.

The additional metrics are prepared per batch.

Parameters:
eval_dataloadertorch.utils.data.DataLoader

The evaluation dataset as a PyTorch DataLoader

Returns:
dict of str to float

A dictionary containing the evaluation metrics

make_forward_pass(inputs: medkit.training.utils.BatchData, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor]#

Run forward safely, same device as the component.

update_learning_rate(eval_metrics: dict[str, float]) None#

Call the learning rate scheduler if defined.

train() list[dict]#

Call the training and evaluation loop.

Returns:
list of dict of str to float

The list of computed metrics per epoch

save(epoch: int) str#

Save a checkpoint.

Checkpoints include trainer configuration, model weights, optimizer and scheduler.

Parameters:
epochint

Epoch corresponding of the current training state (will be included in the checkpoint name)

Returns:
str

Path to the saved checkpoint