大模型缓存机制与 KV Cache 深度解析

大模型缓存机制与 KV Cache 深度解析

一、大模型服务的缓存机制全景

目前主流大模型服务的缓存机制从底层到上层可以分为三个层次。

KV Cache(推理层) 是最底层的机制。Transformer 自回归生成时,每生成一个新 token 都需要 attend 到之前所有 token。KV Cache 把已经计算过的 Key/Value 张量缓存在显存中,避免逐步重复计算,是所有现代 Transformer 推理系统的基础设施。

Prompt Cache / Prefix Cache(跨请求层) 是 KV Cache 的”跨请求复用”。如果两次 API 请求的 prompt 前缀相同,前缀部分的 KV 计算结果完全一致,可以直接加载复用,跳过 prefill 阶段。OpenAI(自动 128-token 块匹配)、Anthropic(显式 cache_control 标记,缓存命中价格降至 10%)、Google Gemini、DeepSeek 都实现了这一机制。

Semantic Cache(应用层) 则不在模型推理层面。对用户 query 做 embedding,如果新 query 和历史 query 的语义相似度超过阈值,直接返回之前缓存的回答,完全跳过推理。GPTCache 是代表方案,适合 FAQ 等高重复场景,但语义匹配的精度有限。


二、KV Cache 的核心原理

2.1 为什么需要 KV Cache

Transformer 的自回归生成是逐 token 产出的。每生成一个新 token,Self-Attention 层都需要让这个新 token “看到”前面所有 token。如果每次都从头把整个序列跑一遍 attention 计算,前面已经算过的 token 的 K 和 V 会被反复重算,造成大量浪费。

KV Cache 的核心观察是:对于已经处理过的 token,它们的 K 和 V 在后续步骤中不会改变(因为线性变换权重固定,输入也没变),所以只需算一次,存下来复用。

2.2 Attention 基本公式

标准 Scaled Dot-Product Attention:

Attention(Q, K, V) = softmax(Q · Kᵀ / √d_k) · V

其中 Q、K、V 分别是输入 token 经过线性变换得到的 Query、Key、Value 矩阵,d_k 是 Key 的维度。

训练阶段整个序列一次性输入,Q/K/V 一次性算出,没有浪费。但推理的生成阶段每步只新增一个 token,重复计算的问题就出现了。

2.3 有无 KV Cache 的对比

无 KV Cache 的推理:

假设已经生成了 “I love” 两个 token,要生成第 3 个:

  • 第 1 步生成 “I”:输入 [prompt],计算所有 token 的 Q/K/V → 得到 “I”
  • 第 2 步生成 “love”:输入 [prompt, I],重新计算所有 token 的 Q/K/V → 得到 “love”
  • 第 3 步生成 “?”:输入 [prompt, I, love],再次重新计算所有 Q/K/V → 得到 “?”

每步都重算所有历史 token 的 K/V,到第 n 步总计算量为 O(1+2+…+n) = O(n²)。

有 KV Cache 的推理:

  • 第 1 步:输入 [prompt],计算所有 token 的 K/V → 存入 cache → 得到 “I”
  • 第 2 步:只输入 [I],计算它的 K₂/V₂ → 拼到 cache → 用完整 cache 做 attention → 得到 “love”
  • 第 3 步:只输入 [love],计算 K₃/V₃ → 拼到 cache → attention → 得到 “?”

每步只需对 1 个新 token 做线性变换,投影计算从 O(n²) 降到 O(n)。

2.4 KV Cache 里到底存了什么?

KV Cache 存的是每一层 Attention 中,每个已处理 token 经过线性变换后得到的 K 向量和 V 向量。不是原始 token,不是 embedding,而是投影之后的浮点数向量。

用一个极简模型来走一遍(1 层 Attention,d_model=4,d_k=2):

模型权重(固定不变):
W_k = [[0.1, 0.2],    # 形状 [4, 2]
        [0.3, 0.4],
        [0.5, 0.6],
        [0.7, 0.8]]

W_v = [[0.2, 0.1],    # 形状 [4, 2]
        [0.4, 0.3],
        [0.6, 0.5],
        [0.8, 0.7]]

Prefill 阶段,输入 prompt “我喜欢” 共 3 个 token,embedding 后:

x = [[1.0, 0.5, 0.3, 0.1],   ← "我"
     [0.2, 0.8, 0.6, 0.4],   ← "喜"
     [0.3, 0.1, 0.9, 0.7]]   ← "欢"
形状: [3, 4]

对每个 token 做投影:

K = x @ W_k   →  [[0.42, 0.56],    ← "我" 的 key 向量
                   [0.64, 0.84],    ← "喜" 的 key 向量
                   [0.97, 1.20]]    ← "欢" 的 key 向量

V = x @ W_v   →  [[0.38, 0.32],    ← "我" 的 value 向量
                   [0.72, 0.58],    ← "喜" 的 value 向量
                   [0.90, 0.79]]    ← "欢" 的 value 向量

此时 KV Cache 里存的就是这两个矩阵——一堆浮点数,每行对应一个 token:

kv_cache = {
    "K": [[0.42, 0.56],     # 3 行
          [0.64, 0.84],
          [0.97, 1.20]],
    "V": [[0.38, 0.32],     # 3 行
          [0.72, 0.58],
          [0.90, 0.79]]
}

Decode 阶段,模型生成了新 token “吃”,embedding 为 [0.5, 0.3, 0.7, 0.2]。只对它做投影:

k_new = [0.5, 0.3, 0.7, 0.2] @ W_k = [0.51, 0.66]    ← "吃" 的 key
v_new = [0.5, 0.3, 0.7, 0.2] @ W_v = [0.38, 0.33]    ← "吃" 的 value

“追加到 cache” 就是把新算出的这一行拼到矩阵末尾

kv_cache["K"] = [[0.42, 0.56],     ← "我"(已有)
                  [0.64, 0.84],     ← "喜"(已有)
                  [0.97, 1.20],     ← "欢"(已有)
                  [0.51, 0.66]]     ← "吃"(新增)  ← 这就是"追加"
# 从 [3, 2] 变成了 [4, 2]

然后 “吃” 的 query 向量去和 cache 里全部 4 行 K 做点积,得到对所有历史 token 的注意力分数,再加权求和 V,完成 attention。下一步生成时 cache 变成 5 行,再下一步 6 行……每步多一行。

总结:KV Cache 本质上就是两个不断”长高”的矩阵。 考虑完整模型(多层 × 多头),其形状为 [n_layers, 2, n_kv_heads, seq_len, d_head],”追加”操作就是在 seq_len 维度上 +1。

2.5 两个阶段

实际推理分为两个明确阶段:

  • Prefill 阶段:把整个 prompt 一次性灌入,并行计算所有 prompt token 的 K/V,一次性填满 cache。这是计算密集型(compute-bound),GPU 算力是瓶颈。
  • Decode 阶段:逐 token 生成。每步只处理 1 个 token,但需要读取整个 cache 做 attention。这是访存密集型(memory-bound),显存带宽是瓶颈。

三、KV Cache 的伪代码实现

# ============================================================
# 基础定义
# ============================================================

class AttentionLayer:
    """单个 Attention 层,包含 Q/K/V 的投影权重"""
    W_q: Matrix  # [d_model, d_model]
    W_k: Matrix  # [d_model, d_model]
    W_v: Matrix  # [d_model, d_model]
    W_o: Matrix  # [d_model, d_model]

class TransformerModel:
    layers: List[AttentionLayer]  # n_layers 个 attention 层
    embed: Embedding             # token → 向量
    lm_head: Linear              # 向量 → vocab logits


# ============================================================
# 方案一:朴素推理(无 KV Cache)—— 每步重算一切
# ============================================================

def generate_naive(model, prompt_tokens, max_new_tokens):
    tokens = prompt_tokens

    for step in range(max_new_tokens):
        # 每步把完整序列从头到尾跑一遍
        x = model.embed(tokens)          # [seq_len, d_model]

        for layer in model.layers:
            Q = x @ layer.W_q            # [seq_len, d_model]
            K = x @ layer.W_k            # [seq_len, d_model]  ← 重复计算!
            V = x @ layer.W_v            # [seq_len, d_model]  ← 重复计算!

            scores = (Q @ K.T) / sqrt(d_k)
            scores = causal_mask(scores)
            attn = softmax(scores, dim=-1)
            x = attn @ V
            x = x @ layer.W_o

        logits = model.lm_head(x[-1])
        next_token = sample(logits)
        tokens.append(next_token)

    return tokens


# ============================================================
# 方案二:KV Cache 推理 —— 增量计算,缓存复用
# ============================================================

def generate_with_kv_cache(model, prompt_tokens, max_new_tokens):
    n_layers = len(model.layers)
    kv_cache = [{"K": None, "V": None} for _ in range(n_layers)]

    # ===== 阶段一:Prefill =====
    # 整个 prompt 一次性灌入
    x = model.embed(prompt_tokens)       # [prompt_len, d_model]

    for i, layer in enumerate(model.layers):
        Q = x @ layer.W_q                # [prompt_len, d_model]
        K = x @ layer.W_k                # [prompt_len, d_model]
        V = x @ layer.W_v                # [prompt_len, d_model]

        kv_cache[i]["K"] = K             # 存入 cache
        kv_cache[i]["V"] = V

        scores = (Q @ K.T) / sqrt(d_k)
        scores = causal_mask(scores)
        attn = softmax(scores, dim=-1)
        x = attn @ V
        x = x @ layer.W_o

    logits = model.lm_head(x[-1])
    next_token = sample(logits)
    generated = [next_token]

    # ===== 阶段二:Decode(逐 token 生成)=====
    for step in range(max_new_tokens - 1):
        x = model.embed([next_token])     # [1, d_model] ← 只处理 1 个 token

        for i, layer in enumerate(model.layers):
            q = x @ layer.W_q             # [1, d_model]
            k_new = x @ layer.W_k         # [1, d_model]
            v_new = x @ layer.W_v         # [1, d_model]

            # 追加到 cache
            kv_cache[i]["K"] = concat(kv_cache[i]["K"], k_new)
            kv_cache[i]["V"] = concat(kv_cache[i]["V"], v_new)

            K_full = kv_cache[i]["K"]     # [seq_so_far, d_model]
            V_full = kv_cache[i]["V"]

            # 1 个 query 对所有历史 key 做 attention
            scores = (q @ K_full.T) / sqrt(d_k)  # [1, seq_so_far]
            attn = softmax(scores, dim=-1)
            x = attn @ V_full             # [1, d_model]
            x = x @ layer.W_o

        logits = model.lm_head(x[0])
        next_token = sample(logits)
        generated.append(next_token)

    return prompt_tokens + generated

