检索增强型变压器(复古)

这是 PyTorch 对论文《通过从数万亿个代币中检索来改进语言模型》的实现。

它建立了一个包含大量文本的数据库。它是一个键值数据库,其中的密钥由区块的 BERT 嵌入索引。他们使用冻结的预训练的 BERT 模型来计算这些嵌入。这些值是相应的区块和该区块的等长度文本。

然后,模型从该数据库检索与模型输入相似(最近邻域)的文本。这些检索到的文本用于预测输出。

由于我们使用冻结的 BERT 模型进行检索,因此我们可以预先计算训练数据集的所有最近邻域。这加快了训练过程。

组件: