Training

Training#

import torch
from medkit.training import TrainerConfig, Trainer
from medkit.text.metrics.ner import SeqEvalMetricsComputer
from medkit.text.ner.hf_entity_matcher import HFEntityMatcher
from medkit.io.medkit_json import load_text_documents
import os
import shutil

train, val, test = [], [], []

#Merge each corpus split into one to get a massive amount of data to fine-tune on
for c in ['quaero','e3c', 'casm2']:
    train += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/train.jsonl"))
    val += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/val.jsonl"))
    test += list(load_text_documents(f"/content/drive/MyDrive/datasets/{c}/test.jsonl"))
CHECKPOINT_DIR = "checkpoints_drbert/"

DEVICE = 0 if torch.cuda.is_available() else -1

trainable_matcher = HFEntityMatcher.make_trainable(
    model_name_or_path="Dr-BERT/DrBERT-4GB-CP-PubMedBERT",
    labels=["ANAT","CHEM","DEVI","DISO","GEOG","LIVB","OBJC","PHEN","PHYS","PROC"],
    tagging_scheme="iob2",
    tokenizer_max_length=512,
    device=DEVICE,
    tag_subtokens=True
)

trainer_config = TrainerConfig(
    output_dir=CHECKPOINT_DIR,
    learning_rate=5e-5,
    nb_training_epochs=10,
    batch_size=16,
)

ner_metrics_computer = SeqEvalMetricsComputer(
    id_to_label=trainable_matcher.id_to_label,
    tagging_scheme='iob2',
    return_metrics_by_label=False,
    average='weighted'
)

trainer = Trainer(
    config=trainer_config,
    component=trainable_matcher,
    train_data=train,
    eval_data=val,
    metrics_computer=ner_metrics_computer,
)

#Train model
history = trainer.train()

#Get best checkpoint, rename it and save it on my local drive
checkpoint_paths = sorted(glob(CHECKPOINT_DIR + "/checkpoint_*"))
checkpoint_path = checkpoint_paths[0]
os.rename(checkpoint_path, f'{CHECKPOINT_DIR}/DrBert-Generalized')
shutil.move(f'{CHECKPOINT_DIR}/DrBert-Generalized','/content/drive/MyDrive/models')
CHECKPOINT_DIR = "checkpoints_cam/"

DEVICE = 0 if torch.cuda.is_available() else -1

trainable_matcher = HFEntityMatcher.make_trainable(
    model_name_or_path="almanach/camembert-bio-base",
    labels=["ANAT","CHEM","DEVI","DISO","GEOG","LIVB","OBJC","PHEN","PHYS","PROC"],
    tagging_scheme="iob2",
    tokenizer_max_length=512,
    device=DEVICE,
    tag_subtokens=True
)

trainer_config = TrainerConfig(
    output_dir=CHECKPOINT_DIR,
    learning_rate=5e-5,
    nb_training_epochs=10,
    batch_size=16
)

ner_metrics_computer = SeqEvalMetricsComputer(
    id_to_label=trainable_matcher.id_to_label,
    tagging_scheme='iob2',
    return_metrics_by_label=False,
    average='weighted'
)

trainer = Trainer(
    config=trainer_config,
    component=trainable_matcher,
    train_data=train,
    eval_data=val,
    metrics_computer=ner_metrics_computer,
)

#Train model
history = trainer.train()

#Get best checkpoint, rename it and save it on my local drive
checkpoint_paths = sorted(glob(CHECKPOINT_DIR + "/checkpoint_*"))
checkpoint_path = checkpoint_paths[0]
os.rename(checkpoint_path, f'{CHECKPOINT_DIR}/CamemBert-Bio-Generalized')
shutil.move(f'{CHECKPOINT_DIR}/CamemBert-Bio-Generalized','/content/drive/MyDrive/models')