利用Pytorch进行CNN详细剖析

本文缘起于一次CNN作业中的一道题,这道题涉及到了基本的CNN网络搭建,在MNIST数据集上的分类结果,Batch
Normalization的影响,Dropout的影响,卷积核大小的影响,数据集大小的影响,不同部分数据集的影响,随机数种子的影响,以及不同激活单元的影响等,能够让人比较全面地对CNN有一个了解,所以想做一下,于是有了本文。

工具

开源深度学习库: PyTorch

数据集: MNIST

实现

初始要求

首先建立基本的BASE网络,在Pytorch中有如下code:


  1. class Net(nn.Module): 
  2.     def __init__(self): 
  3.         super(Net, self).__init__() 
  4.         self.conv1 = nn.Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), padding=0) 
  5.         self.conv2 = nn.Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1), padding=0) 
  6.         self.fc1 = nn.Linear(4*4*50, 500) 
  7.         self.fc2 = nn.Linear(500, 10) 
  8.  
  9.     def forward(self, x): 
  10.         x = F.max_pool2d(self.conv1(x), 2) 
  11.         x = F.max_pool2d(self.conv2(x), 2) 
  12.         x = x.view(-1, 4*4*50) 
  13.         x = F.relu(self.fc1(x)) 
  14.         x = self.fc2(x) 
  15.         return F.log_softmax(x) 

这部分代码见 base.py 。

问题A:预处理

即要求将MNIST数据集按照规则读取并且tranform到适合处理的格式。这里读取的代码沿用了BigDL Python Support的读取方式,无需细说,根据MNIST主页上的数据格式可以很快读出,关键block有读取32位比特的函数:


  1. def _read32(bytestream): 
  2.     dt = numpy.dtype(numpy.uint32).newbyteorder('>')    # 大端模式读取,最高字节在前(MSB first) 
  3.     return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] 

读出后是(N, 1, 28,
28)的tensor,每个像素是0-255的值,首先做一下归一化,将所有值除以255,得到一个0-1的值,然后再Normalize,训练集和测试集的均值方差都已知,直接做即可。由于训练集和测试集的均值方差都是针对归一化后的数据来说的,所以刚开始没做归一化,所以forward输出和grad很离谱,后来才发现是这里出了问题。

这部分代码见 preprocessing.py 。

问题B:BASE模型

将random seed设置为0,在前10000个训练样本上学习参数,最后看20个epochs之后的测试集错误率。最后结果为:


  1. Test set: Average loss: 0.0014, Accuracy: 9732/10000 (97.3%) 

可以看到,BASE模型准确率并不是那么的高。

问题C:Batch Normalization v.s BASE

在前三个block的卷积层之后加上Batch Normalization层,简单修改网络结构如下即可:


  1. class Net(nn.Module): 
  2.     def __init__(self): 
  3.         super(Net, self).__init__() 
  4.         self.conv1 = nn.Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1), padding=0) 
  5.         self.conv2 = nn.Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1), padding=0) 
  6.         self.fc1 = nn.Linear(4*4*50, 500) 
  7.         self.fc2 = nn.Linear(500, 10) 
  8.         self.bn1 = nn.BatchNorm2d(20) 
  9.         self.bn2 = nn.BatchNorm2d(50) 
  10.         self.bn3 = nn.BatchNorm1d(500) 
  11.  
  12.     def forward(self, x): 
  13.         x = self.conv1(x) 
  14.         x = F.max_pool2d(self.bn1(x), 2) 
  15.         x = self.conv2(x) 
  16.         x = F.max_pool2d(self.bn2(x), 2) 
  17.         x = x.view(-1, 4*4*50) 
  18.         x = self.fc1(x) 
  19.         x = F.relu(self.bn3(x)) 
  20.         x = self.fc2(x) 
  21.         return F.log_softmax(x) 

同样的参数run一下,得出加了BN的结果为:


  1. Test set: Average loss: 0.0009, Accuracy: 9817/10000 (98.2%) 

由此可见,有明显的效果提升。

