medkit.training

Contents

medkit.training#

Submodules#

Classes#

DefaultPrinterCallback

Default implementation of TrainerCallback.

TrainerCallback

A TrainerCallback is the base class for trainer callbacks.

TrainableComponent

TrainableComponent is the base protocol to be trainable in medkit.

Trainer

Class faciltating training and evaluation of PyTorch models to generate annotations.

TrainerConfig

Trainer configuration.

BatchData

A BatchData pack data allowing both column and row access.

MetricsComputer

A MetricsComputer is the base protocol to compute metrics in training.

Package Contents#

class medkit.training.DefaultPrinterCallback#

Bases: TrainerCallback

Default implementation of TrainerCallback.

logger#
console_handler#
formatter#
_progress_bar = None#
on_train_begin(config)#

Event called at the beginning of training.

on_epoch_end(metrics, epoch, epoch_duration)#

Event called at the end of an epoch.

on_train_end()#

Event called at the end of training.

on_save(checkpoint_dir)#

Event called on saving a checkpoint.

on_step_begin(step_idx: int, nb_batches: int, phase: str)#

Event called at the beginning of a step in training.

on_step_end(step_idx: int, nb_batches: int, phase: str)#

Event called at the end of a step in training.

class medkit.training.TrainerCallback#

A TrainerCallback is the base class for trainer callbacks.

on_train_begin(config: medkit.training.trainer_config.TrainerConfig)#

Event called at the beginning of training.

on_train_end()#

Event called at the end of training.

on_epoch_begin(epoch: int)#

Event called at the beginning of an epoch.

on_epoch_end(metrics: dict[str, float], epoch: int, epoch_time: float)#

Event called at the end of an epoch.

on_step_begin(step_idx: int, nb_batches: int, phase: str)#

Event called at the beginning of a step in training.

on_step_end(step_idx: int, nb_batches: int, phase: str)#

Event called at the end of a step in training.

on_save(checkpoint_dir: str)#

Event called on saving a checkpoint.

class medkit.training.TrainableComponent#

Bases: typing_extensions.Protocol

TrainableComponent is the base protocol to be trainable in medkit.

property device: torch.device#
configure_optimizer(lr: float) torch.optim.Optimizer#

Create optimizer using the learning rate.

preprocess(data_item: Any) dict[str, Any]#

Run preprocessing on the input data.

Preprocess the input data item and return a dictionary with everything needed for the forward pass.

This method is intended to preprocess an input, self.collate must be used to generate batches for self.forward to run properly. Preprocess should include labels to compute a loss.

collate(batch: list[dict[str, Any]]) medkit.training.utils.BatchData#

Collate a list of data processed by preprocess to form a batch.

forward(input_batch: medkit.training.utils.BatchData, return_loss: bool, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor | None]#

Perform the forward pass on a batch.

Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.

Before forwarding the model, this method must set the model to training or evaluation mode depending on eval_mode. In PyTorch models there are two methods to set the mode model.train() and model.eval().

Parameters:
input_batchBatchData

Input batch

return_lossbool

Whether to return the computed loss as well

eval_modebool

Whether to set the model to training (False) or evaluation mode (True)

Returns:
outputBatchData

Output after forward pass completion

loss: torch.Tensor, optional

Loss after forward pass completion, if return_loss was set to True.

save(path: str | pathlib.Path)#

Save model to disk.

load(path: str | pathlib.Path)#

Load weights from disk.

class medkit.training.Trainer(component: medkit.training.trainable_component.TrainableComponent, config: medkit.training.trainer_config.TrainerConfig, train_data: Any, eval_data: Any, metrics_computer: medkit.training.utils.MetricsComputer | None = None, lr_scheduler_builder: Callable[[torch.optim.Optimizer], Any] | None = None, callback: medkit.training.callbacks.TrainerCallback | None = None)#

Class faciltating training and evaluation of PyTorch models to generate annotations.

Parameters:
component:

The component to train, the component must implement the TrainableComponent protocol.

config:

A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpoints

train_data:

The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.

eval_data:

The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.

metrics_computer:

Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.

lr_scheduler_builder:

Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.

callback:

Optional callback to customize training.

output_dir#
component#
batch_size#
dataloader_drop_last = False#
dataloader_nb_workers#
dataloader_pin_memory = False#
device#
train_dataloader#
eval_dataloader#
nb_training_epochs#
config#
optimizer#
lr_scheduler#
metrics_computer#
callback#
get_dataloader(data: dict, shuffle: bool) torch.utils.data.DataLoader#

