This is the code to get BERT embeddings of chunks for RETRO model.
13from typing import List
14
15import torch
16from transformers import BertTokenizer, BertModel
17
18from labml import lab, monit
For a given chunk of text this class generates BERT embeddings . is the average of BERT embeddings of all the tokens in .
21class BERTChunkEmbeddings:
29 def __init__(self, device: torch.device):
30 self.device = device
Load the BERT tokenizer from HuggingFace
33 with monit.section('Load BERT tokenizer'):
34 self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35 cache_dir=str(
36 lab.get_data_path() / 'cache' / 'bert-tokenizer'))
Load the BERT model from HuggingFace
39 with monit.section('Load BERT model'):
40 self.model = BertModel.from_pretrained("bert-base-uncased",
41 cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))
Move the model to device
44 self.model.to(device)
In this implementation, we do not make chunks with a fixed number of tokens. One of the reasons is that this implementation uses character-level tokens and BERT uses its sub-word tokenizer.
So this method will truncate the text to make sure there are no partial tokens.
For instance, a chunk could be like s a popular programming la
, with partial words (partial sub-word tokens) on the ends. We strip them off to get better BERT embeddings. As mentioned earlier this is not necessary if we broke chunks after tokenizing.
46 @staticmethod
47 def _trim_chunk(chunk: str):
Strip whitespace
61 stripped = chunk.strip()
Break words
63 parts = stripped.split()
Remove first and last pieces
65 stripped = stripped[len(parts[0]):-len(parts[-1])]
Remove whitespace
68 stripped = stripped.strip()
If empty return original string
71 if not stripped:
72 return chunk
Otherwise, return the stripped string
74 else:
75 return stripped
77 def __call__(self, chunks: List[str]):
We don't need to compute gradients
83 with torch.no_grad():
Trim the chunks
85 trimmed_chunks = [self._trim_chunk(c) for c in chunks]
Tokenize the chunks with BERT tokenizer
88 tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)
Move token ids, attention mask and token types to the device
91 input_ids = tokens['input_ids'].to(self.device)
92 attention_mask = tokens['attention_mask'].to(self.device)
93 token_type_ids = tokens['token_type_ids'].to(self.device)
Evaluate the model
95 output = self.model(input_ids=input_ids,
96 attention_mask=attention_mask,
97 token_type_ids=token_type_ids)
Get the token embeddings
100 state = output['last_hidden_state']
Calculate the average token embeddings. Note that the attention mask is 0
if the token is empty padded. We get empty tokens because the chunks are of different lengths.
104 emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)
107 return emb
110def _test():
114 from labml.logger import inspect
Initialize
117 device = torch.device('cuda:0')
118 bert = BERTChunkEmbeddings(device)
Sample
121 text = ["Replace me by any text you'd like.",
122 "Second sentence"]
Check BERT tokenizer
125 encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
126
127 inspect(encoded_input, _expand=True)
Check BERT model outputs
130 output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131 attention_mask=encoded_input['attention_mask'].to(device),
132 token_type_ids=encoded_input['token_type_ids'].to(device))
133
134 inspect({'last_hidden_state': output['last_hidden_state'],
135 'pooler_output': output['pooler_output']},
136 _expand=True)
Check recreating text from token ids
139 inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140 inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)
Get chunk embeddings
143 inspect(bert(text))
147if __name__ == '__main__':
148 _test()