关于Batch Normalization的更多资料参见[2],[5]。

问题D: Dropout Layer

在最后一层即 fc2 层后加一个 Dropout(p=0.5) 后,在BASE和BN上的结果分别为:


  1. BASE:Test set: Average loss: 0.0011, Accuracy: 9769/10000 (97.7%) 
  2. BN:  Test set: Average loss: 0.0014, Accuracy: 9789/10000 (97.9%) 

观察得知,dropout能够对BASE模型起到一定提升作用,但是对BN模型却效果不明显反而降低了。

原因可能在于,BN模型中本身即包含了正则化的效果,再加一层Dropout显得没有必要反而可能影响结果。

问题E:SK model

SK model: Stacking two 3x3 conv. layers to replace 5x5 conv. layer

如此一番改动后,搭建的SK模型如下:


  1. class Net(nn.Module): 
  2.     def __init__(self): 
  3.         super(Net, self).__init__() 
  4.         self.conv1_1 = nn.Conv2d(1, 20, kernel_size=(3, 3), stride=(1, 1), padding=0) 
  5.         self.conv1_2 = nn.Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=0) 
  6.         self.conv2 = nn.Conv2d(20, 50, kernel_size=(3, 3), stride=(1, 1), padding=0) 
  7.         self.fc1 = nn.Linear(5*5*50, 500) 
  8.         self.fc2 = nn.Linear(500, 10) 
  9.         self.bn1_1 = nn.BatchNorm2d(20) 
  10.         self.bn1_2 = nn.BatchNorm2d(20) 
  11.         self.bn2 = nn.BatchNorm2d(50) 
  12.         self.bn3 = nn.BatchNorm1d(500) 
  13.         self.drop = nn.Dropout(p=0.5) 
  14.  
  15.     def forward(self, x): 
  16.         x = F.relu(self.bn1_1(self.conv1_1(x))) 
  17.         x = F.relu(self.bn1_2(self.conv1_2(x))) 
  18.         x = F.max_pool2d(x, 2) 
  19.         x = self.conv2(x) 
  20.         x = F.max_pool2d(self.bn2(x), 2) 
  21.         x = x.view(-1, 5*5*50) 
  22.         x = self.fc1(x) 
  23.         x = F.relu(self.bn3(x)) 
  24.         x = self.fc2(x) 
  25.         return F.log_softmax(x) 

在20个epoch后,结果如下,


  1. SK: Test set: Average loss: 0.0008, Accuracy: 9848/10000 (98.5%) 

测试集准确率得到了少许的提高。

这里利用2个3x3的卷积核来代替大的5x5卷积核,参数个数由5x5=25变为了2x3x3=18。实践表明,这样使得计算更快了,并且小的卷积层之间的ReLU也很有帮助。

VGG中就使用了这种方法。

问题F:Change Number of channels

通过将特征图大小乘上一个倍数,再通过shell程序执行,得到如下结果:


  1. SK0.2:  97.7% 
  2. SK0.5:  98.2% 
  3. SK1:    98.5% 
  4. SK1.5:  98.6% 
  5. SK2:    98.5%  (max 98.7%) 

在特征图分别为4,10, 30, 40时,最终的准确度基本是往上提升的。这在一定程度上说明,在没有达到过拟合前,增大特征图的个数,即相当于提取了更多的特征,提取特征数的增加有助于精度的提高。

这部分代码见 SK_s.py 和 runSK.sh 。

问题G:Use different training set sizes

同样通过脚本运行,增加参数


  1. parser.add_argument('--usedatasize', type=int, default=60000, metavar='SZ', 
  2.                     help='use how many training data to train network') 

表示使用的数据大小,从前往后取 usebatchsize 个数据。

这部分程序见 SK_s.py 和 runTrainingSize.sh 。

运行的结果如下:


  1. 500:   84.2% 
  2. 1000:  92.0% 
  3. 2000:  94.3% 
  4. 5000:  95.5% 
  5. 10000: 96.6% 
  6. 20000: 98.4% 
  7. 60000: 99.1% 

