用于中间激活的缓存

在推理过程中,模型逐个输出令牌。我们使用这个简单的缓存来存储键和值的注意层,这样我们就不必为以前的令牌重新计算它们了。

15from typing import Any

缓存

这将维护一个键值缓存,并将推送值排队并按相同的顺序弹出它们。队列非常有用,因为我们有多个关注层。

18class Cache:
26    def __init__(self):
27        self._cache = {}

清除缓存

29    def clear_all(self):
33        self._cache = {}

将值推送到队列

  • name 是队列的名称
  • value 是要推送的值
35    def push(self, name: str, value: Any):

如果队列不存在,请创建一个空队列

44        if name not in self._cache:
45            self._cache[name] = []

推送到队列

48        self._cache[name].append(value)

返回队列的大小

  • name 是队列的名称
  • 返回队列的大小(如果存在)否则 None

50    def q_size(self, name):
58        if name not in self._cache:
59            return None
60
61        if type(self._cache[name]) != list:
62            return None
63
64        return len(self._cache[name])

从队列中弹出

  • name 是队列的名称
  • 返回

66    def pop(self, name: str):
73        return self._cache[name].pop(0)

缓存一个值

  • key 是要缓存的值的名称
  • value 是价值
75    def set(self, key: str, value: Any):
82        self._cache[key] = value

从缓存中检索值

  • key 是缓存时使用的名称
  • default 如果缓存为空,则为默认值
  • 返回缓存的值

84    def get(self, key: str, default: Any = None):
92        return self._cache.get(key, default)

清除缓存值

  • key 是缓存时使用的名称
94    def clear(self, key: str):
100        del self._cache[key]

缓存的单例

104_INSTANCE = None

获取缓存实例

    返回缓存实例

107def get_cache() -> Cache:
113    global _INSTANCE
114
115    if _INSTANCE is None:
116        _INSTANCE = Cache()
117
118    return _INSTANCE