BERT 文本块的嵌入

这是获取 RETRO 模型块的 BERT 嵌入的代码。

13from typing import List
14
15import torch
16from transformers import BertTokenizer, BertModel
17
18from labml import lab, monit

BERT 嵌入式

对于给定的文本块,这个类会生成 BERT 嵌入是所有令牌的 BERT 嵌入的平均值

21class BERTChunkEmbeddings:
29    def __init__(self, device: torch.device):
30        self.device = device

HuggingFac e 加载 BERT 分词器

33        with monit.section('Load BERT tokenizer'):
34            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35                                                           cache_dir=str(
36                                                               lab.get_data_path() / 'cache' / 'bert-tokenizer'))

HuggingFac e 加载 BERT 模型

39        with monit.section('Load BERT model'):
40            self.model = BertModel.from_pretrained("bert-base-uncased",
41                                                   cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))

将模型移到device

44            self.model.to(device)

在此实现中,我们不会使用固定数量的令牌制作区块。原因之一是此实现使用字符级令牌,而 BERT 使用其子词分词器。

因此,此方法将截断文本以确保没有部分标记。

例如,一个块可能像s a popular programming la ,末尾带有部分单词(部分子词标记)。我们剥离它们以获得更好的 BERT 嵌入。如前所述,如果我们在标记化后破坏了区块,则没有必要这样做。

46    @staticmethod
47    def _trim_chunk(chunk: str):

去掉空白

61        stripped = chunk.strip()

断词

63        parts = stripped.split()

移除第一块和最后一块碎片

65        stripped = stripped[len(parts[0]):-len(parts[-1])]

移除空格

68        stripped = stripped.strip()

如果为空则返回原始字符串

71        if not stripped:
72            return chunk

否则,返回被剥离的字符串

74        else:
75            return stripped

获取区块列表。

77    def __call__(self, chunks: List[str]):

我们不需要计算梯度

83        with torch.no_grad():

修剪块

85            trimmed_chunks = [self._trim_chunk(c) for c in chunks]

使用 BERT 分词器对区块进行标记化

88            tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)

将令牌 ID、注意掩码和令牌类型移动到设备

91            input_ids = tokens['input_ids'].to(self.device)
92            attention_mask = tokens['attention_mask'].to(self.device)
93            token_type_ids = tokens['token_type_ids'].to(self.device)

评估模型

95            output = self.model(input_ids=input_ids,
96                                attention_mask=attention_mask,
97                                token_type_ids=token_type_ids)

获取令牌嵌入

100            state = output['last_hidden_state']

计算平均代币嵌入量。请注意,0 如果令牌是空填充的,则注意掩码为。我们得到空令牌,因为这些块的长度不同。

104            emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)

107            return emb

用于测试 BERT 嵌入的代码

110def _test():
114    from labml.logger import inspect

初始化

117    device = torch.device('cuda:0')
118    bert = BERTChunkEmbeddings(device)

样本

121    text = ["Replace me by any text you'd like.",
122            "Second sentence"]

查看 BERT 分词器

125    encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
126
127    inspect(encoded_input, _expand=True)

检查 BERT 模型输出

130    output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131                        attention_mask=encoded_input['attention_mask'].to(device),
132                        token_type_ids=encoded_input['token_type_ids'].to(device))
133
134    inspect({'last_hidden_state': output['last_hidden_state'],
135             'pooler_output': output['pooler_output']},
136            _expand=True)

检查从令牌 ID 中重新创建文本

139    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)

获取区块嵌入

143    inspect(bert(text))

147if __name__ == '__main__':
148    _test()