由此可以明显地看出,数据越多,结果的精度越大。

太少的数据无法准确反映数据的整体分布情况,而且容易过拟合,数据多到一定程度效果也会不明显,不过,大多数时候我们总还是嫌数据太少,而且更多的数据获取起来也有一定难度。

问题H:Use different training sets

采用脚本完成,这部分程序见 SK_0.2.py 和 diffTrainingSets.sh 。

运行结果如下:


  1.  0-10000: 98.0% 
  2. 10000-20000: 97.8% 
  3. 20000-30000: 97.8% 
  4. 30000-40000: 97.4% 
  5. 40000-50000: 97.5% 
  6. 50000-60000: 97.7% 

由此可见,采用不同的训练样本集合训练出来的网络有一定的差异,虽不是很大,但是毕竟显示出了不稳定的结果。

问题I:Random Seed’s effects

采用 runSeed.sh 脚本完成,用到了全部60000个训练集。

运行的结果如下:


  1. Seed      0:  98.9% 
  2. Seed      1:  99.0% 
  3. Seed     12:  99.1% 
  4. Seed    123:  99.0% 
  5. Seed   1234:  99.1% 
  6. Seed  12345:  99.0% 
  7. Seed 123456:  98.9% 

事实上在用上整个训练集的时候,随机数生成器的种子设置对于最后结果的影响不大。

问题J:ReLU or Sigmoid?

将ReLU全部换成Sigmoid后,用全部60000个训练集训练,有对比结果如下:


  1. ReLU SK_0.2:  99.0% 
  2. igmoid SK_0.2:  98.6% 

由此可以看出,在训练CNN时,使用ReLU激活单元比Sigmoid激活单元要更好一些。原因可能在于二者机制的差别,sigmoid在神经元输入值较大或者较小时,输出值会近乎0或者1,这使得许多地方的梯度几乎为0,权重几乎得不到更新。而ReLU虽然增加了计算的负担,但是它能够显著加速收敛过程,并且也不会有梯度饱和问题。

作者:佚名

来源:51CTO

时间: 2024-10-21 13:22:21

利用Pytorch进行CNN详细剖析的相关文章

《XSS跨站脚本攻击剖析与防御》—第6章6.4节利用Flash进行XSS攻击剖析

6.4 利用Flash进行XSS攻击剖析 XSS跨站脚本攻击剖析与防御 利用嵌入Web页面中的Flash进行XSS有一个决定因素:allowScriptAccess属性.allowScriptAccess是使用或 下面是一个简单的示例: allowScriptAccess属性控制着Flash与HTML页面的通信,可选的值有3个: always:允许随时执行脚本操作 never:禁止所有脚本执行操作 samedomain:只有在Flash 应用程序来自与HTML页相同的域时才允许执行脚本操作 其属

【ANDROID游戏开发之十】(优化处理)详细剖析ANDROID TRACEVIEW效率检视工具,分析程序运行速度!并讲解两种创建SDCARD方式!

本站文章均为 李华明Himi 原创,转载务必在明显处注明:  转载自[黑米GameDev街区] 原文链接: http://www.himigame.com/android-game/316.html ----------------------- 『很多童鞋说我的代码运行后,点击home或者back后会程序异常,如果你也这样遇到过,那么你肯定没有仔细读完Himi的博文,第十九篇Himi专门写了关于这些错误的原因和解决方法,这里我在博客都补充说明下,省的童鞋们总疑惑这一块:请点击下面联系进入阅读:

XMLHttpRequest对象详细剖析

XMLHttpRequest对象是当今所有AJAX和Web 2.0应用程序的技术基础.尽管软件经销商和开源社团现在都在提供各种AJAX框架以进一步简化XMLHttpRequest对象的使用:但是,我们仍然很有必要理解这个对象的详细工作机制. 一. 引言 异步JavaScript与XML(AJAX)是一个专用术语,用于实现在客户端脚本与服务器之间的数据交互过程.这一技术的优点在于,它向开发者提供了一种从Web服务器检索数据而不必把用户当前正在观察的页面回馈给服务器.与现代浏览器的通过存取浏览器DO

