Cache for Intermediate Activations

During inference the model outputs token by token. We use this simple cache to store key's and value's attention layers, so that we don't have to recompute them for previous tokens.

15from typing import Any

Cache

This maintains a key-value cache and queues push values and pop them in the same order. The queues are useful since we have multiple attention layers.

18class Cache:
26    def __init__(self):
27        self._cache = {}

Clear cache

29    def clear_all(self):
33        self._cache = {}

Push a value to a queue

  • name is the name of the queue
  • value is the value to be pushed
35    def push(self, name: str, value: Any):

Create an empty queue if it's not present

44        if name not in self._cache:
45            self._cache[name] = []

Push to the queue

48        self._cache[name].append(value)

Return the size of the queue

  • name is the name of the queue
  • Returns size of the queue if exists else None

50    def q_size(self, name):
58        if name not in self._cache:
59            return None
60
61        if type(self._cache[name]) != list:
62            return None
63
64        return len(self._cache[name])

Pop from a queue

  • name is the name of the queue
  • Returns the value

66    def pop(self, name: str):
73        return self._cache[name].pop(0)

Cache a value

  • key is the name of the value to be cached
  • value is the value
75    def set(self, key: str, value: Any):
82        self._cache[key] = value

Retrieve a value from cache

  • key is the name used when caching
  • default is the default value if the cache is empty
  • Returns the cached value

84    def get(self, key: str, default: Any = None):
92        return self._cache.get(key, default)

Clear a cache value

  • key is the name used when caching
94    def clear(self, key: str):
100        del self._cache[key]

Singleton for cache

104_INSTANCE = None

Get the cache instance

    Returns the cache instance

107def get_cache() -> Cache:
113    global _INSTANCE
114
115    if _INSTANCE is None:
116        _INSTANCE = Cache()
117
118    return _INSTANCE