LLM Infra 101 v0.2: KV Cache
系列的第二集,前面的可以看:
这一期的代码在 https://github.com/iFurySt/nanoLLMServe/tree/release/v0.2.0
上一期过完,能通过API调用模型了。这期我们来支持KV Cache。在第一集的时候我们发现,每次forward 的时候都会重复计算:
prompt
-> forward(prompt)
-> 采样 token1
-> forward(prompt + token1)
-> 采样 token2
-> forward(prompt + token1 + token2)
-> ...
这里每次推入的序列都会重新计算一遍,Transformer的计算就贵在Attention的计算:
Q = xW_Q
K = xW_K
V = xW_V
Attention(Q, K, V)
所以当我们没有KV缓存的时候,大概流程是这样的:
-
forward(prompt)
-
计算prompt里每个token的Q/K/V
-
计算prompt内部的Attention
-
采样得到token1
-
forward(prompt+token1)
-
计算prompt里每个token的Q/K/V
-
计算prompt内部的Attention
-
采样得到token2
-
forward(prompt+token1+token2)
-
计算prompt里每个token的Q/K/V
-
计算prompt内部的Attention
-
采样得到token3
如果有了KV Cache,那流程会是这样的:
-
forward(prompt)
-
计算prompt里每个token的Q/K/V
-
计算prompt内部的Attention
-
保存K/V到KV Cache
-
采样得到token1
-
forward(token1+past_kv(也就是prompt的))
-
只计算token1的Q/K/V
-
读取prompt的K/V
-
计算Attention(Q_token1, K_prompt+token1, V_prompt+token1)
-
保存token1的K/V
-
采样得到token2
-
forward(token2+past_kv(也就是prompt+token1的))
-
只计算token2的Q/K/V
-
读取prompt+token1的K/V
-
计算Attention(Q_token2, K_prompt+token1+token2, V_prompt+token1+token2)
-
保存token2的K/V
-
采样得到token3
本质上KV Cache就是为了后续计算可以重复利用,我们来看一个实际推理过程中的环节:
Token
↓
Attention(看上下文)
↓
FFN(自己思考)
↓
下一层
可以看到Attention里的K/V都Cache里,但是FFN里没有任何Cache的,这个是因为Attention的计算都是依赖于之前计算的,但是FFN都是针对当前token自己去做计算(非线性变换)的
token3
↓
Linear Up Projection(升维,高纬空间有更复杂的表达能力)
↓
Activation (GELU / SwiGLU)
↓
Linear Down Projection(降维)
↓
output
这个过程中只涉及到token3本身的计算,输出的FFN(hidden3)只会在当前layer使用一次,后续就没用了,所以没办法做Cache
知道了原理后,来看看实现
实现
改动文件涉及这些:
.
├── benchmarks/
│ └── benchmark_generate.py # 增加 KV cache vs v0.0 naive baseline 对比,输出 TTFT/TPOT
├── src/
│ └── nanollmserve/
│ ├── cli/
│ │ └── generate.py # show-stats 新增 TTFT / TPOT
│ └── engine/
│ ├── engine.py # 核心改动:prefill + decode + past_key_values 复用
│ └── request.py # 新增 GenerationRequestState,保存单请求生成状态
└── tests/
├── test_engine.py # 验证 decode 阶段只喂单 token,且复用 past_key_values
├── test_request_state.py # 验证 request state 的 token 统计和 TPOT
└── test_benchmark_generate.py # 验证 benchmark 汇总字段和 speedup 计算
Prefill
src/nanollmserve/engine/engine.py:160
model.eval()
with torch.inference_mode():
prefill_start = perf_counter()
outputs = model(input_ids=input_ids, attention_mask=state.attention_mask, use_cache=True)
state.prefill_seconds = perf_counter() - prefill_start
state.past_key_values = getattr(outputs, "past_key_values", None)
if state.past_key_values is None:
raise RuntimeError("model did not return past_key_values; KV cache decode requires use_cache support")
这边的model是基于transformers加载进来的模型对象
loaded = load_model_and_tokenizer(...)
result = generate_one(
loaded.model,
loaded.tokenizer,
prompt,
...
)
传入use_cache=True 参数后,会要求模型forward后返回past_key_values ,后续decode的时候再把这个KV Cache传回去。
这里做的就是预填充Prefill,简单说就是把传入的prompt完整的处理一遍,建立KV Cache,后续就只要做新的token的Q计算,然后就可以服用之前的KV Cache做Attention的计算了
Decode
src/nanollmserve/engine/engine.py:179
next_token = _sample_from_outputs(outputs, temperature=temperature, generator=generator)
yield _record_step(
tokenizer,
state,
next_token,
eos_token_ids=eos_token_ids,
start=start,
max_new_tokens=max_new_tokens,
)
if state.finished:
return
for _ in range(max_new_tokens - 1):
decode_start = perf_counter()
outputs = model(
input_ids=next_token.to(input_ids.device),
attention_mask=state.attention_mask,
past_key_values=state.past_key_values,
use_cache=True,
)
state.past_key_values = getattr(outputs, "past_key_values", None)
if state.past_key_values is None:
raise RuntimeError("model did not return past_key_values during decode")
next_token = _sample_from_outputs(outputs, temperature=temperature, generator=generator)
yield _record_step(
tokenizer,
state,
next_token,
eos_token_ids=eos_token_ids,
start=start,
max_new_tokens=max_new_tokens,
decode_start=decode_start,
)
if state.finished:
break
后续的循环这里,可以看到进入的已经不再是不断拼接的input_ids了,而是next_token ,也就是前一次生成的token,然后会通过past_key_values=state.past_key_values,带上前面的KV
推理
因此这次改动是单个请求内的KV Cache Reuse,prefill后decode复用,所以没办法在多个请求之间命中缓存,就没办法做那种演示了,但是bench是可以看出来kv_cache_decode.mean_prefill_seconds是非0
"elapsed_speedup": 1.066
"tpot_speedup": 1.073
现在总耗时和TOPT(Time per Output Token)都变快了,但是因为输入的prompt很短,没有更明显的差距体现
总结
这些大概就是引入KV Cache带来的变化,代码改动不多,也相对简洁,因为transformers这类框架帮我屏蔽了很多实现细节。
另外这里的KV Cache在GPU显存里,会涉及到每层 和每个token 都要存K/V,KV Cache的大小近似于:
2*L*T*H*dtype
| 参数 | 含义 |
|---|---|
| L | layer 数 |
| T | sequence length |
| H | hidden size |
| 2 | K+V |
比如我们简单算一个Qwen3 32B的:
264128k51202bytes/1024^3=~156.25GB
但是实际上Qwen3走了GQA(attention heads是40,kv heads是8,head_dim是128),所以实际大概会是33.5GB左右(GQA这些技术的意义来了)
可以看出大模型在推理的时候,显存会被大量的KV Cache占满!这个也是Infra里需要解决的一个重要课题。现在很多模型使用一些技术来降低KV Cache,列举几个,比如模型层可以做的有:
-
GQA(Grouped Query Attention)这种技术,Q Heads很多KV Heads很少,这样可以大量降低KV Cache
-
MQA(Multi-Query Attention):比GQA更激进,所有的Q共享同一组KV,但是效果会下降比较多
-
MLA(Multi-head Latent Attention):是DeepSeek很关键的方向,不直接存完整的KV,而是存压缩的latent(KV Compression),需要的时候再恢复
-
Sliding Window Attention:只看最近的窗口,比如看最近4k,而不是完整的1M上下文
-
Sparse Attention:不是所有的token都两两attention(比如只关注附近的token、少量关键的token以及一些summary token等)
Inference Engine层可以做的有:
-
PagedAttention:vllm主要的特性,kv cache做分页
-
Prefix Cache:共享相同前缀的prompt的kv,不重复做prefill
-
KV Quantization:KV不存bf16,改成存int8/int4,但是伴随量化也会带来精度下降
-
Distributed KV Cache:KV分布到多GPU,按head/layer/sequence去做shard
-
PD分离(Prefill-Decode Disaggregation):Prefill和Decode分不同机器,因为前者是Compute-bound型,后者是Memory-bound型,这也可以有不同的机器支撑
这些手段或多或少都在解决KV Cache相关的问题,只不过关注的角度不太一样。后续我们也会接触到里面的某些内容,其他的有价值值得写的也会单独有文章来聊
Enjoy Reading This Article?
Here are some more articles you might like to read next: