Training#
This page describes all components related to medkit training.
Important
This module requires additional dependencies:
pip install 'medkit-lib[training]'
For more details, please refer to medkit.training
.
Becoming Trainable#
A component should implement the TrainableComponent
protocol to be trainable within medkit
.
With this protocol, you can define how to preprocess data, call the model and define the optimizer.
Then, the Trainer
will use these methods inside the training and evaluation loop.
The following table explains who makes the calls and where they make them:
Who |
Where |
A TrainableComponent |
---|---|---|
TrainableComponent |
Initialization |
load : load / initialize modules to be trained |
Trainer |
Initialization |
create_optimizer : define an optimizer for the training / evaluation loop |
Data loading |
preprocess: transform annotations to input data |
|
Forward step |
forward: call internal model, return loss and model output |
|
Saving checkpoint |
save: save trained modules |
Trainable Entity Detection#
A trainable component could define how to train a model from scratch or fine-tune a pretrained model.
As a first implementation, medkit
includes HFEntityMatcherTrainable
,
a trainable version of HFEntityMatcher
.
As you can see, an operation can contain a trainable component and expose it using the make_trainable()
method.
Please refer to this example for a fine-tuning case for entity detection.
Important
Currently, medkit
only supports the training of components using PyTorch.
For more details, please refer to medkit.training.trainable_component
module.
Trainer#
The Trainer
aims to train any component implementing the TrainableComponent
protocol.
For each step involving data transformation, the Trainer
calls the corresponding methods in the TrainableComponent
.
For example, if you want to train a SegmentClassifier
,
you can define how to preprocess the Segment
with its Attribute
to get a dictionary of tensors for the model.
Under the hood, the training loop will call SegmentClassifier.preprocess()
and SegmentClassifier.collate()
inside the training_dataloader
to transform medkit
segments into a batch of tensors.
# 1. Initialize the trainable component i.e. a segment_classifier
segment_classifier = SegmentClassifier(...)
# 2. Load/prepare the set of medkit anns (segments)
# 3. Define hyperparameters for the trainer
trainer_config = TrainerConfig(...)
trainer = Trainer(
component=segment_classifier, # trainable component
config=trainer_config, # configuration
train_data=train_dataset, # training documents
eval_data=val_dataset, # eval documents
)
History#
Once the trainer has been configured, you can start the training using trainer.train()
.
The method returns a dictionary with the metrics during training and evaluation by epoch.
history = trainer.train()
The trainer controls the calling of methods and optional modules, here a simplified version of the training loop.
for input_data in training_dataloader:
callback_on_step()
input_data = input_data.to_device(device)
output_data, loss = trainableComponent.forward(input_data)
loss.backward()
optimizer.step()
# if metrics_computer is defined
data_for_metrics.extend(metrics_computer.prepare_batch(input_data,output_data))
...
# compute metrics
metrics_computer.compute(data_for_metrics)
For more details, please refer to medkit.training.trainer
module.
Custom Training#
Hyperparameters#
The TrainerConfig
allows you to define learning parameters
such as learning rate, number of epochs, etc.
Metrics Computer#
Custom metrics can be provided to training.
You can define how to prepare a batch for the metric and how to compute the metric.
For more details, refer to the medkit.training.MetricsComputer
protocol.
Tip
For the moment, medkit includes SeqEvalMetricsComputer
for entity detection.
This is still in development, you can integrate more metrics depending on your task / modality.
Learning Rate Scheduler#
You can define how to adjust learning rate.
If you use PyTorch models, you can use a method from torch.optim.lr_scheduler
For example, you can update the learning rate every 5 optimization steps:
import torch
lr_scheduler_builder=lambda optimizer: torch.optim.lr_scheduler.StepLR(optimizer, step_size=5)
trainer = Trainer(..., lr_scheduler_builder=lr_scheduler_builder)
If you use transformer models, you may refer to the get_scheduler
method.
Callbacks#
medkit
provides a set of callbacks to extend training for features like logging information.
For using these callbacks, you need to implement a class derived from TrainerCallback
.
If you do not provide your own one to the Trainer
,
it will use the DefaultPrinterCallback
.
For more details, please refer to medkit.training.callbacks
module.
Note
This module is under development and may add support for more powerful callbacks.