Code To Align Annotations With Huggingface Tokenizers

Tal Perry
5 min read

This post comes with a repo

Our previous post on aligning span annotations to Hugginface's tokenizer outputs discussed the various tradeoffs one needs to consider, and concluded that a windowing strategy over the tokenized text and labels is optimal for our use cases.

This post demonstrates an end to end implementation of token alignment and windowing. We'll start by implementing utility classes that make programming a little easier, then implement the alignment functionality which aligns offset annotations to the out of a tokenizer. Finnaly we'll implement a PyTorch Dataset that stores our aligned tokens and labels as windows, a Collator to implement batching and a simple DataLoader to be used in training.

We'll show and end to end flow on the DDI Corpus, recognizing pharmacological entities with BERT.

Utility Classes For Convenient APIs

We'll start by defining some types and utility classes that will make our work more convenient

from typing_extensions import TypedDict
from typing import List,Any
IntList = List[int] # A list of token_ids
IntListList = List[IntList] # A List of List of token_ids, e.g. a Batch

The Alignment Algorithm

FastTokenizers Simplify Alignment

Recent versions of Hugginface's tokenizers library include variants of Tokenizers that end with Fast and inherit from PreTrainedTokenizerFast
such as BertTokenizerFast and GPT2TokenizerFast.

Per the tokenizer's documentation

When the tokenizer is a “Fast” tokenizer (i.e., backed by HuggingFace tokenizers library), [the output] provides in addition several advanced alignment methods which can be used to map between the original string (character and words) and the token space (e.g., getting the index of the token comprising a given character or the span of characters corresponding to a given token).

Notably, the output provides the methods tokentochars and chartotoken which do exactly what their name implies, provide mappings between tokens and character offsets in the original text. That's exactly what we need to align annotations in offset format with tokens.

A warmup implementation

Our final implementation will use the BIOUL scheme we mentioned before. But before we do that, let's try a simple alignment to see what it feels like

text = "I am Tal Perry, founder of LightTag"
annotations = [
    dict(start=5,end=14,text="Tal Perry",label="Person"),
    dict(start=16,end=23,text="founder",label="Title"),
    dict(start=27,end=35,text="LightTag",label="Org"),
    
              ]
for anno in annotations:
    # Show our annotations
    print (text[anno['start']:anno['end']],anno['label'])
    
Tal Perry Person
founder Title
LightTag Org
from transformers import BertTokenizerFast,  BatchEncoding
from tokenizers import Encoding
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased') # Load a pre-trained tokenizer
tokenized_batch : BatchEncoding = tokenizer(text)
tokenized_text :Encoding  =tokenized_batch[0]
tokens = tokenized_text.tokens
aligned_labels = ["O"]*len(tokens) # Make a list to store our labels the same length as our tokens
for anno in (annotations):
    for char_ix in range(anno['start'],anno['end']):
        token_ix = tokenized_text.char_to_token(char_ix)
        if token_ix is not None: # White spaces have no token and will return None
            aligned_labels[token_ix] = anno['label']
for token,label in zip(tokens,aligned_labels):
    print (token,"-",label)
[CLS] - O
I - O
am - O
Ta - Person
##l - Person
Perry - Person
, - O
founder - Title
of - O
Light - Org
##T - Org
##ag - Org
[SEP] - O

Accounting For Multi Token Annotations

In the above example, some of our annotations spanned multiple tokens. For instance "Tal Perry" spanned "Ta", "##l" and "Perry". Clearly by themselves none of those tokens are a Person, and so our current alignment scheme isn't as useful as it could be. To overcome that, we'll use the previously mentioned BIOLU scheme, which will indicate if a token is the beginning, inside, last token in an annotation or if it is not part of an annotation or if it is perfectly aligned with an annotation.