计算量对比(直观数字)

假设 prompt 长度 P=100,生成 N=50 个 token,模型 L 层:

朴素方法 KV Cache
投影计算 L × 3 × (100+101+…+149) = L × 3 × 6225 Prefill: L×3×100 + Decode: L×3×50 = L×3×150
节省比例 ~97.6%

代价是显存中维护 KV Cache,大小为 2(K+V) × L × seq_len × d_model × dtype_size


四、KV Cache 的显存开销

以 LLaMA-2 70B 为例(n_layers=80,d_model=8192,FP16):

每个 token 的 KV Cache = 2 × 80 × 8192 × 2 bytes ≈ 2.5 MB
4096 长度序列 ≈ 10 GB
128K 上下文 × batch_size=8 ≈ 2.5TB(远超单卡显存)

这就是 KV Cache 最大的矛盾:它节省了计算,但吃掉了显存。


五、KV Cache 的优化方案

5.1 减少 KV 头数

Multi-Query Attention (MQA):所有 Q head 共享一组 K/V,KV Cache 缩小到 1/n_heads。 Grouped-Query Attention (GQA):折中方案,heads 分组,组内共享 K/V。LLaMA-2 70B 用 GQA(8 个 KV head vs 64 个 Q head),KV Cache 缩小到 1/8。

5.2 显存管理优化

PagedAttention(vLLM):借鉴 OS 虚拟内存分页。KV Cache 按固定大小的”页”动态分配,消除碎片,不同序列可共享相同前缀的页(copy-on-write),显存利用率大幅提升。

5.3 量化压缩

KV Cache 量化:用 INT8/INT4 存储 K/V,显存减半或更多。注意 K 的量化通常比 V 更敏感。

5.4 窗口与淘汰策略

  • Sliding Window Attention(Mistral):只保留最近 W 个 token 的 KV Cache,超出丢弃
  • StreamingLLM:保留开头几个 token(attention sink)+ 最近窗口,维持流式生成质量
  • H2O(Heavy Hitter Oracle):动态淘汰不重要的 KV

六、为什么各家的 Prompt Cache 有效期很短?

各家的缓存 TTL 普遍不长:OpenAI 约 5-10 分钟,Anthropic 约 5 分钟,DeepSeek 硬盘级缓存稍长但命中率低。原因是多方面的:

显存是最核心的瓶颈。 GPU 显存(HBM)是数据中心最昂贵的资源。一个 128K 上下文的 70B 模型单请求的 KV Cache 就要几十 GB。云服务需要同时服务数百万用户,不可能把每个人的 KV Cache 都长期驻留在显存中。缓存空间是零和博弈——留给用户 A 的 cache 就意味着用户 B 要重新计算。

分布式调度的复杂性。 大模型通常跨多张 GPU 甚至多台机器做张量并行。KV Cache 和模型权重绑定在特定 GPU 上。如果下一次请求被路由到了不同的 GPU 集群,之前的 cache 就无法直接使用。保持”同用户请求路由到同一组 GPU”(会话亲和性)会严重限制负载均衡的灵活性。

缓存命中率的经济学。 缓存越久,边际命中率越低——大部分复用发生在短时间内(比如多轮对话的连续请求)。长时间不活跃的 cache 命中概率很低,但持续占用宝贵的显存/存储资源。从成本效益角度看,短 TTL + LRU 淘汰是最优解。

落盘方案的延迟问题。 理论上可以把 KV Cache 从显存卸载到主存甚至 SSD。DeepSeek 确实这么做了(硬盘缓存)。但从 SSD 加载 KV Cache 到显存的延迟远大于直接重新计算(PCIe 带宽瓶颈),只有在 prompt 非常长(prefill 开销巨大)时才划算。对于中短 prompt,直接重算反而更快。

简而言之,不只是”显存不够”这一个原因,而是显存成本、调度复杂性、命中率衰减、落盘延迟这四个因素共同决定了短 TTL 是当前的最佳工程折中。


整理日期:2026-04-29

发表回复