Batch Normalization技术介绍

  • A+
所属分类:深度学习
摘要这一部分介绍一下Batch Normalization,这个可以帮助我们在模型train不起来的时候,帮助模型的训练。十分有用的一个工具。

Batch Normalization工作原理

首先,我们输入的是一个batch,下面的例子中,我们可以将(x1,x2,x3)看成一个Batch

Batch Normalization技术介绍

接着,我们计算在经过一层网络后, 输出值(z1,z2,z3)的均值和标准差,如下图所示:

Batch Normalization技术介绍

接着,对(z1,z2,z3)进行标准化,标准化后z的均值为0,方差为1;

Batch Normalization技术介绍

最后,我们的输出为 z=A*z' + B其中A和B是需要学习的参数, 这里A和B是包含在Batch Normalization层的参数; 例如我们直接查看Pytorch中BN层的参数, 可以发现是有weight和bias的.

  1. n = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6))
  2. n[1].state_dict()
  3. """
  4. OrderedDict([('weight', tensor([1., 1., 1., 1., 1., 1.])),
  5.              ('bias', tensor([0., 0., 0., 0., 0., 0.])),
  6.              ('running_mean', tensor([0., 0., 0., 0., 0., 0.])),
  7.              ('running_var', tensor([1., 1., 1., 1., 1., 1.])),
  8.              ('num_batches_tracked', tensor(0))])
  9. """

我们注意到,当A=标准差, B=均值, 则相当于z_new = z,即相当于没有进行batch normalization.

Batch Normalization技术介绍

需要注意

  • Batch Normalization不能在小数据集上进行,因为均值方差的估计会不准确
  • 在test上时,不会计算test的均值与方差,会使用trian时候均值方差的移动平均来代替(下面会有一个例子);

Batch Normalization实际使用

在batch normalization的时候,我们在train和test的时候进行的操作是不同的这是由于在test的时候, 输入数据可能只有一个data, 故不能计算均值和标准差;

所以, 在test的时候, 会使用之前计算得到的均值和标准差做标准化。

初始化参数

首先,我们先初始化一些参数,下面初始化的有我们的测试数据X,测试数据的均值方差,迭代式子里最初的均值方差(均值0,方差1),和迭代的公式。

Batch Normalization技术介绍

迭代计算过程

下面看一下在输入是X的情况下,在训练模式下的输出结果和在测试模式下的输出结果。

Batch Normalization技术介绍

可以看到,在训练模式下,pytorch会对原始数据进行标准化,同时更新均值与方差(这个均值与方差是在test的时候,对test数据进行标准化的)。

测试模式下,会使用在训练模式下反复计算的均值和方差来进行标准化。具体的计算过程可以结合上面的图进行推导。

第一行是第一次输出X后不同的输出,第二行是第二次输入X后不同的输出。

测试代码

把测试代码放在下面,方便自己测试。

  1. # Without Learnable Parameters(没有学习参数)
  2. # 这里是momentum=0.5的情况
  3. m = nn.BatchNorm1d(2, affine=False, eps=0, momentum=0.5, track_running_stats=True# 2为输出的特征数
  4. print("初始化 : {}, :mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  5. # --------
  6. # 第一轮
  7. # --------
  8. # 训练模式
  9. m.train()
  10. output = m(inputData)
  11. print("TrainModeP : {}, :mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  12. print("TrainMode:\n{}".format(output))
  13. # 测试模式
  14. print('------------')
  15. m.eval()
  16. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  17. output = m(inputData)
  18. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  19. print("EvalMode:\n{}".format(output))
  20. # --------
  21. # 第二轮
  22. # --------
  23. print('\n=====\n')
  24. # 训练模式
  25. m.train()
  26. output = m(inputData)
  27. print("TrainModeP : {}, :mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  28. print("TrainMode:\n{}".format(output))
  29. # 测试模式
  30. print('------------')
  31. m.eval()
  32. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  33. output = m(inputData)
  34. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  35. print("EvalMode:\n{}".format(output))
  36. # --------
  37. # 第三轮
  38. # --------
  39. print('\n=====\n')
  40. # 训练模式
  41. m.train()
  42. output = m(inputData)
  43. print("TrainModeP : {}, :mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  44. print("TrainMode:\n{}".format(output))
  45. # 测试模式
  46. print('------------')
  47. m.eval()
  48. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  49. output = m(inputData)
  50. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  51. print("EvalMode:\n{}".format(output))

Batch Normalization的手动实现

为了能够更好的说明Batch Normalization, 我们手动实现一下Batch Norm的功能. 首先是实现计算Batch Norm的函数, 包含:

  • moving_mean和moving_var的计算;
  • 包含在训练模式和测试模式下对X进行标准化;
  • 同时, 这里会有两个参数, 分别是gamma和beta, 是需要在训练的时候进行更新的;
  1. def batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):
  2.     # Use torch.is_grad_enabled() to determine whether the current mode is
  3.     # training mode or prediction mode
  4.     if not torch.is_grad_enabled(): # 预测模式就不需要计算均值和方差
  5.         # If it is the prediction mode, directly use the mean and variance
  6.         # obtained from the incoming moving average
  7.         X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)
  8.     else:
  9.         assert len(X.shape) in (2, 4) # 长度=2, 全连接; 长度=4, 卷积;
  10.         if len(X.shape) == 2:
  11.             # When using a fully connected layer, calculate the mean and
  12.             # variance on the feature dimension
  13.             mean = X.mean(dim=0)
  14.             var = ((X - mean) ** 2).mean(dim=0)
  15.         else:
  16.             # When using a two-dimensional convolutional layer, calculate the
  17.             # mean and variance on the channel dimension (axis=1). Here we
  18.             # need to maintain the shape of `X`, so that the broadcast
  19.             # operation can be carried out later
  20.             mean = X.mean(dim=(0, 2, 3), keepdim=True)
  21.             var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)
  22.         # In training mode, the current mean and variance are used for the
  23.         # standardization
  24.         X_hat = (X - mean) / torch.sqrt(var + eps)
  25.         # Update the mean and variance of the moving average
  26.         moving_mean = momentum * moving_mean + (1.0 - momentum) * mean
  27.         moving_var = momentum * moving_var + (1.0 - momentum) * var
  28.     Y = gamma * X_hat + beta  # Scale and shift
  29.     return Y, moving_mean, moving_var

定义了上面计算的函数之后, 下面就是定义BatchNorm layer. 这里会存储moving_mean和moving_var, 会存储参数gamma和beta.

  1. class BatchNorm(nn.Module):
  2.     # num_features: the number of outputs for a fully-connected layer
  3.     #   or the number of output channels for a convolutional layer.
  4.     # num_dims: 2 for a fully-connected layer and 4 for a convolutional layer.
  5.     def __init__(self, num_features, num_dims):
  6.         super().__init__()
  7.         if num_dims == 2:
  8.             shape = (1, num_features)
  9.         else:
  10.             shape = (1, num_features, 1, 1)
  11.         # The scale parameter and the shift parameter involved in gradient
  12.         # finding and iteration are initialized to 0 and 1 respectively
  13.         self.gamma = nn.Parameter(torch.ones(shape))
  14.         self.beta = nn.Parameter(torch.zeros(shape))
  15.         # All the variables not involved in gradient finding and iteration are
  16.         # initialized to 0 on the CPU
  17.         self.moving_mean = torch.zeros(shape)
  18.         self.moving_var = torch.zeros(shape)
  19.     def forward(self, X):
  20.         # If X is not on the CPU, copy `moving_mean` and `moving_var` to the
  21.         # device where `X` is located
  22.         if self.moving_mean.device != X.device:
  23.             self.moving_mean = self.moving_mean.to(X.device)
  24.             self.moving_var = self.moving_var.to(X.device)
  25.         # Save the updated `moving_mean` and `moving_var`
  26.         Y, self.moving_mean, self.moving_var = batch_norm(
  27.             X, self.gamma, self.beta, self.moving_mean,
  28.             self.moving_var, eps=1e-5, momentum=0.9)
  29.         return Y

上面计算过程与实际的layer分开的过程, 是非常好的. (This pattern enables a clean separation of math from boilerplate code.)

 

直接使用Pytorch中的BatchNorm

上面我们自己动手实现了BatchNorm的相关操作, 但是在实际情况中, 我们会直接使用框架中现成的函数, 这样使用起来就会方便很多. 下面我们将LeNet加上BatchNorm, 注意我们在全连接层的部分也是可以增加一维的BatchNorm的,  整体的代码如下所示:

  1. net = nn.Sequential(
  2.     nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),
  3.     nn.MaxPool2d(kernel_size=2, stride=2),
  4.     nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),
  5.     nn.MaxPool2d(kernel_size=2, stride=2), nn.Flatten(),
  6.     nn.Linear(7056, 120), nn.BatchNorm1d(120), nn.Sigmoid(),
  7.     nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),
  8.     nn.Linear(84, 10))

 

Batch Normalization实验

之前在dropout的实验里,我们可以看到使用dropout后,准确率变化如下所示:

Batch Normalization技术介绍

在使用batch normalization后,我们可以看到准确率有了进一步的提升:

Batch Normalization技术介绍

原始模型train不起来

我们看一下原始模型训练不起来的情况,如下图所示,在训练集上准确率只有10%左右

Batch Normalization技术介绍

在使用batch normalization后,模型在训练集上的准确率可以提升到80%以上;

Batch Normalization技术介绍

结语

数据标准化是很重要的一步,会对模型最终的好坏起到至关重要的作用。Batch Normalization在model无法train起来的时候,也就是在训练集上结果很差的时候,会起到比较好的作用。

Batch Normalization技术介绍
  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: