13from typing import List
14
15import torch
16from transformers import BertTokenizer, BertModel
17
18from labml import lab, monit
21class BERTChunkEmbeddings:
29 def __init__(self, device: torch.device):
30 self.device = device
33 with monit.section('Load BERT tokenizer'):
34 self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35 cache_dir=str(
36 lab.get_data_path() / 'cache' / 'bert-tokenizer'))
ハギングフェイスから BERT モデルをロード
39 with monit.section('Load BERT model'):
40 self.model = BertModel.from_pretrained("bert-base-uncased",
41 cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))
モデルを次の場所に移動 device
44 self.model.to(device)
この実装では、固定数のトークンでチャンクを作成しません。理由の1つは、この実装では文字レベルのトークンを使用し、BERTはサブワードトークナイザーを使用していることです
。そのため、このメソッドはテキストを切り捨てて、部分的なトークンがないことを確認します。
たとえば、チャンクは、末尾に単語の一部(部分的なサブワードトークン)s a popular programming la
があるようなものかもしれません。BERT 埋め込みの精度を高めるため、これらを削除しました。先に述べたように、トークン化後にチャンクを分割した場合は必要ありません
46 @staticmethod
47 def _trim_chunk(chunk: str):
空白を削除
61 stripped = chunk.strip()
ブレーク・ワード
63 parts = stripped.split()
最初と最後のピースを削除する
65 stripped = stripped[len(parts[0]):-len(parts[-1])]
空白を削除
68 stripped = stripped.strip()
空の場合、元の文字列を返す
71 if not stripped:
72 return chunk
それ以外の場合は、取り除いた文字列を返す
74 else:
75 return stripped
77 def __call__(self, chunks: List[str]):
勾配を計算する必要はありません
83 with torch.no_grad():
チャンクをトリミング
85 trimmed_chunks = [self._trim_chunk(c) for c in chunks]
BERT トークナイザーでチャンクをトークン化する
88 tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)
トークン ID、アテンションマスク、トークンタイプをデバイスに移動
91 input_ids = tokens['input_ids'].to(self.device)
92 attention_mask = tokens['attention_mask'].to(self.device)
93 token_type_ids = tokens['token_type_ids'].to(self.device)
モデルの評価
95 output = self.model(input_ids=input_ids,
96 attention_mask=attention_mask,
97 token_type_ids=token_type_ids)
トークンの埋め込みを入手
100 state = output['last_hidden_state']
トークン埋め込みの平均回数を計算します。注意マスクは、0
トークンが空でパディングされている場合であることに注意してください。チャンクの長さが異なるため、空のトークンが返されます
104 emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)
107 return emb
110def _test():
114 from labml.logger import inspect
[初期化]
117 device = torch.device('cuda:0')
118 bert = BERTChunkEmbeddings(device)
[サンプル]
121 text = ["Replace me by any text you'd like.",
122 "Second sentence"]
BERT トークナイザーをチェック
125 encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
126
127 inspect(encoded_input, _expand=True)
BERT モデルの出力をチェック
130 output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131 attention_mask=encoded_input['attention_mask'].to(device),
132 token_type_ids=encoded_input['token_type_ids'].to(device))
133
134 inspect({'last_hidden_state': output['last_hidden_state'],
135 'pooler_output': output['pooler_output']},
136 _expand=True)
トークン ID からテキストを再作成することを確認してください
139 inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140 inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)
チャンク埋め込みを取得
143 inspect(bert(text))
147if __name__ == '__main__':
148 _test()