13from typing import List
14
15import torch
16from transformers import BertTokenizer, BertModel
17
18from labml import lab, monit
21class BERTChunkEmbeddings:
29 def __init__(self, device: torch.device):
30 self.device = device
从 HuggingFac e 加载 BERT 分词器
33 with monit.section('Load BERT tokenizer'):
34 self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35 cache_dir=str(
36 lab.get_data_path() / 'cache' / 'bert-tokenizer'))
从 HuggingFac e 加载 BERT 模型
39 with monit.section('Load BERT model'):
40 self.model = BertModel.from_pretrained("bert-base-uncased",
41 cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))
将模型移到device
44 self.model.to(device)
在此实现中,我们不会使用固定数量的令牌制作区块。原因之一是此实现使用字符级令牌,而 BERT 使用其子词分词器。
因此,此方法将截断文本以确保没有部分标记。
例如,一个块可能像s a popular programming la
,末尾带有部分单词(部分子词标记)。我们剥离它们以获得更好的 BERT 嵌入。如前所述,如果我们在标记化后破坏了区块,则没有必要这样做。
46 @staticmethod
47 def _trim_chunk(chunk: str):
去掉空白
61 stripped = chunk.strip()
断词
63 parts = stripped.split()
移除第一块和最后一块碎片
65 stripped = stripped[len(parts[0]):-len(parts[-1])]
移除空格
68 stripped = stripped.strip()
如果为空则返回原始字符串
71 if not stripped:
72 return chunk
否则,返回被剥离的字符串
74 else:
75 return stripped
77 def __call__(self, chunks: List[str]):
我们不需要计算梯度
83 with torch.no_grad():
修剪块
85 trimmed_chunks = [self._trim_chunk(c) for c in chunks]
使用 BERT 分词器对区块进行标记化
88 tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)
将令牌 ID、注意掩码和令牌类型移动到设备
91 input_ids = tokens['input_ids'].to(self.device)
92 attention_mask = tokens['attention_mask'].to(self.device)
93 token_type_ids = tokens['token_type_ids'].to(self.device)
评估模型
95 output = self.model(input_ids=input_ids,
96 attention_mask=attention_mask,
97 token_type_ids=token_type_ids)
获取令牌嵌入
100 state = output['last_hidden_state']
计算平均代币嵌入量。请注意,0
如果令牌是空填充的,则注意掩码为。我们得到空令牌,因为这些块的长度不同。
104 emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)
107 return emb
110def _test():
114 from labml.logger import inspect
初始化
117 device = torch.device('cuda:0')
118 bert = BERTChunkEmbeddings(device)
样本
121 text = ["Replace me by any text you'd like.",
122 "Second sentence"]
查看 BERT 分词器
125 encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
126
127 inspect(encoded_input, _expand=True)
检查 BERT 模型输出
130 output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131 attention_mask=encoded_input['attention_mask'].to(device),
132 token_type_ids=encoded_input['token_type_ids'].to(device))
133
134 inspect({'last_hidden_state': output['last_hidden_state'],
135 'pooler_output': output['pooler_output']},
136 _expand=True)
检查从令牌 ID 中重新创建文本
139 inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140 inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)
获取区块嵌入
143 inspect(bert(text))
147if __name__ == '__main__':
148 _test()