def align_tokens_and_annotations_bilou(tokenized: Encoding, annotations):
    tokens = tokenized.tokens
    aligned_labels = ["O"] * len(
        tokens
    )  # Make a list to store our labels the same length as our tokens
    for anno in annotations:
        annotation_token_ix_set = (
            set()
        )  # A set that stores the token indices of the annotation
        for char_ix in range(anno["start"], anno["end"]):

            token_ix = tokenized.char_to_token(char_ix)
            if token_ix is not None:
                annotation_token_ix_set.add(token_ix)
        if len(annotation_token_ix_set) == 1:
            # If there is only one token
            token_ix = annotation_token_ix_set.pop()
            prefix = (
                "U"  # This annotation spans one token so is prefixed with U for unique
            )
            aligned_labels[token_ix] = f"{prefix}-{anno['label']}"

        else:

            last_token_in_anno_ix = len(annotation_token_ix_set) - 1
            for num, token_ix in enumerate(sorted(annotation_token_ix_set)):
                if num == 0:
                    prefix = "B"
                elif num == last_token_in_anno_ix:
                    prefix = "L"  # Its the last token
                else:
                    prefix = "I"  # We're inside of a multi token annotation
                aligned_labels[token_ix] = f"{prefix}-{anno['label']}"
    return aligned_labels


labels = align_tokens_and_annotations_bilou(tokenized_text, annotations)
for token, label in zip(tokens, labels):
    print(token, "-", label)
[CLS] - O
I - O
am - O
Ta - B-Person
##l - I-Person
Perry - L-Person
, - O
founder - U-Title
of - O
Light - B-Org
##T - I-Org
##ag - L-Org
[SEP] - O

Notice how founder above has a U prefix and the other annotations now follow a BIL scheme.

Mapping Labels To Ids

It's great that we have our annotations aligned, but we need the labels as integer ids for training. During inference, we'll also need a way to map predicted ids back to labels. I'm going to make a custom class that handles that, called a LabelSet.

import itertools


class LabelSet:
    def __init__(self, labels: List[str]):
        self.labels_to_id = {}
        self.ids_to_label = {}
        self.labels_to_id["O"] = 0
        self.ids_to_label[0] = "O"
        num = 0  # in case there are no labels
        # Writing BILU will give us incremntal ids for the labels
        for _num, (label, s) in enumerate(itertools.product(labels, "BILU")):
            num = _num + 1  # skip 0
            l = f"{s}-{label}"
            self.labels_to_id[l] = num
            self.ids_to_label[num] = l
        # Add the OUTSIDE label - no label for the token

    def get_aligned_label_ids_from_annotations(self, tokenized_text, annotations):
        raw_labels = align_tokens_and_annotations_bilou(tokenized_text, annotations)    
        return list(map(self.labels_to_id.get, raw_labels))


example_label_set = LabelSet(labels=["Person", "Org", "Title"])
aligned_label_ids = example_label_set.get_aligned_label_ids_from_annotations(
    tokenized_text, annotations
)

for token, label in zip(tokens, aligned_label_ids):
    print(token, "-", label)
[CLS] - 0
I - 0
am - 0
Ta - 1
##l - 2
Perry - 3
, - 0
founder - 12
of - 0
Light - 5
##T - 6
##ag - 7
[SEP] - 0

Batching

Now that we have alignment logic in place, we need to figure out how to load, batch and pad the data. We also need to handle the case where our text is longer than we can feed our model. Below we show an implementation of a particular strategy, windowing over uniform length segments of the text. This isn't the only strategy, or even necessarily the best, but it fits our use case well. You can read more about why we use windowing when training ner models with BERT here. Below we'll just show how to do that.

The Raw Dataset

We'll be using the DDI Corpus. You can download a JSON verion of it here. Let's take a quick look at the data

import json
from pprint import pprint

raw = json.load(open("./ddi_train.json"))
for example in raw:
    # our simple implementation expects the label to be called label, so we adjust the original data
    for anno in example["annotations"]:
        anno["label"] = anno["tag"]
