medkit.training.utils#
Classes#
A BatchData pack data allowing both column and row access.  | 
|
A MetricsComputer is the base protocol to compute metrics in training.  | 
Module Contents#
- class medkit.training.utils.BatchData#
 Bases:
dictA BatchData pack data allowing both column and row access.
- __getitem__(index: int) dict[str, list[Any] | torch.Tensor]#
 x.__getitem__(y) <==> x[y]
- to_device(device: torch.device) typing_extensions.Self#
 Ensure that Tensors in the BatchData object are on the specified device.
- Parameters:
 - device:
 A torch.device object representing the device on which tensors will be allocated.
- Returns:
 - BatchData
 A new object with the tensors on the proper device.
- class medkit.training.utils.MetricsComputer#
 Bases:
typing_extensions.ProtocolA MetricsComputer is the base protocol to compute metrics in training.
- prepare_batch(model_output: BatchData, input_batch: BatchData) dict[str, list[Any]]#
 Prepare a batch of data to compute the metrics.
- Parameters:
 - model_output: BatchData
 Output data after a model forward pass.
- input_batch: BatchData
 Preprocessed input batch
- Returns:
 - dict[str, List[Any]]
 A dictionary with the required data to calculate the metric
- compute(all_data: dict[str, list[Any]]) dict[str, float]#
 Compute metrics using ‘all_data’.
- Parameters:
 - all_data: dict[str, List[Any]]
 A dictionary to compute the metrics. i.e. A dictionary with a list of ‘references’ and a list of ‘predictions’.
- Returns:
 - dict[str, float]
 A dictionary with the results