medkit.training.utils#

Classes#

BatchData

A BatchData pack data allowing both column and row access.

MetricsComputer

A MetricsComputer is the base protocol to compute metrics in training.

Module Contents#

class medkit.training.utils.BatchData#

Bases: dict

A BatchData pack data allowing both column and row access.

__getitem__(index: int) dict[str, list[Any] | torch.Tensor]#

x.__getitem__(y) <==> x[y]

to_device(device: torch.device) typing_extensions.Self#

Ensure that Tensors in the BatchData object are on the specified device.

Parameters:
device:

A torch.device object representing the device on which tensors will be allocated.

Returns:
BatchData

A new object with the tensors on the proper device.

class medkit.training.utils.MetricsComputer#

Bases: typing_extensions.Protocol

A MetricsComputer is the base protocol to compute metrics in training.

prepare_batch(model_output: BatchData, input_batch: BatchData) dict[str, list[Any]]#

Prepare a batch of data to compute the metrics.

Parameters:
model_output: BatchData

Output data after a model forward pass.

input_batch: BatchData

Preprocessed input batch

Returns:
dict[str, List[Any]]

A dictionary with the required data to calculate the metric

compute(all_data: dict[str, list[Any]]) dict[str, float]#

Compute metrics using ‘all_data’.

Parameters:
all_data: dict[str, List[Any]]

A dictionary to compute the metrics. i.e. A dictionary with a list of ‘references’ and a list of ‘predictions’.

Returns:
dict[str, float]

A dictionary with the results