medkit.training.utils#
Classes#
A BatchData pack data allowing both column and row access. |
|
A MetricsComputer is the base protocol to compute metrics in training. |
Module Contents#
- class medkit.training.utils.BatchData#
Bases:
dict
A BatchData pack data allowing both column and row access.
- __getitem__(index: int) dict[str, list[Any] | torch.Tensor] #
x.__getitem__(y) <==> x[y]
- to_device(device: torch.device) typing_extensions.Self #
Ensure that Tensors in the BatchData object are on the specified device.
- Parameters:
- device:
A torch.device object representing the device on which tensors will be allocated.
- Returns:
- BatchData
A new object with the tensors on the proper device.
- class medkit.training.utils.MetricsComputer#
Bases:
typing_extensions.Protocol
A MetricsComputer is the base protocol to compute metrics in training.
- prepare_batch(model_output: BatchData, input_batch: BatchData) dict[str, list[Any]] #
Prepare a batch of data to compute the metrics.
- Parameters:
- model_output: BatchData
Output data after a model forward pass.
- input_batch: BatchData
Preprocessed input batch
- Returns:
- dict[str, List[Any]]
A dictionary with the required data to calculate the metric
- compute(all_data: dict[str, list[Any]]) dict[str, float] #
Compute metrics using ‘all_data’.
- Parameters:
- all_data: dict[str, List[Any]]
A dictionary to compute the metrics. i.e. A dictionary with a list of ‘references’ and a list of ‘predictions’.
- Returns:
- dict[str, float]
A dictionary with the results