使用注意力机制的 seq2seq

王 茂南 2022年10月15日07:24:51
评论
5475字阅读18分15秒
摘要这一篇中,我们介绍使用注意力机制的 Seq2Seq。我们会将「注意力机制」用在 Seq2Seq 模型上面。这样在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。我们会介绍模型的整体结构,和实现的相关代码。

简介

前面我们介绍了注意力机制注意力分数的内容。在这一部分,我们会将「注意力机制」用在 Seq2Seq 模型上面。这样在预测词元时,如果不是所有输入词元都相关,模型将仅对齐(或参与)输入序列中与当前预测相关的部分。

注意力机制的关键就是搞明白,什么作为 keyvalue,什么作为 query。在这里 keyvalueencoder 网络对每个词的 output(这里文档中说的是使用 encoderhidden state,但是代码中好像是 output)。querydecoder 将输入的词的 hidden state

参考资料

 

Seq2Seq with Attention

模型结构介绍

整体的模型结构如下所示。注意看下面「注意力模块」的三个输入,两个来自「编码器」,作为 keyvalue(这个是编码器对每个字的 hidden state)。一个来自「解码器」,作为 query(这个是解码器对输入字的 hidden state)。

使用注意力机制的 seq2seq

于是将数据流的表达转换为公式,如下所示。其中 hencoder 对输入的每个词计算的 hidden state。而 sdecoder 对输入的上个词计算的 hidden state(作为 query,注意这里 st-1,也就是计算的上一个词)。

使用注意力机制的 seq2seq

上面式子的含义就是,我们将当前词的 hidden stateencoder 里面每一个词的hidden state 计算相似度,然后在乘encoder 里面上一个词的hidden state,最终得到的值作为当前这个字的表示。最后再会和这个词 embedding 的结果拼接起来。

 

Seq2Seq with Attention 代码实现

上面我们介绍了使用注意力机制的 seq2seq 的模型结构。这里我们来看一下相关的代码实现。可以看到上面的结构部分,加入「注意力机制」之后,encoder 部分是没有变化的,需要变化的是 decoder 的部分。下面是完整的 decoder 部分的代码(几乎都写了非常详细的注释了):

  1. class Seq2SeqAttentionDecoder(AttentionDecoder):
  2.     """Encoder 部分是没有改变的, Decoder 部分加入了 Attention.
  3.     """
  4.     def __init__(
  5.                 self,
  6.                 vocab_size,
  7.                 embed_size,
  8.                 num_hiddens,
  9.                 num_layers,
  10.                 dropout=0
  11.             ):
  12.         super().__init__()
  13.         self.attention = AdditiveAttention(num_hiddens, dropout) # 加性 Attention
  14.         self.embedding = nn.Embedding(vocab_size, embed_size) # Embedding 层
  15.         self.rnn = nn.GRU(
  16.             embed_size + num_hiddens,
  17.             num_hiddens,
  18.             num_layers,
  19.             dropout=dropout
  20.         )
  21.         self.dense = nn.LazyLinear(vocab_size) # 输出每一个词的概率
  22.         self.apply(init_seq2seq)
  23.     def init_state(self, enc_outputs, enc_valid_lens):
  24.         # Shape of outputs: (num_steps, batch_size, num_hiddens). 为 encoder 中的 output, 大小为, torch.Size([9, 128, 256])
  25.         # Shape of hidden_state: (num_layers, batch_size, num_hiddens), 初始化为 encoder 中的 hidden state, 大小为 torch.Size([2, 128, 256])
  26.         # enc_valid_lens, 表示 encode 句子中哪些是 padding 的
  27.         outputs, hidden_state = enc_outputs
  28.         return (outputs.permute(1, 0, 2), hidden_state, enc_valid_lens) # output 大小转换后为 torch.Size([128, 9, 256])
  29.     def forward(self, X, state):
  30.         # 这里 num_steps 是原始句子的长度
  31.         # Shape of enc_outputs: (batch_size, num_steps, num_hiddens). 每一的 output --> 作为 key 和 value, torch.Size([128, 9, 256])
  32.         # Shape of hidden_state: (num_layers, batch_size, num_hiddens), torch.Size([2, 128, 256])
  33.         enc_outputs, hidden_state, enc_valid_lens = state
  34.         # Shape of the output X: (num_steps, batch_size, embed_size)
  35.         X = self.embedding(X).permute(1, 0, 2) # torch.Size([9, 128, 256])
  36.         # #########################
  37.         # 下面用到了 Attention 的机制
  38.         # #########################
  39.         outputs, self._attention_weights = [], []
  40.         for x in X: # decoder 中每一个字需要依次输入
  41.             # Shape of query: (batch_size, 1, num_hiddens), torch.Size([128, 1, 256])
  42.             # query 是上一个时间的 RNN 的输出, 其实就是上一个字 Embedding 的结果
  43.             query = torch.unsqueeze(hidden_state[-1], dim=1) # 只取最后一个 layer 的结果
  44.             # Shape of enc_outputs: (batch size, num_steps, h), torch.Size([128, 9, 256])
  45.             # Shape of context: (batch_size, 1, num_hiddens), torch.Size([128, 1, 256]), 加权之后的结果
  46.             # query, key, value
  47.             context = self.attention(query, enc_outputs, enc_outputs, enc_valid_lens) # Attention 的关键, context 的计算
  48.             # Concatenate on the feature dimension
  49.             x = torch.cat((context, torch.unsqueeze(x, dim=1)), dim=-1) # context 和 x 合并起来
  50.             # 最终 x 的大小为, torch.Size([128, 1, 512])
  51.             # Reshape x from (batch_size, 1, embed_size + num_hiddens) to (1, batch_size, embed_size + num_hiddens), 此时是为了适应 gru 输入的大小
  52.             out, hidden_state = self.rnn(x.permute(1, 0, 2), hidden_state) # 这里 out 的大小为 torch.Size([1, 128, 256])
  53.             outputs.append(out)
  54.             self._attention_weights.append(self.attention.attention_weights) # 存储 attention weights
  55.         outputs = self.dense(torch.cat(outputs, dim=0)) # cat 之后大小为 torch.Size([9, 128, 256])
  56.         # After fully connected layer transformation, shape of outputs:
  57.         # (num_steps, batch_size, vocab_size), 也就是 torch.Size([9, 128, 214])
  58.         return outputs.permute(1, 0, 2), [enc_outputs, hidden_state, enc_valid_lens]
  59.     @property
  60.     def attention_weights(self):
  61.         return self._attention_weights

