Published on

How to scale LLM Inference with KV-Caching

Authors

"When you use ChatGPT or Claude, why does the first token take longer to appear than the second one? That is thanks to KV caching! It is quite impressive when you think about the number of little optimization tricks meant to improve the serving latency of text generation services!

An LLM generates text by iteratively predicting the next token and appending it to the previously generated tokens and the original prompt. Typically, causal LLMs are trained to predict the next word in the sequence. This means that each input token maps to a resulting hidden state within the Transformer, which in turn maps to a prediction vector for the following token. The prediction vector has as many predictions as there are tokens in the dictionary, and the next token can be predicted by greedily finding the prediction with the highest probability. This means that when we are decoding a specific token, we only need to compute its corresponding hidden state and discard the others.

This required hidden state corresponds to the last token of the input sequence. To compute this hidden state, in each of the self-attention layers, we need all the Keys and Values of the whole input sequence but only the Query for the last token of this sequence. We generate the attentions for the last token by taking the Softmax of the dot product between the Query and all the Keys:

attentions for last token = Softmax(Query x all Keys)

And we get the resulting hidden state by taking the weighted average of all the Values as given by the attentions:

hidden state = all the Values x attentions

There are a few conclusions that we can draw. First, there is no need to compute attentions besides the ones corresponding to the last token. This means that during the decoding process, the time complexity at each iteration is linear in the number of input tokens (~O(N)), even with a vanilla attention mechanism. Second, the Keys and Values remain the same for each of the tokens for all iterations. This means that we don't need to recompute them at every iteration. That is where the idea of KV caching comes from!

The decoding process can be divided into 2 phases: the initialization phase and the generation phase. In the initialization phase, all the Keys and Values corresponding to all the tokens in the input prompt need to be created. This takes almost as long as the following phase. We can then store all the keys and values for all the attention layers in the KV cache. In the generation phase, we only need to generate the Key, Query, and Value corresponding to the last token in the input sequence. We can then pull the stored Keys and Values from the cache to compute the required hidden states. In the end, we update the cache with the latest computed Key and Value corresponding to that last token.

This KV caching process significantly reduces the latency associated with text generation! "

How to scale LLM Inference with KV-Caching

Author

ABN ASIA was founded by people with deep roots in academia, with work experience in the US, Holland, Hungary, Japan, South Korea, Singapore, and Vietnam. ABN Asia is where academy and technology meet opportunity. With our cutting-edge solutions and competent software development services, we're helping businesses level up and take on the global scene. Our commitment: Faster. Better. More reliable. In most cases: Cheaper as well.

Feel free to reach out to us whenever you require IT services, digital consulting, off-the-shelf software solutions, or if you'd like to send us requests for proposals (RFPs). You can contact us at contact@abnasia.org. We're ready to assist you with all your technology needs.

ABNAsia.org

© ABN ASIA