This is the code to get BERT embeddings of chunks for RETRO model.
13from typing import List
15import torch
16from transformers import BertTokenizer, BertModel
18from labml import lab, monit
For a given chunk of text this class generates BERT embeddings . is the average of BERT embeddings of all the tokens in .
21class BERTChunkEmbeddings:
29 def __init__(self, device: torch.device):
30 self.device = device
Load the BERT tokenizer from HuggingFace
33 with monit.section('Load BERT tokenizer'):
34 self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35 cache_dir=str(
36 lab.get_data_path() / 'cache' / 'bert-tokenizer'))
Load the BERT model from HuggingFace
39 with monit.section('Load BERT model'):
40 self.model = BertModel.from_pretrained("bert-base-uncased",
41 cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))
Move the model to device
In this implementation, we do not make chunks with a fixed number of tokens. One of the reasons is that this implementation uses character-level tokens and BERT uses its sub-word tokenizer.
So this method will truncate the text to make sure there are no partial tokens.
For instance, a chunk could be like s a popular programming la
, with partial words (partial sub-word tokens) on the ends. We strip them off to get better BERT embeddings. As mentioned earlier this is not necessary if we broke chunks after tokenizing.
46 @staticmethod
47 def _trim_chunk(chunk: str):
Strip whitespace
61 stripped = chunk.strip()
Break words
63 parts = stripped.split()
Remove first and last pieces
65 stripped = stripped[len(parts[0]):-len(parts[-1])]
Remove whitespace
68 stripped = stripped.strip()
If empty return original string
71 if not stripped:
72 return chunk
Otherwise, return the stripped string
74 else:
75 return stripped
77 def __call__(self, chunks: List[str]):
We don't need to compute gradients
83 with torch.no_grad():
Trim the chunks
85 trimmed_chunks = [self._trim_chunk(c) for c in chunks]
Tokenize the chunks with BERT tokenizer
88 tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)
Move token ids, attention mask and token types to the device
91 input_ids = tokens['input_ids'].to(self.device)
92 attention_mask = tokens['attention_mask'].to(self.device)
93 token_type_ids = tokens['token_type_ids'].to(self.device)
Evaluate the model
95 output = self.model(input_ids=input_ids,
96 attention_mask=attention_mask,
97 token_type_ids=token_type_ids)
Get the token embeddings
100 state = output['last_hidden_state']
Calculate the average token embeddings. Note that the attention mask is 0
if the token is empty padded. We get empty tokens because the chunks are of different lengths.
104 emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)
107 return emb
110def _test():
114 from labml.logger import inspect
117 device = torch.device('cuda:0')
118 bert = BERTChunkEmbeddings(device)
121 text = ["Replace me by any text you'd like.",
122 "Second sentence"]
Check BERT tokenizer
125 encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
127 inspect(encoded_input, _expand=True)
Check BERT model outputs
130 output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131 attention_mask=encoded_input['attention_mask'].to(device),
132 token_type_ids=encoded_input['token_type_ids'].to(device))
134 inspect({'last_hidden_state': output['last_hidden_state'],
135 'pooler_output': output['pooler_output']},
136 _expand=True)
Check recreating text from token ids
139 inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140 inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)
Get chunk embeddings
143 inspect(bert(text))
147if __name__ == '__main__':
148 _test()