During inference the model outputs token by token. We use this simple cache to store key's and value's attention layers, so that we don't have to recompute them for previous tokens.
15from typing import Any
This maintains a key-value cache and queues push values and pop them in the same order. The queues are useful since we have multiple attention layers.
18class Cache:
26 def __init__(self):
27 self._cache = {}
29 def clear_all(self):
33 self._cache = {}
35 def push(self, name: str, value: Any):
Create an empty queue if it's not present
44 if name not in self._cache:
45 self._cache[name] = []
Push to the queue
48 self._cache[name].append(value)
name
is the name of the queue Returns size of the queue if exists else None
50 def q_size(self, name):
58 if name not in self._cache:
59 return None
60
61 if type(self._cache[name]) != list:
62 return None
63
64 return len(self._cache[name])
66 def pop(self, name: str):
73 return self._cache[name].pop(0)
75 def set(self, key: str, value: Any):
82 self._cache[key] = value
key
is the name used when caching default
is the default value if the cache is empty Returns the cached value
84 def get(self, key: str, default: Any = None):
92 return self._cache.get(key, default)
94 def clear(self, key: str):
100 del self._cache[key]
Singleton for cache
104_INSTANCE = None
107def get_cache() -> Cache:
113 global _INSTANCE
114
115 if _INSTANCE is None:
116 _INSTANCE = Cache()
117
118 return _INSTANCE