接下来我们来过一下上面的代码。通过一个具体的例子,看一下每行输入的是什么。首先先说明一下相关的参数:

  • embedding size = 256,每个字符变为长度为 256 的向量;
  • 每句话都填充为 9 个词(最大长度为 9);
  • 使用 GRU,其中输出大小为 512,共有两层;
  • batch size128

 

Encoder 部分数据解释

选择一个 batch 的数据,经过 encoder 之后,分别输出 outputhidden state,大小分别是(这里 output 包含 9 个词的信息,之后会作为 keyvalue):

  • output:torch.Size([9, 128, 512]),其中 (num_steps, batch_size, num_hiddens);
  • hidden state:torch.Size([2, 128, 512]),其中 (num_layers, batch_size, num_hiddens);

 

Decoder 部分数据解释

接着就到了 decoder 的部分。首先是 init_state,这里的作用就是得到 encoder 输出的 outputhidden state

接着到了 forward 部分。传入两个信息

  • 一个是 X,也就是 decoder 的输入的字(比如是英文翻译法语,那这里就是翻译后的每个法语的单词)。此时 X 的大小为 torch.Size([128, 9]),因为句子长度为 9,所以有 9 个单词;
  • 一个就是 state,包含 init_state 返回的值,用于计算权重。
    • enc_outputs,(batch_size, num_steps, num_hiddens),torch.Size([128, 9, 512])。这里是包含 9 个单词的信息的,后面会作为 keyvalue
    • hidden_state,(num_layers, batch_size, num_hiddens),torch.Size([2, 128, 512])。作为第一个 query

接着对于 X 中的每一个字,分别计算权重:

  • 首先得到 query,就是上面的 hidden_state[-1],变换大小后为 torch.Size([128, 1, 512]);
  • enc_outputs 作为 keyvalue,计算 attention,得到 context,torch.Size([128, 1, 512]),也就是 (batch_size, 1, num_hiddens),这是这个词的结果。
  • context 与这个 xembedding 的结果合并,得到信息,此时的大小为 torch.Size([128, 1, 768]),也就是 (batch_size, 1, embed_size + num_hiddens)。
  • 将得到的信息放入 rnn 中,得到 output(在过一个全连接层,得到对词的预测)和 hidden state,作为下一个的 query(替换上面的 hidden state)。这样循环往复,得到所有词的预测。

  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南
  • 本文由 发表于 2022年10月15日07:24:51
  • 转载请务必保留本文链接:https://mathpretty.com/15205.html
匿名

发表评论

匿名网友 填写信息

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: