BERT Embeddings of chunks of text

This is the code to get BERT embeddings of chunks for RETRO model.

13from typing import List
15import torch
16from transformers import BertTokenizer, BertModel
18from labml import lab, monit

BERT Embeddings

For a given chunk of text this class generates BERT embeddings . is the average of BERT embeddings of all the tokens in .

21class BERTChunkEmbeddings:
29    def __init__(self, device: torch.device):
30        self.device = device

Load the BERT tokenizer from HuggingFace

33        with monit.section('Load BERT tokenizer'):
34            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35                                                           cache_dir=str(
36                                                               lab.get_data_path() / 'cache' / 'bert-tokenizer'))

Load the BERT model from HuggingFace

39        with monit.section('Load BERT model'):
40            self.model = BertModel.from_pretrained("bert-base-uncased",
41                                                   cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))

Move the model to device


In this implementation, we do not make chunks with a fixed number of tokens. One of the reasons is that this implementation uses character-level tokens and BERT uses its sub-word tokenizer.

So this method will truncate the text to make sure there are no partial tokens.

For instance, a chunk could be like s a popular programming la , with partial words (partial sub-word tokens) on the ends. We strip them off to get better BERT embeddings. As mentioned earlier this is not necessary if we broke chunks after tokenizing.

46    @staticmethod
47    def _trim_chunk(chunk: str):

Strip whitespace

61        stripped = chunk.strip()

Break words

63        parts = stripped.split()

Remove first and last pieces

65        stripped = stripped[len(parts[0]):-len(parts[-1])]

Remove whitespace

68        stripped = stripped.strip()

If empty return original string

71        if not stripped:
72            return chunk

Otherwise, return the stripped string

74        else:
75            return stripped

Get for a list of chunks.

77    def __call__(self, chunks: List[str]):

We don't need to compute gradients

83        with torch.no_grad():

Trim the chunks

85            trimmed_chunks = [self._trim_chunk(c) for c in chunks]

Tokenize the chunks with BERT tokenizer

88            tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)

Move token ids, attention mask and token types to the device

91            input_ids = tokens['input_ids'].to(self.device)
92            attention_mask = tokens['attention_mask'].to(self.device)
93            token_type_ids = tokens['token_type_ids'].to(self.device)

Evaluate the model

95            output = self.model(input_ids=input_ids,
96                                attention_mask=attention_mask,
97                                token_type_ids=token_type_ids)

Get the token embeddings

100            state = output['last_hidden_state']

Calculate the average token embeddings. Note that the attention mask is 0 if the token is empty padded. We get empty tokens because the chunks are of different lengths.

104            emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)

107            return emb

Code to test BERT embeddings

110def _test():
114    from labml.logger import inspect


117    device = torch.device('cuda:0')
118    bert = BERTChunkEmbeddings(device)


121    text = ["Replace me by any text you'd like.",
122            "Second sentence"]

Check BERT tokenizer

125    encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
127    inspect(encoded_input, _expand=True)

Check BERT model outputs

130    output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131                        attention_mask=encoded_input['attention_mask'].to(device),
132                        token_type_ids=encoded_input['token_type_ids'].to(device))
134    inspect({'last_hidden_state': output['last_hidden_state'],
135             'pooler_output': output['pooler_output']},
136            _expand=True)

Check recreating text from token ids

139    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)

Get chunk embeddings

143    inspect(bert(text))

147if __name__ == '__main__':
148    _test()