pprint(raw[2])
{'annotations': [{'end': 58, 'label': 'drug', 'start': 47, 'tag': 'drug'},
                 {'end': 75, 'label': 'drug', 'start': 62, 'tag': 'drug'},
                 {'end': 135, 'label': 'drug', 'start': 124, 'tag': 'drug'},
                 {'end': 164, 'label': 'drug', 'start': 152, 'tag': 'drug'}],
 'content': 'Pharmacokinetic studies have demonstrated that omeprazole and '
            'erythromycin significantly increased the systemic exposure of '
            'cilostazol and/or its major metabolites.',
 'metadata': {'original_id': 'DrugDDI.d452.s1'}}

Lets take a look at that tokenized and aligned

example = raw[2]
tokenized_batch = tokenizer(example["content"])
tokenized_text = tokenized_batch[0]
labels = align_tokens_and_annotations_bilou(tokenized_text, example["annotations"])
for token, label in zip(tokenized_text.tokens, labels):
    print(token, "-", label)
[CLS] - O
Ph - O
##arma - O
##co - O
##kin - O
##etic - O
studies - O
have - O
demonstrated - O
that - O
o - B-drug
##me - I-drug
##pra - I-drug
##zo - I-drug
##le - L-drug
and - O
er - B-drug
##yt - I-drug
##hr - I-drug
##omy - I-drug
##cin - L-drug
significantly - O
increased - O
the - O
systemic - O
exposure - O
of - O
c - B-drug
##ilo - I-drug
##sta - I-drug
##zo - I-drug
##l - L-drug
and - O
/ - O
or - O
its - O
major - O
meta - B-drug
##bol - I-drug
##ites - I-drug
. - L-drug
[SEP] - O

Padding and Windowing in a Dataset

Our dataset is conveniently split into sentences. We still need to batch it and pad the examples. More commonly, data is not split into sentences, and so we will window over fixed sized parts of it. The windowing, padding and alignment logic will be done in a pytorch Dataset and we'll get to batching in a moment.

from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizerFast
@dataclass
class TrainingExample:
    input_ids: IntList
    attention_masks: IntList
    labels: IntList


class TraingDataset(Dataset):
    def __init__(
        self,
        data: Any,
        label_set: LabelSet,
        tokenizer: PreTrainedTokenizerFast,
        tokens_per_batch=32,
        window_stride=None,
    ):
        self.label_set = label_set
        if window_stride is None:
            self.window_stride = tokens_per_batch
        self.tokenizer = tokenizer
        for example in data:
            # changes tag key to label
            for a in example["annotations"]:
                a["label"] = a["tag"]
        self.texts = []
        self.annotations = []

        for example in data:
            self.texts.append(example["content"])
            self.annotations.append(example["annotations"])
        ###TOKENIZE All THE DATA
        tokenized_batch = self.tokenizer(self.texts, add_special_tokens=False)
        ###ALIGN LABELS ONE EXAMPLE AT A TIME
        aligned_labels = []
        for ix in range(len(tokenized_batch.encodings)):
            encoding = tokenized_batch.encodings[ix]
            raw_annotations = self.annotations[ix]
            aligned = label_set.get_aligned_label_ids_from_annotations(
                encoding, raw_annotations
            )
            aligned_labels.append(aligned)
        ###END OF LABEL ALIGNMENT

        ###MAKE A LIST OF TRAINING EXAMPLES. (This is where we add padding)
        self.training_examples: List[TrainingExample] = []
        empty_label_id = "O"
        for encoding, label in zip(tokenized_batch.encodings, aligned_labels):
            length = len(label)  # How long is this sequence
            for start in range(0, length, self.window_stride):

                end = min(start + tokens_per_batch, length)

                # How much padding do we need ?
                padding_to_add = max(0, tokens_per_batch - end + start)
                self.training_examples.append(
                    TrainingExample(
                        # Record the tokens
                        input_ids=encoding.ids[start:end]  # The ids of the tokens
                        + [self.tokenizer.pad_token_id]
                        * padding_to_add,  # padding if needed
                        labels=(
                            label[start:end]
                            + [-100] * padding_to_add  # padding if needed
                        ),  # -100 is a special token for padding of labels,
                        attention_masks=(
                            encoding.attention_mask[start:end]
                            + [0]
                            * padding_to_add  # 0'd attenetion masks where we added padding
                        ),
                    )
                )

    def __len__(self):
        return len(self.training_examples)

    def __getitem__(self, idx) -> TrainingExample:

        return self.training_examples[idx]

