medkit.training#
Submodules#
Classes#
Default implementation of |
|
A TrainerCallback is the base class for trainer callbacks. |
|
TrainableComponent is the base protocol to be trainable in medkit. |
|
Class faciltating training and evaluation of PyTorch models to generate annotations. |
|
Trainer configuration. |
|
A BatchData pack data allowing both column and row access. |
|
A MetricsComputer is the base protocol to compute metrics in training. |
Package Contents#
- class medkit.training.DefaultPrinterCallback#
Bases:
TrainerCallback
Default implementation of
TrainerCallback
.- logger#
- console_handler#
- formatter#
- _progress_bar = None#
- on_train_begin(config)#
Event called at the beginning of training.
- on_epoch_end(metrics, epoch, epoch_duration)#
Event called at the end of an epoch.
- on_train_end()#
Event called at the end of training.
- on_save(checkpoint_dir)#
Event called on saving a checkpoint.
- on_step_begin(step_idx: int, nb_batches: int, phase: str)#
Event called at the beginning of a step in training.
- on_step_end(step_idx: int, nb_batches: int, phase: str)#
Event called at the end of a step in training.
- class medkit.training.TrainerCallback#
A TrainerCallback is the base class for trainer callbacks.
- on_train_begin(config: medkit.training.trainer_config.TrainerConfig)#
Event called at the beginning of training.
- on_train_end()#
Event called at the end of training.
- on_epoch_begin(epoch: int)#
Event called at the beginning of an epoch.
- on_epoch_end(metrics: dict[str, float], epoch: int, epoch_time: float)#
Event called at the end of an epoch.
- on_step_begin(step_idx: int, nb_batches: int, phase: str)#
Event called at the beginning of a step in training.
- on_step_end(step_idx: int, nb_batches: int, phase: str)#
Event called at the end of a step in training.
- on_save(checkpoint_dir: str)#
Event called on saving a checkpoint.
- class medkit.training.TrainableComponent#
Bases:
typing_extensions.Protocol
TrainableComponent is the base protocol to be trainable in medkit.
- property device: torch.device#
- configure_optimizer(lr: float) torch.optim.Optimizer #
Create optimizer using the learning rate.
- preprocess(data_item: Any) dict[str, Any] #
Run preprocessing on the input data.
Preprocess the input data item and return a dictionary with everything needed for the forward pass.
This method is intended to preprocess an input, self.collate must be used to generate batches for self.forward to run properly. Preprocess should include labels to compute a loss.
- collate(batch: list[dict[str, Any]]) medkit.training.utils.BatchData #
Collate a list of data processed by preprocess to form a batch.
- forward(input_batch: medkit.training.utils.BatchData, return_loss: bool, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor | None] #
Perform the forward pass on a batch.
Perform the forward pass on a batch and return the corresponding output as well as the loss if return_loss is True.
Before forwarding the model, this method must set the model to training or evaluation mode depending on eval_mode. In PyTorch models there are two methods to set the mode model.train() and model.eval().
- Parameters:
- input_batchBatchData
Input batch
- return_lossbool
Whether to return the computed loss as well
- eval_modebool
Whether to set the model to training (False) or evaluation mode (True)
- Returns:
- outputBatchData
Output after forward pass completion
- loss: torch.Tensor, optional
Loss after forward pass completion, if return_loss was set to True.
- save(path: str | pathlib.Path)#
Save model to disk.
- load(path: str | pathlib.Path)#
Load weights from disk.
- class medkit.training.Trainer(component: medkit.training.trainable_component.TrainableComponent, config: medkit.training.trainer_config.TrainerConfig, train_data: Any, eval_data: Any, metrics_computer: medkit.training.utils.MetricsComputer | None = None, lr_scheduler_builder: Callable[[torch.optim.Optimizer], Any] | None = None, callback: medkit.training.callbacks.TrainerCallback | None = None)#
Class faciltating training and evaluation of PyTorch models to generate annotations.
- Parameters:
- component:
The component to train, the component must implement the TrainableComponent protocol.
- config:
A TrainerConfig with the parameters for training, the parameter output_dir define the path of the checkpoints
- train_data:
The data to use for training. This should be a corpus of medkit objects. The data could be, for instance, a torch.utils.data.Dataset that returns medkit objects for training.
- eval_data:
The data to use for evaluation, this is not for testing. This should be a corpus of medkit objects. The data can be a torch.utils.data.Dataset that returns medkit objects for evaluation.
- metrics_computer:
Optional MetricsComputer object that will be used to compute custom metrics during eval. By default, only evaluation metrics will be computed, do_metrics_in_training in config allows metrics in training.
- lr_scheduler_builder:
Optional function that build a lr_scheduler to adjust the learning rate after an epoch. Must take an Optimizer and return a lr_scheduler. If not provided, the learning rate does not change during training.
- callback:
Optional callback to customize training.
- output_dir#
- component#
- batch_size#
- dataloader_drop_last = False#
- dataloader_nb_workers#
- dataloader_pin_memory = False#
- device#
- train_dataloader#
- eval_dataloader#
- nb_training_epochs#
- config#
- optimizer#
- lr_scheduler#
- metrics_computer#
- callback#
- get_dataloader(data: dict, shuffle: bool) torch.utils.data.DataLoader #
Return a DataLoader with transformations defined in the component to train.
- Parameters:
- datadict
Training data
- shuffle: bool
Whether to use sequential or shuffled sampling
- Returns:
- torch.utils.data.DataLoader
The corresponding instance of a DataLoader
- training_epoch() dict[str, float] #
Perform an epoch using the training data.
When the config enabled metrics in training (‘do_metrics_in_training’ is True), the additional metrics are prepared per batch.
- Returns:
- dict of str to float
A dictionary containing the training metrics
- evaluation_epoch(eval_dataloader) dict[str, float] #
Perform an epoch using the evaluation data.
The additional metrics are prepared per batch.
- Parameters:
- eval_dataloadertorch.utils.data.DataLoader
The evaluation dataset as a PyTorch DataLoader
- Returns:
- dict of str to float
A dictionary containing the evaluation metrics
- make_forward_pass(inputs: medkit.training.utils.BatchData, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch.Tensor] #
Run forward safely, same device as the component.
- update_learning_rate(eval_metrics: dict[str, float]) None #
Call the learning rate scheduler if defined.
- train() list[dict] #
Call the training and evaluation loop.
- Returns:
- list of dict of str to float
The list of computed metrics per epoch
- save(epoch: int) str #
Save a checkpoint.
Checkpoints include trainer configuration, model weights, optimizer and scheduler.
- Parameters:
- epochint
Epoch corresponding of the current training state (will be included in the checkpoint name)
- Returns:
- str
Path to the saved checkpoint
- class medkit.training.TrainerConfig#
Trainer configuration.
- Parameters:
- output_dir:
The output directory where the checkpoint will be saved.
- learning_rate:
The initial learning rate.
- nb_training_epochs:
Total number of training/evaluation epochs to do.
- dataloader_nb_workers:
Number of subprocess for the data loading. The default value is 0, the data will be loaded in the main process. If this config is for a HuggingFace model, do not change this value.
- batch_size:
Number of samples per batch to load.
- seed:
Random seed to use with PyTorch and numpy. It should be set to ensure reproducibility between experiments.
- gradient_accumulation_steps:
Number of steps to accumulate gradient before performing an optimization step.
- do_metrics_in_training:
By default, only the custom metrics are computed using eval_data. If set to True, the custom metrics are computed also using training_data.
- metric_to_track_lr:
Name of the eval metric to be tracked for updating the learning rate. By default, eval loss is tracked.
- checkpoint_period:
How often, in number of epochs, should we save a checkpoint. Use 0 to only save last checkpoint.
- checkpoint_metric:
Name of the eval metric to be tracked for selecting the best checkpoint. By default, eval loss is tracked.
- minimize_checkpoint_metric:
If True, the checkpoint with the lowest metric value will be selected as best, otherwise the checkpoint with the highest metric value.
- output_dir: str#
- learning_rate: float = 1e-05#
- nb_training_epochs: int = 3#
- dataloader_nb_workers: int = 0#
- batch_size: int = 1#
- seed: int | None = None#
- gradient_accumulation_steps: int = 1#
- do_metrics_in_training: bool = False#
- metric_to_track_lr: str = 'loss'#
- checkpoint_period: int = 1#
- checkpoint_metric: str = 'loss'#
- minimize_checkpoint_metric: bool = True#
- to_dict() dict[str, Any] #
- class medkit.training.BatchData#
Bases:
dict
A BatchData pack data allowing both column and row access.
- __getitem__(index: int) dict[str, list[Any] | torch.Tensor] #
x.__getitem__(y) <==> x[y]
- to_device(device: torch.device) typing_extensions.Self #
Ensure that Tensors in the BatchData object are on the specified device.
- Parameters:
- device:
A torch.device object representing the device on which tensors will be allocated.
- Returns:
- BatchData
A new object with the tensors on the proper device.
- class medkit.training.MetricsComputer#
Bases:
typing_extensions.Protocol
A MetricsComputer is the base protocol to compute metrics in training.
- prepare_batch(model_output: BatchData, input_batch: BatchData) dict[str, list[Any]] #
Prepare a batch of data to compute the metrics.
- Parameters:
- model_output: BatchData
Output data after a model forward pass.
- input_batch: BatchData
Preprocessed input batch
- Returns:
- dict[str, List[Any]]
A dictionary with the required data to calculate the metric
- compute(all_data: dict[str, list[Any]]) dict[str, float] #
Compute metrics using ‘all_data’.
- Parameters:
- all_data: dict[str, List[Any]]
A dictionary to compute the metrics. i.e. A dictionary with a list of ‘references’ and a list of ‘predictions’.
- Returns:
- dict[str, float]
A dictionary with the results