Return a DataLoader with transformations defined in the component to train.

Parameters:
datadict

Training data

shuffle: bool

Whether to use sequential or shuffled sampling

Returns:
torch.utils.data.DataLoader

The corresponding instance of a DataLoader

training_epoch() dict[str, float]#

Perform an epoch using the training data.

When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.

Returns:
dict of str to float

A dictionary containing the training metrics

evaluation_epoch(eval_dataloader) dict[str, float]#

Perform an epoch using the evaluation data.

The additional metrics are prepared per batch.

Parameters:
eval_dataloadertorch.utils.data.DataLoader

The evaluation dataset as a PyTorch DataLoader

Returns:
dict of str to float

A dictionary containing the evaluation metrics

make_forward_pass(inputs: medkit.training.utils.BatchData, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor]#

Run forward safely, same device as the component.

update_learning_rate(eval_metrics: dict[str, float]) None#

Call the learning rate scheduler if defined.

train() list[dict]#

Call the training and evaluation loop.

Returns:
list of dict of str to float

The list of computed metrics per epoch

save(epoch: int) str#

Save a checkpoint.

Checkpoints include trainer configuration, model weights, optimizer and scheduler.

Parameters:
epochint

Epoch corresponding of the current training state (will be included in the checkpoint name)

Returns:
str

Path to the saved checkpoint

class medkit.training.TrainerConfig#

Trainer configuration.

Parameters:
output_dir:

The output directory where the checkpoint will be saved.

learning_rate:

The initial learning rate.

nb_training_epochs:

Total number of training/evaluation epochs to do.

dataloader_nb_workers:

Number of subprocess for the data loading. The default value is 0, the data will be loaded in the main process. If this config is for a HuggingFace model, do not change this value.

batch_size:

Number of samples per batch to load.

seed:

Random seed to use with PyTorch and numpy. It should be set to ensure reproducibility between experiments.

gradient_accumulation_steps:

Number of steps to accumulate gradient before performing an optimization step.

do_metrics_in_training:

By default, only the custom metrics are computed using eval_data. If set to True, the custom metrics are computed also using training_data.

metric_to_track_lr:

Name of the eval metric to be tracked for updating the learning rate. By default, eval loss is tracked.

checkpoint_period:

How often, in number of epochs, should we save a checkpoint. Use 0 to only save last checkpoint.

checkpoint_metric:

Name of the eval metric to be tracked for selecting the best checkpoint. By default, eval loss is tracked.

minimize_checkpoint_metric:

If True, the checkpoint with the lowest metric value will be selected as best, otherwise the checkpoint with the highest metric value.

output_dir: str#
learning_rate: float = 1e-05#
nb_training_epochs: int = 3#
dataloader_nb_workers: int = 0#
batch_size: int = 1#
seed: int | None = None#
gradient_accumulation_steps: int = 1#
do_metrics_in_training: bool = False#
metric_to_track_lr: str = 'loss'#
checkpoint_period: int = 1#
checkpoint_metric: str = 'loss'#
minimize_checkpoint_metric: bool = True#
to_dict() dict[str, Any]#
class medkit.training.BatchData#

Bases: dict

A BatchData pack data allowing both column and row access.

__getitem__(index: int) dict[str, list[Any] | torch.Tensor]#

x.__getitem__(y) <==> x[y]

to_device(device: torch.device) typing_extensions.Self#

Ensure that Tensors in the BatchData object are on the specified device.

Parameters:
device:

A torch.device object representing the device on which tensors will be allocated.

Returns:
BatchData

A new object with the tensors on the proper device.

class medkit.training.MetricsComputer#

Bases: typing_extensions.Protocol

A MetricsComputer is the base protocol to compute metrics in training.

prepare_batch(model_output: BatchData, input_batch: BatchData) dict[str, list[Any]]#

Prepare a batch of data to compute the metrics.

Parameters:
model_output: BatchData

Output data after a model forward pass.

input_batch: BatchData

Preprocessed input batch

Returns:
dict[str, List[Any]]

A dictionary with the required data to calculate the metric

compute(all_data: dict[str, list[Any]]) dict[str, float]#

Compute metrics using ‘all_data’.

Parameters:
all_data: dict[str, List[Any]]

A dictionary to compute the metrics. i.e. A dictionary with a list of ‘references’ and a list of ‘predictions’.

Returns:
dict[str, float]

A dictionary with the results