This is the code to get BERT embeddings of chunks for RETRO model.
13from typing import List
14
15import torch
16from transformers import BertTokenizer, BertModel
17
18from labml import lab, monitFor a given chunk of text this class generates BERT embeddings . is the average of BERT embeddings of all the tokens in .
21class BERTChunkEmbeddings:29    def __init__(self, device: torch.device):
30        self.device = deviceLoad the BERT tokenizer from HuggingFace
33        with monit.section('Load BERT tokenizer'):
34            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35                                                           cache_dir=str(
36                                                               lab.get_data_path() / 'cache' / 'bert-tokenizer'))Load the BERT model from HuggingFace
39        with monit.section('Load BERT model'):
40            self.model = BertModel.from_pretrained("bert-base-uncased",
41                                                   cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))Move the model to device
 
44            self.model.to(device)In this implementation, we do not make chunks with a fixed number of tokens. One of the reasons is that this implementation uses character-level tokens and BERT uses its sub-word tokenizer.
So this method will truncate the text to make sure there are no partial tokens.
For instance, a chunk could be like s a popular programming la
, with partial words (partial sub-word tokens) on the ends. We strip them off to get better BERT embeddings. As mentioned earlier this is not necessary if we broke chunks after tokenizing.
46    @staticmethod
47    def _trim_chunk(chunk: str):Strip whitespace
61        stripped = chunk.strip()Break words
63        parts = stripped.split()Remove first and last pieces
65        stripped = stripped[len(parts[0]):-len(parts[-1])]Remove whitespace
68        stripped = stripped.strip()If empty return original string
71        if not stripped:
72            return chunkOtherwise, return the stripped string
74        else:
75            return stripped77    def __call__(self, chunks: List[str]):We don't need to compute gradients
83        with torch.no_grad():Trim the chunks
85            trimmed_chunks = [self._trim_chunk(c) for c in chunks]Tokenize the chunks with BERT tokenizer
88            tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)Move token ids, attention mask and token types to the device
91            input_ids = tokens['input_ids'].to(self.device)
92            attention_mask = tokens['attention_mask'].to(self.device)
93            token_type_ids = tokens['token_type_ids'].to(self.device)Evaluate the model
95            output = self.model(input_ids=input_ids,
96                                attention_mask=attention_mask,
97                                token_type_ids=token_type_ids)Get the token embeddings
100            state = output['last_hidden_state']Calculate the average token embeddings. Note that the attention mask is 0
 if the token is empty padded. We get empty tokens because the chunks are of different lengths. 
104            emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)107            return emb110def _test():114    from labml.logger import inspectInitialize
117    device = torch.device('cuda:0')
118    bert = BERTChunkEmbeddings(device)Sample
121    text = ["Replace me by any text you'd like.",
122            "Second sentence"]Check BERT tokenizer
125    encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
126
127    inspect(encoded_input, _expand=True)Check BERT model outputs
130    output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131                        attention_mask=encoded_input['attention_mask'].to(device),
132                        token_type_ids=encoded_input['token_type_ids'].to(device))
133
134    inspect({'last_hidden_state': output['last_hidden_state'],
135             'pooler_output': output['pooler_output']},
136            _expand=True)Check recreating text from token ids
139    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)Get chunk embeddings
143    inspect(bert(text))147if __name__ == '__main__':
148    _test()