迁移学习是指对提前训练过的神经网络进行调整,以用于新的不同数据集。
1.Keras实现迁移学习实例
先看个Keras实现的迁移学习案例
https://keras-cn.readthedocs.io/en/latest/other/application/
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import pandas as pd
import sklearn
import os
import random
import csv
import tensorflow as tf
from numpy import expand_dims
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from PIL import Image
from keras.preprocessing import image
import keras as
1.1准备训练数据
BATCH_SIZE = 16
# 迭代50次
EPOCHS = 50
# 依照模型规定,图片大小被设定为224
IMAGE_SIZE = 224
TRAIN_PATH = './17flowerclasses/train'
TEST_PATH = './17flowerclasses/test'
#FLOWER_CLASSES = ['Bluebell', 'ButterCup', 'ColtsFoot', 'Cowslip', 'Crocus', 'Daffodil', 'Daisy','Dandelion', 'Fritillary', 'Iris', 'LilyValley', 'Pansy', 'Snowdrop', 'Sunflower','Tigerlily', 'tulip', 'WindFlower']
# 使用数据增强
train_datagen = ImageDataGenerator(
rotation_range=40,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
fill_mode='nearest')
# 可指定输出图片大小,因为深度学习要求训练图片大小保持一致
train_generator = train_datagen.flow_from_directory(directory=TRAIN_PATH,
target_size=(IMAGE_SIZE, IMAGE_SIZE),
batch_size = BATCH_SIZE)#,
#classes=FLOWER_CLASSES)
test_datagen = ImageDataGenerator()
test_generator = test_datagen.flow_from_directory(directory=TEST_PATH,
target_size=(IMAGE_SIZE, IMAGE_SIZE))#,
#classes=FLOWER_CLASSES)
Found 1190 images belonging to 17 classes.
Found 170 images belonging to 17 classes.
1.2设计迁移学习网络
这里运用resnet50模型,把输出层改为链接两个全连接层,组成新的网络。
import keras
from keras.models import Model, Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout,GlobalAveragePooling2D
IMAGE_SIZE = 224
# weights='imagenet'
keras_resnet = keras.applications.resnet50.ResNet50(include_top=False, weights='imagenet',input_shape = (IMAGE_SIZE, IMAGE_SIZE, 3))
print(keras_resnet.output_shape[1:])
resnet_flower = Sequential()
resnet_flower.add(Flatten(input_shape=keras_resnet.output_shape[1:]))
resnet_flower.add(Dense(1024, activation="relu"))
resnet_flower.add(Dropout(0.5))
resnet_flower.add(Dense(17, activation='softmax'))
resnet_flower_model = Model(inputs=keras_resnet.input, outputs=resnet_flower(keras_resnet.output))
resnet_flower_model.summary()
/home/leon/anaconda3/lib/python3.7/site-packages/keras_applications/resnet50.py:265: UserWarning: The output shape of `ResNet50(include_top=False)` has been changed since Keras 2.2.0.
warnings.warn('The output shape of `ResNet50(include_top=False)` '
(7, 7, 2048)
Model: "model_9"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
input_11 (InputLayer) (None, 224, 224, 3) 0
__________________________________________________________________________________________________
conv1_pad (ZeroPadding2D) (None, 230, 230, 3) 0 input_11[0][0]
__________________________________________________________________________________________________
conv1 (Conv2D) (None, 112, 112, 64) 9472 conv1_pad[0][0]
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization) (None, 112, 112, 64) 256 conv1[0][0]
__________________________________________________________________________________________________
activation_189 (Activation) (None, 112, 112, 64) 0 bn_conv1[0][0]
__________________________________________________________________________________________________
pool1_pad (ZeroPadding2D) (None, 114, 114, 64) 0 activation_189[0][0]
__________________________________________________________________________________________________
max_pooling2d_9 (MaxPooling2D) (None, 56, 56, 64) 0 pool1_pad[0][0]
__________________________________________________________________________________________________
res2a_branch2a (Conv2D) (None, 56, 56, 64) 4160 max_pooling2d_9[0][0]
__________________________________________________________________________________________________
bn2a_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2a_branch2a[0][0]
__________________________________________________________________________________________________
activation_190 (Activation) (None, 56, 56, 64) 0 bn2a_branch2a[0][0]
__________________________________________________________________________________________________
res2a_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_190[0][0]
__________________________________________________________________________________________________
bn2a_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2a_branch2b[0][0]
__________________________________________________________________________________________________
activation_191 (Activation) (None, 56, 56, 64) 0 bn2a_branch2b[0][0]
__________________________________________________________________________________________________
res2a_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_191[0][0]
__________________________________________________________________________________________________
res2a_branch1 (Conv2D) (None, 56, 56, 256) 16640 max_pooling2d_9[0][0]
__________________________________________________________________________________________________
bn2a_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2a_branch2c[0][0]
__________________________________________________________________________________________________
bn2a_branch1 (BatchNormalizatio (None, 56, 56, 256) 1024 res2a_branch1[0][0]
__________________________________________________________________________________________________
add_1 (Add) (None, 56, 56, 256) 0 bn2a_branch2c[0][0]
bn2a_branch1[0][0]
__________________________________________________________________________________________________
activation_192 (Activation) (None, 56, 56, 256) 0 add_1[0][0]
__________________________________________________________________________________________________
res2b_branch2a (Conv2D) (None, 56, 56, 64) 16448 activation_192[0][0]
__________________________________________________________________________________________________
bn2b_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2b_branch2a[0][0]
__________________________________________________________________________________________________
activation_193 (Activation) (None, 56, 56, 64) 0 bn2b_branch2a[0][0]
__________________________________________________________________________________________________
res2b_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_193[0][0]
__________________________________________________________________________________________________
bn2b_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2b_branch2b[0][0]
__________________________________________________________________________________________________
activation_194 (Activation) (None, 56, 56, 64) 0 bn2b_branch2b[0][0]
__________________________________________________________________________________________________
res2b_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_194[0][0]
__________________________________________________________________________________________________
bn2b_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2b_branch2c[0][0]
__________________________________________________________________________________________________
add_2 (Add) (None, 56, 56, 256) 0 bn2b_branch2c[0][0]
activation_192[0][0]
__________________________________________________________________________________________________
activation_195 (Activation) (None, 56, 56, 256) 0 add_2[0][0]
__________________________________________________________________________________________________
res2c_branch2a (Conv2D) (None, 56, 56, 64) 16448 activation_195[0][0]
__________________________________________________________________________________________________
bn2c_branch2a (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2a[0][0]
__________________________________________________________________________________________________
activation_196 (Activation) (None, 56, 56, 64) 0 bn2c_branch2a[0][0]
__________________________________________________________________________________________________
res2c_branch2b (Conv2D) (None, 56, 56, 64) 36928 activation_196[0][0]
__________________________________________________________________________________________________
bn2c_branch2b (BatchNormalizati (None, 56, 56, 64) 256 res2c_branch2b[0][0]
__________________________________________________________________________________________________
activation_197 (Activation) (None, 56, 56, 64) 0 bn2c_branch2b[0][0]
__________________________________________________________________________________________________
res2c_branch2c (Conv2D) (None, 56, 56, 256) 16640 activation_197[0][0]
__________________________________________________________________________________________________
bn2c_branch2c (BatchNormalizati (None, 56, 56, 256) 1024 res2c_branch2c[0][0]
__________________________________________________________________________________________________
add_3 (Add) (None, 56, 56, 256) 0 bn2c_branch2c[0][0]
activation_195[0][0]
__________________________________________________________________________________________________
activation_198 (Activation) (None, 56, 56, 256) 0 add_3[0][0]
__________________________________________________________________________________________________
res3a_branch2a (Conv2D) (None, 28, 28, 128) 32896 activation_198[0][0]
__________________________________________________________________________________________________
bn3a_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2a[0][0]
__________________________________________________________________________________________________
activation_199 (Activation) (None, 28, 28, 128) 0 bn3a_branch2a[0][0]
__________________________________________________________________________________________________
res3a_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_199[0][0]
__________________________________________________________________________________________________
bn3a_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3a_branch2b[0][0]
__________________________________________________________________________________________________
activation_200 (Activation) (None, 28, 28, 128) 0 bn3a_branch2b[0][0]
__________________________________________________________________________________________________
res3a_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_200[0][0]
__________________________________________________________________________________________________
res3a_branch1 (Conv2D) (None, 28, 28, 512) 131584 activation_198[0][0]
__________________________________________________________________________________________________
bn3a_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3a_branch2c[0][0]
__________________________________________________________________________________________________
bn3a_branch1 (BatchNormalizatio (None, 28, 28, 512) 2048 res3a_branch1[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 28, 28, 512) 0 bn3a_branch2c[0][0]
bn3a_branch1[0][0]
__________________________________________________________________________________________________
activation_201 (Activation) (None, 28, 28, 512) 0 add_4[0][0]
__________________________________________________________________________________________________
res3b_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_201[0][0]
__________________________________________________________________________________________________
bn3b_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2a[0][0]
__________________________________________________________________________________________________
activation_202 (Activation) (None, 28, 28, 128) 0 bn3b_branch2a[0][0]
__________________________________________________________________________________________________
res3b_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_202[0][0]
__________________________________________________________________________________________________
bn3b_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3b_branch2b[0][0]
__________________________________________________________________________________________________
activation_203 (Activation) (None, 28, 28, 128) 0 bn3b_branch2b[0][0]
__________________________________________________________________________________________________
res3b_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_203[0][0]
__________________________________________________________________________________________________
bn3b_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3b_branch2c[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 28, 28, 512) 0 bn3b_branch2c[0][0]
activation_201[0][0]
__________________________________________________________________________________________________
activation_204 (Activation) (None, 28, 28, 512) 0 add_5[0][0]
__________________________________________________________________________________________________
res3c_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_204[0][0]
__________________________________________________________________________________________________
bn3c_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2a[0][0]
__________________________________________________________________________________________________
activation_205 (Activation) (None, 28, 28, 128) 0 bn3c_branch2a[0][0]
__________________________________________________________________________________________________
res3c_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_205[0][0]
__________________________________________________________________________________________________
bn3c_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3c_branch2b[0][0]
__________________________________________________________________________________________________
activation_206 (Activation) (None, 28, 28, 128) 0 bn3c_branch2b[0][0]
__________________________________________________________________________________________________
res3c_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_206[0][0]
__________________________________________________________________________________________________
bn3c_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3c_branch2c[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 28, 28, 512) 0 bn3c_branch2c[0][0]
activation_204[0][0]
__________________________________________________________________________________________________
activation_207 (Activation) (None, 28, 28, 512) 0 add_6[0][0]
__________________________________________________________________________________________________
res3d_branch2a (Conv2D) (None, 28, 28, 128) 65664 activation_207[0][0]
__________________________________________________________________________________________________
bn3d_branch2a (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2a[0][0]
__________________________________________________________________________________________________
activation_208 (Activation) (None, 28, 28, 128) 0 bn3d_branch2a[0][0]
__________________________________________________________________________________________________
res3d_branch2b (Conv2D) (None, 28, 28, 128) 147584 activation_208[0][0]
__________________________________________________________________________________________________
bn3d_branch2b (BatchNormalizati (None, 28, 28, 128) 512 res3d_branch2b[0][0]
__________________________________________________________________________________________________
activation_209 (Activation) (None, 28, 28, 128) 0 bn3d_branch2b[0][0]
__________________________________________________________________________________________________
res3d_branch2c (Conv2D) (None, 28, 28, 512) 66048 activation_209[0][0]
__________________________________________________________________________________________________
bn3d_branch2c (BatchNormalizati (None, 28, 28, 512) 2048 res3d_branch2c[0][0]
__________________________________________________________________________________________________
add_7 (Add) (None, 28, 28, 512) 0 bn3d_branch2c[0][0]
activation_207[0][0]
__________________________________________________________________________________________________
activation_210 (Activation) (None, 28, 28, 512) 0 add_7[0][0]
__________________________________________________________________________________________________
res4a_branch2a (Conv2D) (None, 14, 14, 256) 131328 activation_210[0][0]
__________________________________________________________________________________________________
bn4a_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2a[0][0]
__________________________________________________________________________________________________
activation_211 (Activation) (None, 14, 14, 256) 0 bn4a_branch2a[0][0]
__________________________________________________________________________________________________
res4a_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_211[0][0]
__________________________________________________________________________________________________
bn4a_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4a_branch2b[0][0]
__________________________________________________________________________________________________
activation_212 (Activation) (None, 14, 14, 256) 0 bn4a_branch2b[0][0]
__________________________________________________________________________________________________
res4a_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_212[0][0]
__________________________________________________________________________________________________
res4a_branch1 (Conv2D) (None, 14, 14, 1024) 525312 activation_210[0][0]
__________________________________________________________________________________________________
bn4a_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4a_branch2c[0][0]
__________________________________________________________________________________________________
bn4a_branch1 (BatchNormalizatio (None, 14, 14, 1024) 4096 res4a_branch1[0][0]
__________________________________________________________________________________________________
add_8 (Add) (None, 14, 14, 1024) 0 bn4a_branch2c[0][0]
bn4a_branch1[0][0]
__________________________________________________________________________________________________
activation_213 (Activation) (None, 14, 14, 1024) 0 add_8[0][0]
__________________________________________________________________________________________________
res4b_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_213[0][0]
__________________________________________________________________________________________________
bn4b_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2a[0][0]
__________________________________________________________________________________________________
activation_214 (Activation) (None, 14, 14, 256) 0 bn4b_branch2a[0][0]
__________________________________________________________________________________________________
res4b_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_214[0][0]
__________________________________________________________________________________________________
bn4b_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4b_branch2b[0][0]
__________________________________________________________________________________________________
activation_215 (Activation) (None, 14, 14, 256) 0 bn4b_branch2b[0][0]
__________________________________________________________________________________________________
res4b_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_215[0][0]
__________________________________________________________________________________________________
bn4b_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4b_branch2c[0][0]
__________________________________________________________________________________________________
add_9 (Add) (None, 14, 14, 1024) 0 bn4b_branch2c[0][0]
activation_213[0][0]
__________________________________________________________________________________________________
activation_216 (Activation) (None, 14, 14, 1024) 0 add_9[0][0]
__________________________________________________________________________________________________
res4c_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_216[0][0]
__________________________________________________________________________________________________
bn4c_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2a[0][0]
__________________________________________________________________________________________________
activation_217 (Activation) (None, 14, 14, 256) 0 bn4c_branch2a[0][0]
__________________________________________________________________________________________________
res4c_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_217[0][0]
__________________________________________________________________________________________________
bn4c_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4c_branch2b[0][0]
__________________________________________________________________________________________________
activation_218 (Activation) (None, 14, 14, 256) 0 bn4c_branch2b[0][0]
__________________________________________________________________________________________________
res4c_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_218[0][0]
__________________________________________________________________________________________________
bn4c_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4c_branch2c[0][0]
__________________________________________________________________________________________________
add_10 (Add) (None, 14, 14, 1024) 0 bn4c_branch2c[0][0]
activation_216[0][0]
__________________________________________________________________________________________________
activation_219 (Activation) (None, 14, 14, 1024) 0 add_10[0][0]
__________________________________________________________________________________________________
res4d_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_219[0][0]
__________________________________________________________________________________________________
bn4d_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2a[0][0]
__________________________________________________________________________________________________
activation_220 (Activation) (None, 14, 14, 256) 0 bn4d_branch2a[0][0]
__________________________________________________________________________________________________
res4d_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_220[0][0]
__________________________________________________________________________________________________
bn4d_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4d_branch2b[0][0]
__________________________________________________________________________________________________
activation_221 (Activation) (None, 14, 14, 256) 0 bn4d_branch2b[0][0]
__________________________________________________________________________________________________
res4d_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_221[0][0]
__________________________________________________________________________________________________
bn4d_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4d_branch2c[0][0]
__________________________________________________________________________________________________
add_11 (Add) (None, 14, 14, 1024) 0 bn4d_branch2c[0][0]
activation_219[0][0]
__________________________________________________________________________________________________
activation_222 (Activation) (None, 14, 14, 1024) 0 add_11[0][0]
__________________________________________________________________________________________________
res4e_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_222[0][0]
__________________________________________________________________________________________________
bn4e_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2a[0][0]
__________________________________________________________________________________________________
activation_223 (Activation) (None, 14, 14, 256) 0 bn4e_branch2a[0][0]
__________________________________________________________________________________________________
res4e_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_223[0][0]
__________________________________________________________________________________________________
bn4e_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4e_branch2b[0][0]
__________________________________________________________________________________________________
activation_224 (Activation) (None, 14, 14, 256) 0 bn4e_branch2b[0][0]
__________________________________________________________________________________________________
res4e_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_224[0][0]
__________________________________________________________________________________________________
bn4e_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4e_branch2c[0][0]
__________________________________________________________________________________________________
add_12 (Add) (None, 14, 14, 1024) 0 bn4e_branch2c[0][0]
activation_222[0][0]
__________________________________________________________________________________________________
activation_225 (Activation) (None, 14, 14, 1024) 0 add_12[0][0]
__________________________________________________________________________________________________
res4f_branch2a (Conv2D) (None, 14, 14, 256) 262400 activation_225[0][0]
__________________________________________________________________________________________________
bn4f_branch2a (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2a[0][0]
__________________________________________________________________________________________________
activation_226 (Activation) (None, 14, 14, 256) 0 bn4f_branch2a[0][0]
__________________________________________________________________________________________________
res4f_branch2b (Conv2D) (None, 14, 14, 256) 590080 activation_226[0][0]
__________________________________________________________________________________________________
bn4f_branch2b (BatchNormalizati (None, 14, 14, 256) 1024 res4f_branch2b[0][0]
__________________________________________________________________________________________________
activation_227 (Activation) (None, 14, 14, 256) 0 bn4f_branch2b[0][0]
__________________________________________________________________________________________________
res4f_branch2c (Conv2D) (None, 14, 14, 1024) 263168 activation_227[0][0]
__________________________________________________________________________________________________
bn4f_branch2c (BatchNormalizati (None, 14, 14, 1024) 4096 res4f_branch2c[0][0]
__________________________________________________________________________________________________
add_13 (Add) (None, 14, 14, 1024) 0 bn4f_branch2c[0][0]
activation_225[0][0]
__________________________________________________________________________________________________
activation_228 (Activation) (None, 14, 14, 1024) 0 add_13[0][0]
__________________________________________________________________________________________________
res5a_branch2a (Conv2D) (None, 7, 7, 512) 524800 activation_228[0][0]
__________________________________________________________________________________________________
bn5a_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2a[0][0]
__________________________________________________________________________________________________
activation_229 (Activation) (None, 7, 7, 512) 0 bn5a_branch2a[0][0]
__________________________________________________________________________________________________
res5a_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_229[0][0]
__________________________________________________________________________________________________
bn5a_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5a_branch2b[0][0]
__________________________________________________________________________________________________
activation_230 (Activation) (None, 7, 7, 512) 0 bn5a_branch2b[0][0]
__________________________________________________________________________________________________
res5a_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_230[0][0]
__________________________________________________________________________________________________
res5a_branch1 (Conv2D) (None, 7, 7, 2048) 2099200 activation_228[0][0]
__________________________________________________________________________________________________
bn5a_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5a_branch2c[0][0]
__________________________________________________________________________________________________
bn5a_branch1 (BatchNormalizatio (None, 7, 7, 2048) 8192 res5a_branch1[0][0]
__________________________________________________________________________________________________
add_14 (Add) (None, 7, 7, 2048) 0 bn5a_branch2c[0][0]
bn5a_branch1[0][0]
__________________________________________________________________________________________________
activation_231 (Activation) (None, 7, 7, 2048) 0 add_14[0][0]
__________________________________________________________________________________________________
res5b_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_231[0][0]
__________________________________________________________________________________________________
bn5b_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2a[0][0]
__________________________________________________________________________________________________
activation_232 (Activation) (None, 7, 7, 512) 0 bn5b_branch2a[0][0]
__________________________________________________________________________________________________
res5b_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_232[0][0]
__________________________________________________________________________________________________
bn5b_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5b_branch2b[0][0]
__________________________________________________________________________________________________
activation_233 (Activation) (None, 7, 7, 512) 0 bn5b_branch2b[0][0]
__________________________________________________________________________________________________
res5b_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_233[0][0]
__________________________________________________________________________________________________
bn5b_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5b_branch2c[0][0]
__________________________________________________________________________________________________
add_15 (Add) (None, 7, 7, 2048) 0 bn5b_branch2c[0][0]
activation_231[0][0]
__________________________________________________________________________________________________
activation_234 (Activation) (None, 7, 7, 2048) 0 add_15[0][0]
__________________________________________________________________________________________________
res5c_branch2a (Conv2D) (None, 7, 7, 512) 1049088 activation_234[0][0]
__________________________________________________________________________________________________
bn5c_branch2a (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2a[0][0]
__________________________________________________________________________________________________
activation_235 (Activation) (None, 7, 7, 512) 0 bn5c_branch2a[0][0]
__________________________________________________________________________________________________
res5c_branch2b (Conv2D) (None, 7, 7, 512) 2359808 activation_235[0][0]
__________________________________________________________________________________________________
bn5c_branch2b (BatchNormalizati (None, 7, 7, 512) 2048 res5c_branch2b[0][0]
__________________________________________________________________________________________________
activation_236 (Activation) (None, 7, 7, 512) 0 bn5c_branch2b[0][0]
__________________________________________________________________________________________________
res5c_branch2c (Conv2D) (None, 7, 7, 2048) 1050624 activation_236[0][0]
__________________________________________________________________________________________________
bn5c_branch2c (BatchNormalizati (None, 7, 7, 2048) 8192 res5c_branch2c[0][0]
__________________________________________________________________________________________________
add_16 (Add) (None, 7, 7, 2048) 0 bn5c_branch2c[0][0]
activation_234[0][0]
__________________________________________________________________________________________________
activation_237 (Activation) (None, 7, 7, 2048) 0 add_16[0][0]
__________________________________________________________________________________________________
sequential_9 (Sequential) (None, 17) 102778897 activation_237[0][0]
==================================================================================================
Total params: 126,366,609
Trainable params: 126,313,489
Non-trainable params: 53,120
__________________________________________________________________________________________________
1.3编译网络
from keras.callbacks import ModelCheckpoint
# train the model
checkpointer = ModelCheckpoint(filepath='flowers.weights.best.hdf5', verbose=1,
save_best_only=True)
resnet_flower_model.compile(optimizer='rmsprop',loss='categorical_crossentropy', metrics=['accuracy'])
1.4训练网络
#model.fit(X_train,y_train,validation_split=0.2,shuffle=True,epochs=20)
history_object = resnet_flower_model.fit_generator(train_generator,validation_data=test_generator, epochs=EPOCHS)
Epoch 1/50
75/75 [==============================] - 303s 4s/step - loss: 7.7597 - accuracy: 0.2034 - val_loss: 85776.1953 - val_accuracy: 0.0588
Epoch 2/50
75/75 [==============================] - 297s 4s/step - loss: 3.2533 - accuracy: 0.2672 - val_loss: 1551.2539 - val_accuracy: 0.0588
Epoch 3/50
75/75 [==============================] - 300s 4s/step - loss: 2.7691 - accuracy: 0.3202 - val_loss: 460.1471 - val_accuracy: 0.0588
Epoch 4/50
75/75 [==============================] - 299s 4s/step - loss: 2.2212 - accuracy: 0.3748 - val_loss: 6.4425 - val_accuracy: 0.1647
Epoch 5/50
75/75 [==============================] - 299s 4s/step - loss: 2.0888 - accuracy: 0.4277 - val_loss: 6.9810 - val_accuracy: 0.2412
Epoch 6/50
75/75 [==============================] - 300s 4s/step - loss: 1.8929 - accuracy: 0.4672 - val_loss: 71.1725 - val_accuracy: 0.1294
Epoch 7/50
75/75 [==============================] - 299s 4s/step - loss: 1.7380 - accuracy: 0.4815 - val_loss: 22.3084 - val_accuracy: 0.1529
Epoch 8/50
75/75 [==============================] - 301s 4s/step - loss: 1.6882 - accuracy: 0.5235 - val_loss: 6.6994 - val_accuracy: 0.4706
Epoch 9/50
75/75 [==============================] - 301s 4s/step - loss: 1.5664 - accuracy: 0.5555 - val_loss: 3.4593 - val_accuracy: 0.4176
Epoch 10/50
56/75 [=====================>........] - ETA: 1:13 - loss: 1.4389 - accuracy: 0.5700
---------------------------------------------------------------------------
2.迁移学习分类主要取决于以下两个条件:
1.新数据集的大小,以及
2.新数据集与原始数据集的相似程度
使用迁移学习的方法将各不相同。有以下四大主要情形:
新数据集很小,新数据与原始数据相似
新数据集很小,新数据不同于原始训练数据
新数据集很大,新数据与原始训练数据相似
新数据集很大,新数据不同于原始训练数据
大型数据集可能具有 100 万张图片。小型数据集可能有 2000 张图片。大型数据集与小型数据集之间的界限比较主观。对小型数据集使用迁移学习需要考虑过拟合现象。
狗的图片和狼的图片可以视为相似的图片;这些图片具有共同的特征。鲜花图片数据集不同于狗类图片数据集。
四个迁移学习情形均具有自己的方法。
演示网络
为了解释每个情形的工作原理,我们将以一个普通的预先训练过的卷积神经网络开始,并解释如何针对每种情形调整该网络。我们的示例网络包含三个卷积层和三个完全连接层:
下面是卷积神经网络的作用一般概述:
第一层级将检测图片中的边缘
第二层级将检测形状
第三个卷积层将检测更高级的特征
每个迁移学习情形将以不同的方式使用预先训练过的神经网络。
2.1 情形 1:小数据集,相似数据
如果新数据集很小,并且与原始训练数据相似:
删除神经网络的最后层级
添加一个新的完全连接层,与新数据集中的类别数量相匹配
随机化设置新的完全连接层的权重;冻结预先训练过的网络中的所有权重
训练该网络以更新新连接层的权重;
为了避免小数据集出现过拟合现象,原始网络的权重将保持不变,而不是重新训练这些权重。
因为数据集比较相似,每个数据集的图片将具有相似的更高级别特征。因此,大部分或所有预先训练过的神经网络层级已经包含关于新数据集的相关信息,应该保持不变。
以下是如何可视化此方法的方式:
2.2 情形 2:小型数据集、不同的数据
如果新数据集很小,并且与原始训练数据不同:
将靠近网络开头的大部分预先训练过的层级删掉
向剩下的预先训练过的层级添加新的完全连接层,并与新数据集的类别数量相匹配
随机化设置新的完全连接层的权重;冻结预先训练过的网络中的所有权重
训练该网络以更新新连接层的权重
因为数据集很小,因此依然需要注意过拟合问题。要解决过拟合问题,原始神经网络的权重应该保持不变,就像第一种情况那样。
但是原始训练集和新的数据集并不具有相同的更高级特征。在这种情况下,新的网络仅使用包含更低级特征的层级。
以下是如何可视化此方法的方式:
2.3 情形 3:大型数据集、相似数据
如果新数据集比较大型,并且与原始训练数据相似:
删掉最后的完全连接层,并替换成与新数据集中的类别数量相匹配的层级
随机地初始化新的完全连接层的权重
使用预先训练过的权重初始化剩下的权重
重新训练整个神经网络
训练大型数据集时,过拟合问题不严重;因此,你可以重新训练所有权重。
因为原始训练集和新的数据集具有相同的更高级特征,因此使用整个神经网络。
以下是如何可视化此方法的方式:
2.4 情形 4:大型数据集、不同的数据
如果新数据集很大型,并且与原始训练数据不同:
删掉最后的完全连接层,并替换成与新数据集中的类别数量相匹配的层级
使用随机初始化的权重重新训练网络
或者,你可以采用和“大型相似数据”情形的同一策略
虽然数据集与训练数据不同,但是利用预先训练过的网络中的权重进行初始化可能使训练速度更快。因此这种情形与大型相似数据集这一情形完全相同。
如果使用预先训练过的网络作为起点不能生成成功的模型,另一种选择是随机地初始化卷积神经网络权重,并从头训练网络。
以下是如何可视化此方法的方式:
参考资料
参阅这篇 研究论文,该论文系统地分析了预先训练过的 CNN 中的特征的可迁移性。
https://arxiv.org/pdf/1411.1792.pdf
阅读这篇详细介绍 Sebastian Thrun 的癌症检测 CNN 的《自然》论文!
这是提议将 GAP 层级用于对象定位的首篇研究论文。
http://cnnlocalization.csail.mit.edu/Zhou_Learning_Deep_Features_CVPR_2016_paper.pdf
参阅这个使用 CNN 进行对象定位的资源库。
https://github.com/alexisbcook/ResNetCAM-keras
观看这个关于使用 CNN 进行对象定位的视频演示(Youtube链接,国内网络可能打不开)。
https://www.youtube.com/watch?v=fZvOy0VXWAI
参阅这个使用可视化机器更好地理解瓶颈特征的资源库。
https://github.com/alexisbcook/keras_transfer_cifar10