1. 首先是数据集的下载和使用
下载地址: http://www.cs.toronto.edu/~kriz/cifar.html
下载完成后无需解压
直接调用语句即可读取数据集
#具体语句如下 cifar10 = tf.keras.datasets.cifar10 #使用内置API keras下载数据集(速度缓慢) (x_train, y_train), (x_test, y_test) = cifar10.load_data() # 直接读取数据即可
讯享网
2.构建CNN模型
CIFA_10数据集是由 60000张RGB彩色图片构成
一共有10类分别为["airplane","automobile","bird","cat","deer","dog","frog","horse","ship","truck"]
其中50000张图片为训练集,10000张为测试集图片
在读取数据后我们可以输出测试集和训练集的数组大小
讯享网print(x_train.shape, y_train.shape) print(x_test.shape, y_test.shape)
运行结果为

讯享网
在送入卷积神经网络前先进行处理,使图片中的每个像素值都处于0~1之间
并转化为tf张量
x_train, x_test = tf.cast(x_train, dtype=tf.float32)/255.0, tf.cast(x_test, dtype=tf.float32)/255.0 y_train, y_test = tf.cast(y_train, dtype=tf.int32), tf.cast(y_test, dtype=tf.int32)
构建训练模型
讯享网# 构建Sequential模型 #建立模型 model model = Sequential([ #特征提取层1 layers.Conv2D(16, kernel_size=(3,3), padding="same", activation=tf.nn.relu,input_shape=x_train.shape[1:]), layers.Conv2D(16, kernel_size=(3,3), padding="same", activation=tf.nn.relu), layers.MaxPool2D(pool_size=(2,2)), layers.Dropout(0.2), #特征提取层2 layers.Conv2D(32, kernel_size=(3,3), padding="same", activation=tf.nn.relu), layers.Conv2D(32, kernel_size=(3,3), padding="same", activation=tf.nn.relu), layers.MaxPool2D(pool_size=(2,2)), layers.Dropout(0.2), #全连接层 layers.Flatten(), layers.Dropout(0.2), layers.Dense(128,activation='relu'), layers.Dropout(0.2), layers.Dense(10,activation="softmax"), ])
结构如图所示


4. 配置训练方法
使用 model.complie来配置训练的方法
# 配置方法 model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
5.训练模型
讯享网history = model.fit(x_train, y_train, batch_size=64, epochs=5, validation_split=0.2)
结果如图所示:

可见准确率达到了 65%
6.评估模型
对测试集使用 evaluate函数进行模型的评估 可以得到该模型参数对测试集的图片准确率
model.evaluate(x_test, y_test, verbose=2)

准确率为 68%
7.测试
从测试集中随机抽取几张图片进行识别
讯享网plt.figure(figsize=(10,10)) for i in range(4): num = np.random.randint(1,10000) plt.subplot(1,4,i+1) plt.axis('off') plt.imshow(x_test[num],cmap='gray') demo = tf.reshape(x_test[num],(1,32,32,3)) y_pred = name[np.argmax(model.predict(demo))] plt.title('Original: ' + name[(y_test.numpy())[num,0]] + '\nPredict: ' + y_pred) plt.show()

可以看到此次的识别 对了两个错了两个
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/27898.html