大家好,我是写代码的中年人!
自注意力(Self-Attention)是大模型里最常让人“眼花”的魔术道具:看起来只是一堆矩阵乘法和 softmax,可是组合起来就能学到“句子里谁重要、谁次要”的规则,甚至能学到某些头只盯标点、某些头专盯主谓关系。
今天我想把这块魔术板拆开来给你看个究竟:如何把单头注意力改成多头注意力,让每个头能学会自己的注意力分布。
01
回顾单头自注意力机制
假设你在开会,桌上有一堆文件,你想找跟“项目进度”相关的内容。
你心里有个问题(Query):“项目进度在哪儿?
”每份文件上有个标签(Key),写着它的主题,比如“预算”“进度”“人员”。
你会先挑出标签里跟“进度”相关的文件(匹配),然后重点看这些文件的内容(Value),最后把这些内容总结成你的理解。
自注意力就像是给每个词都做了一次这样的“信息筛选和总结”,让每个词都能根据上下文更好地表达自己。
02
理解多头自注意力机制
继续用开会的场景:
桌上还是那堆文件(代表句子里的词),但现在你不是一个人干活,而是找了3个助手(假设3头注意力)。每个助手都有自己的“专长”,他们会从不同的角度问问题、匹配标签和提取内容。
每个头独立工作(多视角筛选):
头1(进度专家):他的问题(Query)是“进度怎么样?”他只关注标签里跟“进度”“时间表”相关的文件,忽略其他。挑出匹配的文件后,他总结出一份“进度报告”。
头2(预算专家):他的问题是“预算超支了吗?”他匹配标签里的“预算”“开销”,然后从那些文件的内容里提炼“预算分析”。
头3(风险专家):问题是“有什么隐患?”他找“风险”“问题”相关的标签,输出一份“风险评估”。
每个头都像单头注意力一样:生成自己的问题、钥匙和内容,计算匹配度,加权总结。但他们用的“眼镜”不同(在机器里,这通过不同的线性变换实现),所以捕捉的信息侧重点不一样。
把多头结果合起来(综合决策):
一旦每个头都给出自己的总结,你就把这些报告拼在一起(或简单平均一下),形成一份完整的“项目概览”。现在,你的理解不只是“进度”,而是进度+预算+风险的全方位视图。万一某个头漏了什么,其他头能补上,确保没死角。
03
用代码实现多头自注意力机制
我们使用水浒传的内容进行演示,使用前三回各 100 字的文本,并按“字”切分成模型可用的格式。
import torchimport torch.nn as nnimport torch.nn.functional as Fimport torch.optim as optimimport matplotlib.pyplot as pltimport numpy as npplt.rcParams['font.sans-serif'] = ['SimHei']plt.rcParams['axes.unicode_minus'] = False# ====== 准备水浒传真实语料 ======raw_texts = ["話說大宋仁宗天子在位,嘉祐三年三月三日五更三點,天子駕坐紫宸殿,受百官朝賀。但見:祥雲迷鳳閣,瑞氣罩龍樓。含煙御柳拂旌旗,帶露宮花迎劍戟。天香影裏,玉簪珠履聚丹墀。仙樂聲中,繡襖錦衣扶御駕。珍珠廉卷,黃金殿上現金輿。鳳尾扇開,白玉階前停寶輦。隱隱凈鞭三下響,層層文武兩班齊。","那高俅在臨淮州,因得了赦宥罪犯,思鄉要回東京。這柳世權卻和東京城里金梁橋下開生藥鋪的董將士是親戚,寫了一封書札,收拾些人事盤纏,赍發高俅回東京,投奔董將士家過活。","話說當時史進道:「卻怎生是好?」朱武等三個頭領跪下答道:「哥哥,你是乾淨的人,休為我等連累了大郎。可把索來綁縛我三個,出去請賞,免得負累了你不好看。」"]# ====== 按字切分 ======def char_tokenize(text):return [ch for ch in text if ch.strip()] # 去掉空格、换行sentences = [char_tokenize(t) for t in raw_texts]# 构建词表vocab = {}for sent in sentences:for ch in sent:if ch not in vocab:vocab[ch] = len(vocab)# ====== 转成索引形式并做 padding ======max_len = max(len(s) for s in sentences)PAD_TOKEN = "<PAD>"vocab[PAD_TOKEN] = len(vocab)input_ids = []for sent in sentences:ids = [vocab[ch] for ch in sent]# paddingids += [vocab[PAD_TOKEN]] * (max_len - len(ids))input_ids.append(ids)input_ids = torch.tensor(input_ids) # (batch_size, seq_len)# ====== 多头自注意力模块 ======class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, dropout=0.1):super().__init__()assert embed_dim % num_heads == 0, "embed_dim 必须能整除 num_heads"self.embed_dim = embed_dimself.num_heads = num_headsself.head_dim = embed_dim // num_headsself.q_proj = nn.Linear(embed_dim, embed_dim)self.k_proj = nn.Linear(embed_dim, embed_dim)self.v_proj = nn.Linear(embed_dim, embed_dim)self.out_proj = nn.Linear(embed_dim, embed_dim)self.dropout = dropoutself.last_attn_weights = None # 保存最后一次注意力权重 (batch, heads, seq, seq)def forward(self, x):B, T, C = x.size()Q = self.q_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)K = self.k_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)V = self.v_proj(x).view(B, T, self.num_heads, self.head_dim).transpose(1, 2)scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)attn_weights = F.softmax(scores, dim=-1)attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)self.last_attn_weights = attn_weights.detach() # (B, heads, T, T)out = torch.matmul(attn_weights, V)out = out.transpose(1, 2).contiguous().view(B, T, C)out = self.out_proj(out)return out# ====== 模型训练 ======embed_dim = 32num_heads = 4vocab_size = len(vocab)embedding = nn.Embedding(vocab_size, embed_dim)model = MultiHeadSelfAttention(embed_dim, num_heads)criterion = nn.MSELoss()optimizer = optim.Adam(list(model.parameters()) + list(embedding.parameters()), lr=1e-3)epochs = 200for epoch in range(epochs):model.train()x = embedding(input_ids)target = x.clone()out = model(x)loss = criterion(out, target)optimizer.zero_grad()loss.backward()optimizer.step()if (epoch + 1) % 20 == 0:print(f"Epoch {epoch+1:3d}, Loss: {loss.item():.6f}")# ====== 可视化注意力热图 ======for idx, sent in enumerate(sentences):attn = model.last_attn_weights[idx] # (heads, seq, seq)sent_len = len(sent)for head in range(num_heads):plt.figure(figsize=(8, 6))plt.imshow(attn[head, :sent_len, :sent_len].numpy(), cmap='viridis')plt.title(f"第{idx+1}句 第{head+1}头 注意力矩阵")plt.xticks(ticks=np.arange(sent_len), labels=sent, rotation=90)plt.yticks(ticks=np.arange(sent_len), labels=sent)plt.xlabel("Key (字)")plt.ylabel("Query (字)")plt.colorbar(label="Attention Strength")for i in range(sent_len):for j in range(sent_len):plt.text(j, i, f"{attn[head, i, j]:.2f}", ha="center", va="center", color="white", fontsize=6)plt.tight_layout()plt.savefig(f"attention_sentence{idx+1}_head{head+1}.png")plt.close()print("注意力热图已保存。")
这些多头自注意力(Multi-Head Self-Attention)的热图,其实是一个“谁在关注谁”的可视化工具,用来直观展示模型在处理文本时的注意力分布。
热图上的颜色:横轴(Key):表示句子中被关注的字,纵轴(Query):表示当前在思考的字,颜色深浅:表示注意力强度,越亮的地方代表这个 Query 在计算时更关注这个 Key。
例如,如果“宋”字在看“天”字时颜色很亮,说明模型觉得“天”这个字对理解“宋”有重要信息。因为是古文,有时模型会捕捉到常见的修辞搭配,比如“天子”“鳳閣”,这时候相邻的字之间注意力会很高。
为什么会有多张图:每一行热图对应一句文本(水浒前三回的一个片段)每句话会画多个头的热图:多头机制的设计就是让不同的头学习到不同的关注模式举个例子:Head 1 可能更多关注相邻的字(局部模式)Head 2 可能更关注句首或特定关键词(全局模式)Head 3 可能专注某个语法结构Head 4 可能专注韵律、排比等古文特性
多头机制就像多双眼睛,从不同角度观察同一句话。
举个大家都能理解的例子:
学生(Query):举手发言
老师(Attention):环顾四周,看看应该关注哪个学生(Key)
不同的老师(Head)关注点不同:一个老师喜欢看前排学生(局部依赖)一个老师总是看坐在角落的安静同学(远距离依赖)还有老师会特别注意那些名字里有“天”“龙”这些关键字的学生
(关键触发词)颜色越亮,表示老师对这个学生说的话越感兴趣。
结束语
回到开头我们的问题:多头自注意力到底在看什么?通过水浒传这样真实、结构独特的古文片段,我们不仅看到了模型如何在字与字之间建立联系,还直观感受了不同“注意力头”各自的关注模式。有人关注近邻字,有人专注关键字,有人把目光投向整句的节奏与意境。
这就像课堂上不同的老师一样——他们的视角不同,但共同构成了对整篇文章的完整理解。这种可视化,不只是为了“看个热闹”,而是把模型内部的决策过程摊开给人看,让深度学习的“黑箱”多了一点可解释性。
至此,我们用水浒的诗意古文,让多头自注意力的数学公式“活”了起来。接下来,我们将整合所有已学过的文章,去实现一个生成模型。
1538