前面一篇内容讲解了如何利用Pytorch实现ResNet,这一篇我们用ResNet18实现一个二分类。接下来从模型、数据及训练三个方面展开。
利用ResNet18将以下数据分为两类
- class_0
- class_1

ResNet系列的模型在上一篇已经详细介绍了,这里采用ResNet18。
1. 模型导入
在库中已经有一些常用模型,我们这里直接引入即可。
讯享网
2. 修改输出层
可以看到库里面自带的ResNet模型最后经过全局最大池化后接的输出是1000类,但这里只有两类,所以需要对最后输出层进行修改。
讯享网
3. 模型可视化
为了更直观理解网络,这里采用查看网络结构。下图是网络前面几层的结构图。


模型确定以后,我们接下来依据模型输入,制作数据集。如下图所示,原始论文中输入大小为224*224,经过5次卷积后特征图大小依次变为112 * 112 ==> 56 * 56 ==> 28 * 28 ==> 14 * 14 ==> 7 * 7,最后经过全局池化变为 1 * 1 共512维。由于这里设计了全局池化层,所以对输入不一定限制为224 * 224的大小。

这里考虑到数据本身比较小,因此输入大小统一为64 * 64。接下来依据以上内容一步一步实现数据集制作。
1. 原始数据分文件存储
将原始图片按类型分别存在不同的文件夹下,其目录结构如下
2. 数据预处理
通道转换,将图片转为RGB格式,(png图片读取会变成RGBA)
讯享网
考虑到原始图片可能大小不一,这里需要进行缩放,将其变为64 * 64
为了训练时更快的收敛,这里对输入图片进行归一化处理,即减去均值后除以方差。
讯享网
3. 数据增强
由于数据量较少,这里对数据集进行增强处理,进行旋转和裁剪
4. 数据加载器
pytorch提供了数据加载器,定义自己数据集的时候只需要继承Dataset类,然后重写,和三个方法即可,其中可以用来初始化一些变量,返回数据集大小, 返回指定索引对应的数据。

讯享网
接下来我们依据数据集编写数据类
- mydataset.py
5. 测试及可视化
验证数据类是否正确,指定索引后,利用进行绘图,并打印出相应标签
讯享网

经过上面的讨论,已经定义好模型和数据集,接下来实现模型训练。按照pytorch框架,需要有优化器以及损失函数,这里依次展开。
1. 定义损失函数
这里采用交叉熵损失,也可以根据实际需求进行修改。
2. 定义优化器
优化器这里采用Adam
讯享网
3. 数据集划分及加载器
将数据集划分为训练集和验证集,这里依据给定比例进行随机划分。
4. 训练
讯享网
5. 可视化训练结果
借助工具监控训练过程,也可以采用等工具。

五、模型测试
模型训练好之后我们得到模型权重, 要实现测试,只需要准备好测试图片,执行以下脚本即可。
小结
借助pytorch训练模型,大体可以分为三个步骤,第一步先确定好数据集,第二步依据数据集定义好模型的输入输出,第三步定义好损失函数和优化器后进行训练,这三个步骤都要用好可视化工具,便于检查及监控训练过程。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/186133.html