13from typing import List
14
15import torch
16from transformers import BertTokenizer, BertModel
17
18from labml import lab, monit

BERT エンベディング

このクラスは、特定のテキストチャンクに対して BERT 埋め込みを生成します。は、内のすべてのトークンのBERT埋め込みの平均です。

21class BERTChunkEmbeddings:
29    def __init__(self, device: torch.device):
30        self.device = device
33        with monit.section('Load BERT tokenizer'):
34            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
35                                                           cache_dir=str(
36                                                               lab.get_data_path() / 'cache' / 'bert-tokenizer'))

ハギングフェイスから BERT モデルをロード

39        with monit.section('Load BERT model'):
40            self.model = BertModel.from_pretrained("bert-base-uncased",
41                                                   cache_dir=str(lab.get_data_path() / 'cache' / 'bert-model'))

モデルを次の場所に移動 device

44            self.model.to(device)

この実装では、固定数のトークンでチャンクを作成しません。理由の1つは、この実装では文字レベルのトークンを使用し、BERTはサブワードトークナイザーを使用していることです

そのため、このメソッドはテキストを切り捨てて、部分的なトークンがないことを確認します。

たとえば、チャンクは、末尾に単語の一部(部分的なサブワードトークン)s a popular programming la があるようなものかもしれません。BERT 埋め込みの精度を高めるため、これらを削除しました。先に述べたように、トークン化後にチャンクを分割した場合は必要ありません

46    @staticmethod
47    def _trim_chunk(chunk: str):

空白を削除

61        stripped = chunk.strip()

ブレーク・ワード

63        parts = stripped.split()

最初と最後のピースを削除する

65        stripped = stripped[len(parts[0]):-len(parts[-1])]

空白を削除

68        stripped = stripped.strip()

空の場合、元の文字列を返す

71        if not stripped:
72            return chunk

それ以外の場合は、取り除いた文字列を返す

74        else:
75            return stripped

チャンクのリストを入手してください。

77    def __call__(self, chunks: List[str]):

勾配を計算する必要はありません

83        with torch.no_grad():

チャンクをトリミング

85            trimmed_chunks = [self._trim_chunk(c) for c in chunks]

BERT トークナイザーでチャンクをトークン化する

88            tokens = self.tokenizer(trimmed_chunks, return_tensors='pt', add_special_tokens=False, padding=True)

トークン ID、アテンションマスク、トークンタイプをデバイスに移動

91            input_ids = tokens['input_ids'].to(self.device)
92            attention_mask = tokens['attention_mask'].to(self.device)
93            token_type_ids = tokens['token_type_ids'].to(self.device)

モデルの評価

95            output = self.model(input_ids=input_ids,
96                                attention_mask=attention_mask,
97                                token_type_ids=token_type_ids)

トークンの埋め込みを入手

100            state = output['last_hidden_state']

トークン埋め込みの平均回数を計算します。注意マスクは、0 トークンが空でパディングされている場合であることに注意してください。チャンクの長さが異なるため、空のトークンが返されます

104            emb = (state * attention_mask[:, :, None]).sum(dim=1) / attention_mask[:, :, None].sum(dim=1)

107            return emb

BERT 埋め込みをテストするコード

110def _test():
114    from labml.logger import inspect

[初期化]

117    device = torch.device('cuda:0')
118    bert = BERTChunkEmbeddings(device)

[サンプル]

121    text = ["Replace me by any text you'd like.",
122            "Second sentence"]

BERT トークナイザーをチェック

125    encoded_input = bert.tokenizer(text, return_tensors='pt', add_special_tokens=False, padding=True)
126
127    inspect(encoded_input, _expand=True)

BERT モデルの出力をチェック

130    output = bert.model(input_ids=encoded_input['input_ids'].to(device),
131                        attention_mask=encoded_input['attention_mask'].to(device),
132                        token_type_ids=encoded_input['token_type_ids'].to(device))
133
134    inspect({'last_hidden_state': output['last_hidden_state'],
135             'pooler_output': output['pooler_output']},
136            _expand=True)

トークン ID からテキストを再作成することを確認してください

139    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][0]), _n=-1)
140    inspect(bert.tokenizer.convert_ids_to_tokens(encoded_input['input_ids'][1]), _n=-1)

チャンク埋め込みを取得

143    inspect(bert(text))

147if __name__ == '__main__':
148    _test()