为啥现在的大模型,生成文字的时候,总得一个字一个字往外蹦?
你输入个几百字,它输出就得慢慢挤牙膏。
是模型本身算力不够吗?
不全是。
这里面其实藏着一个非常基础的效率问题,而解决这个问题的核心技术,就是今天要跟大家聊明白的 KV Cache。
1. 先铺垫一下:这些基础术语你得懂
聊KV Cache之前,得先把一些最基础的术语给你掰扯清楚。
不然上来就说key、value,你肯定听得云里雾里。
这篇文章哪怕是非机器学习背景的兄弟也能看懂,放心往下走。
标记(Tokens)
机器学习这玩意儿,只能看得懂数字,看不懂你说的人话。
所以你输入一句自然语言,第一步就得把它切分成一小块一小块,每一块就叫一个标记。
说白了,你可以大概理解成"拆成单词",当然实际分词比这个复杂,但你这么理解就够用了。
比如原句 This is a blog,分词完了就是 This is a blog 四个标记。
嵌入(Embeddings)
每个标记都得转成一个数字向量才能让模型算,这个向量就叫嵌入向量。
嵌入向量能干啥呢?它能把单词的语义给"编码"进去。简单说,向量的每个维度,其实都藏着某种潜在特征。
我给你举个例子你就明白了:
有一个维度专门管"是不是动物":"猫"这个词在这个维度得分就很高(1.0)
有一个维度专门管"可爱不可爱":"猫"在这儿得分也不低(3.6)
还有一个维度管"是不是城市":这跟猫八竿子打不着,所以得分极低(-3.0)
实际情况中,嵌入向量一般有几千个维度,我们虽然不知道每个维度具体对应啥,但模型自己心里门儿清,它能靠这些特征把不同的单词给区分开。
你可以这么理解:每个单词都是高维空间里的一个点——
语义相近的单词,比如"猫"和"狗",它们在空间里离得就近;
语义八竿子打不着的,比如"马"和"建筑",它们在空间里离得就远。
就这么简单。
另外还有个位置编码,就是给每个单词加个位置信息,让模型知道谁在前谁在后。这个细节咱今天就不展开了。
看到这儿估计你都懂了,催我赶紧进正题——行,满足你!
2. 纯解码器模型到底是个啥?
今天讲的KV Cache,主要用在ChatGPT这种纯解码器架构的大模型上。所以咱得先对解码器有个基础认知。
我今天把那些花里胡哨的细节全给你剥了,只留最核心的:
解码器干的就一件事儿——根据你已经输入的文字,猜下一个字最可能是什么。
猜完把这个字加到你输入的尾巴上,再猜下一个,循环往复,直到模型说"我说完了"。
比如你输入 the cat sat on the,模型如果训练得没问题,大概率猜下一个字是 mat。
那具体怎么猜呢?
原始的嵌入向量,只带了单词自己的语义,没带这个单词在当前句子里的"语境信息"。
举个经典例子:
1. Only Saad wants coffee.(只有萨阿德想喝咖啡——其他人都不想喝)
2. Saad only wants coffee.(萨阿德只想要咖啡——他不想要别的)
同样一个单词only,放在不同位置,意思完全不一样。
这就得靠注意力机制把原始嵌入转换成带语境的"上下文嵌入"。
转换完了,再经过一些后续处理,模型就能输出每个单词的概率,你挑概率最大的当下一个字就完事了。
以上就是解码器最核心的逻辑,听懂这些足够用了。
3. 一分钟搞懂:注意力机制到底在算什么?
好,现在来到最关键的部分——注意力机制到底是怎么工作的?
我给你用最生活化的例子讲明白:
每个单词进来,都要算出三个向量:查询(Query)、键(Key)、值(Value)。
我把整个序列里的每个单词都想象成一个人:
你(当前单词)就是那个发问的人,你身上带的就是Query——你问"谁跟我关系最密切?"
别人回答你,每个人身上带的就是Key——"我跟你关系有多密切"
最后每个人给你的信息本身,就是Value
就这么个事儿。
比如萨阿德在那喊"谁对我最重要?"
他朋友立马回应"我超级重要!"
他邻居说"我大概有那么一点点重要"
两个路人直接沉默不说话。
看到了吗?
这就是注意力机制
用下面这个公式表示:
这就是大模型的第一公式,搞懂了这个公式,也就搞懂了大模型。
有了上面那个萨阿德谁对我最重要的示意图
再看数学计算公式就不吓人了:
1. 每个单词的嵌入乘以三个矩阵WQ、WK、WV,分别得到Q、K、V
2. Q乘以K的转置,得到原始注意力分数——分数越高,说明两个单词关系越近
3. 除以根号dk做缩放,让训练更稳定
4. 过一遍Softmax,把分数归一化到0-1之间,每行加起来等于1
5. 最后用这个分数给V加权求和,得到最终的上下文嵌入
说白了就是:关系越近的单词,对你最终结果影响越大。
就这么简单。
那掩码自注意力又是啥?
因为模型是预测下一个单词,你不能让它"作弊"看到未来的单词对吧?
所以在算注意力的时候,你得把当前单词后面的所有位置都挡上,不让它看。具体做法就是在Softmax之前,把要挡的位置分数设成负无穷,这样Softmax完了就是0,这些位置就不产生影响了。
这种只能看过去和当前,不能看未来的机制,就叫因果注意力,用这种机制的语言模型就叫因果语言模型。
好了,到这儿你已经把注意力最核心的逻辑搞懂了。咱们接着往下走。
4. 大模型推理到底是怎么跑的?
"推理"这俩字听起来玄乎,说白了就是"用训练好的模型生成文字"。
你问"法国首都在哪",模型答"巴黎",这个过程就是推理。
那具体流程是啥样的?
有两个特殊记号你得先记住:
<SOS>:放在句子开头,告诉模型"从这儿开始"
<EOS>:放在句子结尾,告诉模型"到这儿结束"
流程走一遍:
假设你输入 I am:
1. 分词,得到两个标记
2. 开头加,现在序列是 I am
3. 每个标记转嵌入
4. 扔给模型,算完注意力出logits
5. 拿最后一个标记的logits,转成概率
6. 选概率最大的,假设是 drinking,加到序列后面,现在变成 I am drinking
7. 再来一遍,模型说不定下一个是 coffee,加上之后变成 I am drinking coffee
8. 再来一遍,模型输出,完事,停止生成
就这么循环往复,一句话就出来了。
看到这儿,你应该明白大模型生成文字的基本流程了。
那KV Cache到底在哪儿?
它解决了啥问题?
5. 没有KV Cache的时候,低效在哪?
我问你一个问题:每一步推理,我们真正关心的是什么?
只有最后一个标记的预测结果。
前面那些标记,都已经算完了,是已知的了。
但是!如果不用KV Cache,模型每一步都得把前面所有标记的QKV全部重新算一遍。
你生成第10个token的时候,前面9个token的K和V你之前不算过吗?
为啥还要再算一遍?
这不是纯纯瞎耽误功夫吗?
我给你捋一遍:
第一步:输入1个token,算1次 → 共1次
第二步:输入2个token,全部重算 → 共1+2=3次
第三步:输入3个token,全部重算 → 共1+2+3=6次
...
第N步:总共算 1+2+...+N = N(N+1)/2 次
时间复杂度直接干到O(n²)!
你生成100个token就要算5050次,生成1000个token就是50万次。
能不慢吗?
这就是没有KV Cache的时候,推理效率低下的根因——一直在重复计算已经算过的东西。
6. KV Cache:一句话就能讲明白核心思想
问题找到了,解决办法其实特别简单——
算过的就别再算了,存起来下次用不行吗?
这就是KV Cache全部的核心精髓了。
具体怎么做?
也不复杂:
每个新来的token,你只需要干三件事:
1. 只算这个新token自己的Key和Value
2. 把这两个向量加到缓存里
3. 只用这个新token的Query,去和缓存里所有的Key+Value算注意力
完事儿。
我给你走一遍流程你就懂了:
t=1:输入 I,算这俩的K和V,存进缓存 → 预测出下一个token can
t=2:输入新token can,只算can的K和V,追加到缓存 → 用can的Q配缓存里所有K+V算注意力 → 预测出 cook
t=3:输入新token cook,只算它的K和V,追加缓存 → 用它的Q算注意力 → 预测下一个
... 一直循环,直到输出 结束
整个过程下来,每个token的K和V一辈子只算一次。时间复杂度直接从 O(n²) 干到 O(n)。
所以,有KVcache后,每个token的计算,只是计算最新的token,不会重复计算之前的token时,所需要的K和V。
生成100个token,原来算5050次,现在只算100次。差50倍!
输入越长,差距越大。就问你这买卖划算不划算?
KV Cache推理其实分两个阶段,说出来更好理解:
1. 预填充阶段(Prefilling)
用户输入的提示词一般不止一个token对吧?第一步就是把提示词里所有token的K和V一口气全算完,一股脑存进缓存里。
2. 标记生成阶段(Token Generation)
提示词搞定了,接下来一个一个往外生成新token。每来一个新token,只算它自己的K和V,往缓存里一塞,就算完事。
就这么简单。
7. 有缺点吗?有,典型的空间换时间
KV Cache有巨大优势(计算快),它也有代价——占显存。
每个token的K和V都是向量,你都存在GPU显存里,输入越长,占的空间越多。
所以说白了,这就是拿空间换时间:
你多占点显存,省下大把重复计算
推理速度上去了,单位时间能出更多字
你说划算不划算?
现在GPU显存多金贵啊,但是你不开KV Cache,速度慢到根本没法用,显存再大也白搭。所以大部分情况下,这买卖血赚。
8. 最后给你划一遍重点
最后,咱们快速走一遍完整流程,帮你把知识点串起来:
1. 用户给你一个提示词,分词,开头加
2. 每个token转嵌入向量,加位置编码
3. 注意力机制把原始嵌入转成上下文嵌入
4. 用KV Cache存下已经算好的K和V,不再重复计算
5. 模型出logits,拿最后一个转概率,选最大的当下一个token
6. 新token加进去,重复步骤4-5,直到模型输出
完事儿。
就这么一套流程下来,推理速度直接提升几十倍。
这就是KV Cache对大模型推理的意义——它不是什么锦上添花的优化,是刚需。你现在用的所有大模型,基本上都开了KV Cache。
搞懂了这个最基础的优化,以后再看到什么"KV Cache量化"、"PagedAttention"这些新技术,你一下子就能明白它们到底在优化啥——要么帮你省更多显存,要么帮你跑得更快。
这就是基础知识的力量。
有收获的兄弟点个赞,咱们下期再唠。
本文摘自《KV Cache Explained Intuitively》
https://medium.com/@saad.ahmed1926q/kv-cache-explained-intuitively-2b425a36dfc7
475