文章目录(Table of Contents)
简介
在实际中,我们希望有一个这样的效果:给定相同的 queries、keys 和 values,我们希望模型可以学习到不同的内容,然后将这些内容给组合起来。
为了实现「多头注意力」,我们可以使用不同的变化去将 queries、keys 和 values 变为不同的值,接着利用变换后的去计算注意力。最后将 N 个注意力的输出拼接在一起,并通过另外一个可以学习的线性投影进行变化,以产生最终的输出。下图展示一个「多头注意力」模型结构:
可以看到 queries、keys 和 values 通过不同的「全连接层」计算出不同的值,来计算「注意力」。最后将注意力 concat 在一起,在通过一个「全连接层」。
参考资料
- Multi-Head Attention,D2L 英文版 Multi-Head Attention 的内容;
- 多头注意力,D2L 中文版「多头注意力」的内容;
- 68 Transformer【动手学深度学习v2】, B 站视频,介绍「多头注意力」;
- Attention-mechanisms-and-transformers,本文对应的代码,multihead_attention 开头的;
多头注意力模型
模型描述--图解
下面我们以 2 个 heads 作为例子,看一下「多头注意力」模型。如下图所示,我们首先将同一个 query 乘不同的系数(q_i = W*q)得到不同的 q。keys 和 values 也进行类似的处理。此时相当于一个相同的 query,我们通过乘不同的系数得到了两个新的 query。
上面我们就得到了两组 queries、keys 和 values。接着我们分别计算注意力。下图首先使用一组进行注意力的计算。注意这里的 b 是加权之后的结果。
接下来计算另外一组的注意力。如下图所示,此时我们就有了 b1 和 b2。
在有了 b1 和 b2 之后,我们将其合并,在通过一个「全连接层」来作为输出。
模型解释-数学语言
上面我们非常直观的看了「多头注意力」的计算方式。下面我们用数学语言将这个模型形式化地描述出来。给定「查询 q」,「键 k」和「值 v」,给个「注意力头 h」 的计算方式如下:
这里可以学习的参数有三个 W,以及注意力汇聚函数 f(这里 f 是加性注意力和缩放点积注意力)。多头注意力的输出需要经过另一个线性转换,如下所示,这里也有一个可以学习的参数 W。
基于这种设计,每个头都可能会关注输入的不同部分, 可以表示比简单加权平均值更复杂的函数。
多头注意力代码实现
上面我们解释了「多头注意力」的机制,下面看一下代码的实现。具体计算的时候,为了可以并行运算,我们不会定义多组 W,而是会定义一个大的 W,然后进行拆分。
我们首先定义 query 和 key-value pair。共有 4 个 queries, 6 个 key-value pairs。且我们设置 heads 的数量为 5。
- batch_size = 2 # batch 数量
- num_heads = 5 # multi head 的数量
- num_queries = 4 # query 的数量为 4
- num_kvpairs = 6 # kvpair 的数量为 6
- valid_lens = torch.tensor([3, 2])
- X = torch.ones((batch_size, num_queries, 10)) # query
- Y = torch.ones((batch_size, num_kvpairs, 20)) # key, value
- print(X.shape, Y.shape)
- # torch.Size([2, 4, 10]) torch.Size([2, 6, 20])
上面我们提到「多头注意力」是需要使用不同的参数 W 去乘 「查询 q」,「键 k」和「值 v」。但是为了可以并行运算,比如一个参数 W 可以将其转换为 20 维的特征,一共有 5 个 heads,那么我们就直接使用转换为维度为 100,然后再进行切分,将 head 数据和 batch 数据放在一个维度。
我们首先定义函数 _transpose_qkv,目的是将 head 维度拆出来。例如本来是 (batch_size,查询或者“键-值”对的个数,num_hiddens),这里 num_hiddens 其实是包含所有 heads 的,于是函数 _transpose_qkv 可以将其转换为 (batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads),然后把 heads 和 batch 和在一起,变为 (batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)。
- def _transpose_qkv(X, num_heads):
- """为了多注意力头的并行计算而变换形状, 不要因为 multi-head, 而多一个循环
- """
- # 输入 X 的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)
- # 输出 X 的形状:(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)
- # 有多少个 head, 做一下切分, 比如会从 torch.Size([2, 4, 100]) --> torch.Size([2, 4, 5, 20])
- X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)
- # 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
- # 交换维度, 将 num_heads 和 查询或者“键-值”对的个数 交换
- # 此时的大小为 torch.Size([2, 5, 4, 20])
- X = X.permute(0, 2, 1, 3)
- # 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数, num_hiddens/num_heads)
- # 最终大小变为 torch.Size([10, 4, 20])
- return X.reshape(-1, X.shape[2], X.shape[3])
具体「查询 q」,「键 k」和「值 v」的转换如下。batchsize=2,num_heads=5,最终输出的大小第一维为 10,也就是 batch_size*num_heads。
可以理解为,共有 10 组数据,每组数据有 4 个 query 和 6 和 key-value pairs。这些值的特征维度都是 20。要对这 4 个 query 计算特征。
- num_hiddens = 100
- W_q = nn.LazyLinear(num_hiddens, bias=False)
- W_k = nn.LazyLinear(num_hiddens, bias=False)
- W_v = nn.LazyLinear(num_hiddens, bias=False)
- queries = _transpose_qkv(X=W_q(X), num_heads=num_heads) # torch.Size([10, 4, 20])
- keys = _transpose_qkv(W_k(Y), num_heads=num_heads) # torch.Size([10, 6, 20])
- values = _transpose_qkv(W_v(Y), num_heads=num_heads) # torch.Size([10, 6, 20])
- print(queries.shape, keys.shape, values.shape)
- # torch.Size([10, 4, 20]) torch.Size([10, 6, 20]) torch.Size([10, 6, 20])
接下来计算这 10 组数据的注意力。最终 out 的大小为 (batch_size * num_heads, no. of queries, num_hiddens / num_heads)。
- attention = DotProductAttention(0.1)
- valid_lens = torch.repeat_interleave(valid_lens, repeats=num_heads, dim=0)
- output = attention(queries, keys, values, valid_lens)
- print(output.shape) # 计算每个 query 的值, 为 (10, 4, 20)
- # torch.Size([10, 4, 20])
我们相当于是将 multihead 的当作不同 batch 在计算,现在需要把 mutil-head 放回去。形状改变的代码如下所示,也就是将输入形状从 (batch_size * num_heads, no. of queries, num_hiddens / num_heads) 转换为 (batch_size , no. of queries, num_hiddens)。这里相当于已经是将不同 head 得到的结果进行了拼接了。
- def _transpose_output(X, num_heads):
- """逆转transpose_qkv函数的操作
- """
- # 将 num_heads 的维度信息从 batch size 中拿出来
- # 也就是 torch.Size([10, 4, 20]) --> torch.Size([2, 5, 4, 20]), 这里 num_head = 5
- X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])
- # 交换 num_heads 与 查询或者“键-值”对的个数
- # 此时从 torch.Size([2, 5, 4, 20]) --> torch.Size([2, 4, 5, 20])
- X = X.permute(0, 2, 1, 3)
- # 最后将 num_head 的信息合并起来
- # 从 torch.Size([2, 4, 5, 20]) --> torch.Size([2, 4, 100])
- return X.reshape(X.shape[0], X.shape[1], -1)
我们对前面输出的 out 进行形状的变换。最终的大小从 (10, 4, 20) 变为 (2, 4, 100)。
- output_concat = _transpose_output(output, num_heads=num_heads)
- # 将 num_head 从 batch_size 中拿出来, torch.Size([10, 4, 20]) --> torch.Size([2, 4, 100])
- print(output_concat.shape)
- # 这里的 100 相当于是 multi-head 集和后的结果
最后再通过一个线性变换就可以得到最终的结果。也就是对应下面的式子:
实现起来就是多加一个全连接层即可。
- W_o = nn.LazyLinear(32, bias=False)
- result = W_o(output_concat)
- print(result.shape) # torch.Size([2, 4, 100]) --> torch.Size([2, 4, 32])
- # torch.Size([2, 4, 32])
- 微信公众号
- 关注微信公众号
-
- QQ群
- 我们的QQ群号
-









评论