Let's See what comes out

Below we'll create a dataset instance. We first create a labelset, in this case there is only one label, drug. We then instantiate our Dataset by passing the raw data, the tokenizer and the labelset. We get back TrainingExample instances with the windowed and padded inputids and labelids as well as attention_masks.

label_set = LabelSet(labels=["drug"])
ds = TraingDataset(
    data=raw, tokenizer=tokenizer, label_set=label_set, tokens_per_batch=16
)
ex = ds[10]
pprint(ex)
TrainingExample(input_ids=[1233, 1621, 4420, 18061, 5165, 1114, 4267, 6066, 1465, 3171, 1306, 117, 1126, 27558, 1104, 140], attention_masks=[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], labels=[3, 0, 0, 0, 0, 0, 1, 2, 2, 2, 2, 3, 0, 0, 0, 0])

Batching

We still need a way batch these examples. We can't feed a list of TraingExamples to a model, we need to make tensors out of the input_ids and labels. This is easily achieved with a collating function. A collating function gets a list of items from our dataset (in our case a list of TraingExamples) and returns batched tensors.

We'll simplify things, by making a TraingBatch class whose constructor is the collating function

import torch


class TraingingBatch:
    def __getitem__(self, item):
        return getattr(self, item)

    def __init__(self, examples: List[TrainingExample]):
        self.input_ids: torch.Tensor
        self.attention_masks: torch.Tensor
        self.labels: torch.Tensor
        input_ids: IntListList = []
        masks: IntListList = []
        labels: IntListList = []
        for ex in examples:
            input_ids.append(ex.input_ids)
            masks.append(ex.attention_masks)
            labels.append(ex.labels)
        self.input_ids = torch.LongTensor(input_ids)
        self.attention_masks = torch.LongTensor(masks)
        self.labels = torch.LongTensor(labels)

Traing Our Model

With our batching ready, let's use a pre trained model and show how to fine tune it on our new dataset.

from torch.utils.data.dataloader import DataLoader
from transformers import BertForTokenClassification, AdamW

model = BertForTokenClassification.from_pretrained(
    "bert-base-cased", num_labels=len(ds.label_set.ids_to_label.values())
)
optimizer = AdamW(model.parameters(), lr=5e-6)

dataloader = DataLoader(
    ds,
    collate_fn=TraingingBatch,
    batch_size=4,
    shuffle=True,
)
for num, batch in enumerate(dataloader):
    loss, logits = model(
        input_ids=batch.input_ids,
        attention_mask=batch.attention_masks,
        labels=batch.labels,
    )
    loss.backward()
    optimizer.step()
    print(loss)
    
tensor(1.6987, grad_fn=<NllLossBackward>)
tensor(1.6388, grad_fn=<NllLossBackward>)
tensor(1.6135, grad_fn=<NllLossBackward>)
...

The End

This is where this post ends. Check back soon for the follow up where we'll share examples and tips for training sequence labeling models from pretrained transformers.

This post showed an implementation of the ideas in our previous post on Sequence Labeling With Transformers. You can find this post as a notebook with some additional utilites here. Follow us on twitter for updates and share this post if you liked it.

Thes Best Text Annotation Tool For Teams Wanting Secure Natural Language Processing

HAVE QUESTIONS?

SALES@LIGHTTAG.IO