KV Caching in LLMs, explained visually
A popular interview question.
A popular interview question.

TODAY'S ISSUE
KV caching is a popular technique to speed up LLM inference.
To get some perspective, look at the inference speed difference from our demo:
Today, let’s visually understand how KV caching works.
Let's dive in!
To understand KV caching, we must know how LLMs output tokens.

As shown in the visual above:
Thus, to generate a new token, we only need the hidden state of the most recent token. None of the other hidden states are required.

Next, let's see how the last hidden state is computed within the transformer layer from the attention mechanism.

During attention, we first do the product of query and key matrices, and the last row involves the last token’s query vector and all key vectors:

None of the other query vectors are needed during inference.
Also, the last row of the final attention result involves the last query vector and all key & value vectors. Check this visual to understand better:

The above insight suggests that to generate a new token, every attention operation in the network only needs:

But there's one more key insight here.
As we generate new tokens, the KV vectors used for ALL previous tokens do not change.

Thus, we just need to generate a KV vector for the token generated one step before.
The rest of the KV vectors can be retrieved from a cache to save compute and time.
This is called KV caching!
To reiterate, instead of redundantly computing KV vectors of all context tokens, cache them.

To generate a token:
As you can tell, this saves time during inference.
In fact, this is why ChatGPT takes some time to generate the first token than the subsequent tokens. During that little pause, the KV cache of the prompt is computed.
That said, KV cache also takes a lot of memory.
Consider Llama3-70B, which has:
Here:
More users → more memory.
We'll cover KV optimization soon.
👉 Over to you: How can we optimize the memory consumption?
Consider the size difference between BERT-large and GPT-3:
I have fine-tuned BERT-large several times on a single GPU using traditional fine-tuning:

But this is impossible with GPT-3, which has 175B parameters. That's 350GB of memory just to store model weights under float16 precision.
This means that if OpenAI used traditional fine-tuning within its fine-tuning API, it would have to maintain one model copy per user:
And the problems don't end there:
​​LoRA​​ (+ ​​QLoRA and other variants​​) neatly solved this critical business problem.