medkit.text.ner.hf_entity_matcher_trainable#

Classes#

HFEntityMatcherTrainable

Trainable entity matcher based on HuggingFace transformers model.

Module Contents#

class medkit.text.ner.hf_entity_matcher_trainable.HFEntityMatcherTrainable(model_name_or_path: str | pathlib.Path, labels: list[str], tagging_scheme: typing_extensions.Literal[bilou, iob2], tag_subtokens: bool = False, tokenizer_max_length: int | None = None, hf_auth_token: str | None = None, device: int = -1)#

Trainable entity matcher based on HuggingFace transformers model.

Any token classification model from the HuggingFace hub can be used (for instance “samrawal/bert-base-uncased_clinical-ner”).

Parameters:
model_name_or_pathstr or Path

Name (on the HuggingFace models hub) or path of the NER model. Must be a model compatible with the TokenClassification transformers class.

labelslist of str

List of labels to detect

tagging_scheme{“bilou”, “iob2”}

Tagging scheme to use in the segment-entities preprocessing and label mapping definition.

tag_subtokensbool, default=False

Whether tag subtokens in a word. PreTrained models require a tokenization step. If any word of the segment is not in the vocabulary of the tokenizer used by the PreTrained model, the word is split into subtokens. It is recommended to only tag the first subtoken of a word. However, it is possible to tag all subtokens by setting this value to True. It could influence the time and results of fine-tunning.

tokenizer_max_lengthint, optional

Optional max length for the tokenizer, by default the model_max_length will be used.

hf_auth_tokenstr, optional

HuggingFace Authentication token (to access private models on the hub)

deviceint, default=-1

Device to use for the transformer model. Follows the HuggingFace convention (-1 for “cpu” and device number for gpu, for instance 0 for “cuda:0”).

valid_model#
model_name_or_path#
tagging_scheme#
tag_subtokens#
tokenizer_max_length#
model_config#
label_to_id#
id_to_label#
device#
_data_collator#
configure_optimizer(lr: float) torch#
preprocess(data_item: medkit.core.text.TextDocument) dict[str, Any]#
_encode_text(text)#

Return a EncodingFast instance.

collate(batch: list[dict[str, Any]]) medkit.training.utils.BatchData#
forward(input_batch: medkit.training.utils.BatchData, return_loss: bool, eval_mode: bool) tuple[medkit.training.utils.BatchData, torch | None]#
save(path: str | pathlib.Path)#
load(path: str | pathlib.Path, hf_auth_token: str | None = None)#
_get_valid_model_config(labels: list[str], hf_auth_token: str | None = None)#

Return a config file with the correct mapping of labels.