CIFAR-10数据集: https://www.cs.toronto.edu/~kriz/cifar.html
1 2 3 4 5 6
| import numpy as np import pandas as pd import keras from keras import layers import matplotlib.pyplot as plt import keras.datasets.cifar10 as cifar
|
1. 准备数据:
1
| (train_image,train_label),(test_image,test_label)=cifar.load_data()
|
1 2 3
| train_image=train_image/255 test_image = test_image/255
|
(10000, 32, 32, 3)
1
| plt.imshow(train_image[2])
|
<matplotlib.image.AxesImage at 0x26981d7c220>

array([9], dtype=uint8)
1 2
| model = keras.Sequential() model.add(layers.Conv2D())
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
| model.add(layers.Conv2D(64,(3,3),input_shape=(32,32,3),activation='relu')) model.add(layers.Conv2D(64,(3,3),activation='relu')) model.add(layers.MaxPool2D()) model.add(layers.Conv2D(64,(3,3),activation='relu')) model.add(layers.Conv2D(64,(3,3),activation='relu')) model.add(layers.MaxPool2D())
model.add(layers.Flatten())
model.add(layers.Dense(256,activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(256,activation='relu')) model.add(layers.Dropout(0.5)) model.add(layers.Dense(10,activation='softmax'))
|
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 30, 30, 64) 1792
conv2d_1 (Conv2D) (None, 28, 28, 64) 36928
max_pooling2d (MaxPooling2D (None, 14, 14, 64) 0
)
conv2d_2 (Conv2D) (None, 12, 12, 64) 36928
conv2d_3 (Conv2D) (None, 10, 10, 64) 36928
max_pooling2d_1 (MaxPooling (None, 5, 5, 64) 0
2D)
flatten (Flatten) (None, 1600) 0
dense (Dense) (None, 256) 409856
dropout (Dropout) (None, 256) 0
dense_1 (Dense) (None, 256) 65792
dropout_1 (Dropout) (None, 256) 0
dense_2 (Dense) (None, 10) 2570
=================================================================
Total params: 590,794
Trainable params: 590,794
Non-trainable params: 0
_________________________________________________________________
3. 编译模型:
1
| model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics=['acc'])
|
4. 训练模型:
1
| model.fit(train_image,train_label,epochs=10,batch_size=512,validation_data=(test_image,test_label))
|
Epoch 1/10
98/98 [==============================] - 166s 2s/step - loss: 1.9493 - acc: 0.2648 - val_loss: 193.4561 - val_acc: 0.3232
Epoch 2/10
98/98 [==============================] - 170s 2s/step - loss: 1.5473 - acc: 0.4304 - val_loss: 148.5261 - val_acc: 0.4404
Epoch 3/10
98/98 [==============================] - 176s 2s/step - loss: 1.3480 - acc: 0.5138 - val_loss: 146.2150 - val_acc: 0.4511
Epoch 4/10
98/98 [==============================] - 168s 2s/step - loss: 1.2286 - acc: 0.5620 - val_loss: 200.3967 - val_acc: 0.4550
Epoch 5/10
98/98 [==============================] - 165s 2s/step - loss: 1.1045 - acc: 0.6099 - val_loss: 187.7144 - val_acc: 0.4452
Epoch 6/10
98/98 [==============================] - 164s 2s/step - loss: 1.0141 - acc: 0.6465 - val_loss: 208.9634 - val_acc: 0.4655
Epoch 7/10
98/98 [==============================] - 178s 2s/step - loss: 0.9519 - acc: 0.6722 - val_loss: 105.2857 - val_acc: 0.5780
Epoch 8/10
98/98 [==============================] - 171s 2s/step - loss: 0.8804 - acc: 0.6943 - val_loss: 176.2543 - val_acc: 0.4892
Epoch 9/10
98/98 [==============================] - 173s 2s/step - loss: 0.8249 - acc: 0.7135 - val_loss: 150.6743 - val_acc: 0.5347
Epoch 10/10
98/98 [==============================] - 174s 2s/step - loss: 0.7784 - acc: 0.7310 - val_loss: 111.6608 - val_acc: 0.5940
<keras.callbacks.History at 0x269ee9bd7c0>
5. 评估模型:
1
| model.evaluate(test_image,test_label)
|
313/313 [==============================] - 19s 59ms/step - loss: 2.3031 - acc: 0.1072
[2.3030948638916016, 0.10719999670982361]
1
| model.evaluate(train_image,train_label)
|
1563/1563 [==============================] - 77s 49ms/step - loss: 2.3031 - acc: 0.1059
[2.3030762672424316, 0.10586000233888626]
5. 预测数据:
1
| pred_label = model.predict(test_image)
|
1
| pred=np.argmax(pred_label[0])
|
1 2
| if pred == 4: print("deer")
|
deer
1 2
| plt.figure(figsize=(1,1)) plt.imshow(test_image[3])
|
<matplotlib.image.AxesImage at 0x269823c8280>