从N层到.NET详细剖析原理(1)

简介 如今,N 层应用程序已经成为构建企业软件的标准.对于大多数人来说,N 层应用程序就是被分成多个独立的逻辑部分的应用程序.最常见的选择是分为三个部分:表示.业务逻辑和数据,当然还可能存在其他的划分方法.N 层应用程序最初是为了解决与传统的客户端/服务器应用程序相关的问题而出现的,但是,随着 Web 时代的到来,这一体系结构开始成为新开发项目的主流. Microsoft Windows? DNA 技术已成为 N 层应用程序的非常成功的基础.Microsoft .NET 框架也为构建 N 层应用

网友原创:从N层到.NET详细剖析原理

原创 摘要:讨论 Microsoft .net 的应用程序设计和所需的更改:检验从使用 Microsoft Windows DNA 构建 N 层应用程序中学到的结构知识,以及如何将这些知识应用到使用 Microsoft .NET 框架构建的应用程序,并且为使用 XML Web Services 的应用程序提供体系结构方面的建议. 简介 如今,N 层应用程序已经成为构建企业软件的标准.对于大多数人来说,N 层应用程序就是被分成多个独立的逻辑部分的应用程序.最常见的选择是分为三个部分:表示.业务逻辑

python进程池详细剖析教程

python中两个常用来处理进程的模块分别是subprocess和multiprocessing,其中subprocess通常用于执行外部程序,比如一些第三方应用程序,而不是Python程序.如果需要实现调用外部程序的功能,python的psutil模块是更好的选择,它不仅支持subprocess提供的功能,而且还能对当前主机或者启动的外部程序进行监控,比如获取网络.cpu.内存等信息使用情况,在做一些自动化运维工作时支持的更加全面.multiprocessing是python的多进程模块,主要

MySQL 5.7在高并发下性能劣化问题的详细剖析

TL;DR MySQL 5.7为了提升只读事务的性能改进了MVCC机制,虽然在只读场景下能获得很好的收益,但是在读写混合的高并发场景下却带来了性能劣化,导致的结果就是rt飙升和业务端超时.本文剖析了此问题背后的原因,并给出了解决办法. 引言 MySQL 5.7自发布以来备受关注,不仅是因为5.7的在功能特性上大大丰富,它的读写性能上相对于之前的版本也有了很大提升.正是由于5.7卓越的表现,我们自去年起就开始着手将AliSQL整体搬迁到5.7上.然而经过一年多的整合测试我们发现,5.7宣称的有些能

利用表格制作网页详细介绍

将一定的内容按特定的行.列规则进行排列就构成了表格.无论在日常生活和工作中,还是在网页设计中,表格通常都可以使信息更容易理解.HTML 具有很强的表格功能,使用户可以方便地创建出各种规格的表格,并能对表格进行特定的修饰,从而使网页更加生动活泼.HTML 表格模型使用户可以将各种数据(包括文本.预格式化文本.图像.链接.表单.表单域以及其他表格等)排成行和列,从而获得特定的表格效果. 表格在网页设计中的地位非常重要,可以说如果您表格用不好的话,就不可能设计出出色的网页.大多数初学者一开始就接触表格

【ANDROID游戏开发之三】详细剖析 SURFACEVIEW ! CALLBACK以及SURFACEHOLDER!!

本站文章均为 李华明Himi 原创,转载务必在明显处注明:  转载自[黑米GameDev街区] 原文链接: http://www.himigame.com/android-game/296.html ----------------------- 『很多童鞋说我的代码运行后,点击home或者back后会程序异常,如果你也这样遇到过,那么你肯定没有仔细读完Himi的博文,第十九篇Himi专门写了关于这些错误的原因和解决方法,这里我在博客都补充说明下,省的童鞋们总疑惑这一块:请点击下面联系进入阅读: