Batch Normalization技术介绍

  • A+
所属分类:深度学习
摘要这一部分介绍一下Batch Normalization,这个可以帮助我们在模型train不起来的时候,帮助模型的训练。十分有用的一个工具。

Batch Normalization工作原理

首先,我们输入的是一个batch,下面的例子中,我们可以将(x1,x2,x3)看成一个Batch

Batch Normalization技术介绍

接着,我们计算在经过一层网络后, 输出值(z1,z2,z3)的均值和标准差,如下图所示:

Batch Normalization技术介绍

接着,对(z1,z2,z3)进行标准化,标准化后z的均值为0,方差为1;

Batch Normalization技术介绍

最后,我们的输出为 z=A*z' + B,其中A和B是需要学习的参数;

我们注意到,当A=标准差, B=均值, 则相当于z_new = z,即相当于没有进行batch normalization.

Batch Normalization技术介绍

需要注意

  • Batch Normalization不能在小数据集上进行,因为均值方差的估计会不准确
  • 在test上时,不会计算test的均值与方差,会使用trian时候均值方差的移动平均来代替(下面会有一个例子);

Batch Normalization实际使用

在batch normalization的时候,我们在train和test的时候进行的操作是不同的这是由于在test的时候, 输入数据可能只有一个data, 故不能计算均值和标准差;

所以, 在test的时候, 会使用之前计算得到的均值和标准差做标准化。

初始化参数

首先,我们先初始化一些参数,下面初始化的有我们的测试数据X,测试数据的均值方差,迭代式子里最初的均值方差(均值0,方差1),和迭代的公式。

Batch Normalization技术介绍

迭代计算过程

下面看一下在输入是X的情况下,在训练模式下的输出结果和在测试模式下的输出结果。

Batch Normalization技术介绍

可以看到,在训练模式下,pytorch会对原始数据进行标准化,同时更新均值与方差(这个均值与方差是在test的时候,对test数据进行标准化的)。

在测试模式下,会使用在训练模式下反复计算的均值和方差来进行标准化。具体的计算过程可以结合上面的图进行推导。

第一行是第一次输出X后不同的输出,第二行是第二次输入X后不同的输出。

测试代码

把测试代码放在下面,方便自己测试。

  1. # Without Learnable Parameters(没有学习参数)
  2. # 这里是momentum=0.5的情况
  3. m = nn.BatchNorm1d(2, affine=False, eps=0, momentum=0.5, track_running_stats=True# 2为输出的特征数
  4. print("初始化 : {}, :mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  5. # --------
  6. # 第一轮
  7. # --------
  8. # 训练模式
  9. m.train()
  10. output = m(inputData)
  11. print("TrainModeP : {}, :mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  12. print("TrainMode:\n{}".format(output))
  13. # 测试模式
  14. print('------------')
  15. m.eval()
  16. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  17. output = m(inputData)
  18. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  19. print("EvalMode:\n{}".format(output))
  20. # --------
  21. # 第二轮
  22. # --------
  23. print('\n=====\n')
  24. # 训练模式
  25. m.train()
  26. output = m(inputData)
  27. print("TrainModeP : {}, :mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  28. print("TrainMode:\n{}".format(output))
  29. # 测试模式
  30. print('------------')
  31. m.eval()
  32. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  33. output = m(inputData)
  34. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  35. print("EvalMode:\n{}".format(output))
  36. # --------
  37. # 第三轮
  38. # --------
  39. print('\n=====\n')
  40. # 训练模式
  41. m.train()
  42. output = m(inputData)
  43. print("TrainModeP : {}, :mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  44. print("TrainMode:\n{}".format(output))
  45. # 测试模式
  46. print('------------')
  47. m.eval()
  48. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  49. output = m(inputData)
  50. print("EvalMode : {}, mean:{},var:{}".format(m.training,m.running_mean,m.running_var))
  51. print("EvalMode:\n{}".format(output))

Batch Normalization实验

之前在dropout的实验里,我们可以看到使用dropout后,准确率变化如下所示:

Batch Normalization技术介绍

在使用batch normalization后,我们可以看到准确率有了进一步的提升:

Batch Normalization技术介绍

原始模型train不起来

我们看一下原始模型训练不起来的情况,如下图所示,在训练集上准确率只有10%左右

Batch Normalization技术介绍

在使用batch normalization后,模型在训练集上的准确率可以提升到80%以上;

Batch Normalization技术介绍

结语

数据标准化是很重要的一步,会对模型最终的好坏起到至关重要的作用。Batch Normalization在model无法train起来的时候,也就是在训练集上结果很差的时候,会起到比较好的作用。

Batch Normalization技术介绍
  • 微信公众号
  • 关注微信公众号
  • weinxin
  • QQ群
  • 我们的QQ群号
  • weinxin
王 茂南

发表评论

:?: :razz: :sad: :evil: :!: :smile: :oops: :grin: :eek: :shock: :???: :cool: :lol: :mad: :twisted: :roll: :wink: :idea: :arrow: :neutral: :cry: :mrgreen: