1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
| import os import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from PIL import Image from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_txt = './second/train_label.txt' x_train_savepath = './second/model_x_train.npy' y_train_savepath = './second/model_y_train.npy'
test_txt = './second/test_label.txt' x_test_savepath = './second/model_x_test.npy' y_test_savepath = './second/model_y_test.npy' checkpoint_save_path = "./checkpoint/model_data.ckpt"
image_gen_train = ImageDataGenerator( rotation_range=90, width_shift_range=.15, height_shift_range=.15, zoom_range=0.5 )
def generateds(txt): f = open(txt, 'r') contents = f.readlines() f.close() x, y_ = [], [] for content in contents: print(content) value = content.split(",") img_path = value[0] print(img_path) img = Image.open(img_path) img = np.array(img.convert('L')) img = img / 255. x.append(img) y_.append(value[1]) print('loading : ' + content)
x = np.array(x) y_ = np.array(y_) y_ = y_.astype(np.int64) return x, y_
if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists( x_test_savepath) and os.path.exists(y_test_savepath): print('-------------Load Datasets-----------------') x_train_save = np.load(x_train_savepath) y_train = np.load(y_train_savepath) x_test_save = np.load(x_test_savepath) y_test = np.load(y_test_savepath) x_train = np.reshape(x_train_save, (len(x_train_save), 256, 256, 1)) x_test = np.reshape(x_test_save, (len(x_test_save), 256, 256, 1)) else: print('-------------Generate Datasets-----------------') x_train, y_train = generateds(train_txt) x_test, y_test = generateds(test_txt) x_train = np.reshape(x_train, (len(x_train), 256, 256, 1)) x_test = np.reshape(x_test, (len(x_test), 256, 256, 1)) print('-------------Save Datasets-----------------') x_train_save = np.reshape(x_train, (len(x_train), 256, 256, 1)) x_test_save = np.reshape(x_test, (len(x_test), 256, 256, 1)) np.save(x_train_savepath, x_train_save) np.save(y_train_savepath, y_train) np.save(x_test_savepath, x_test_save) np.save(y_test_savepath, y_test)
model = tf.keras.models.Sequential([ tf.keras.layers.Conv2D(filters=36, kernel_size=(3, 3), padding='same'), tf.keras.layers.BatchNormalization(), tf.keras.layers.Activation('relu'), tf.keras.layers.Conv2D(filters=12, kernel_size=(3, 3), padding='same'), tf.keras.layers.BatchNormalization(), tf.keras.layers.Activation('relu'), tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=2, padding='same'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dropout(0.5), tf.keras.layers.Dense(4, activation='softmax') ])
model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.01), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy'])
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', patience=10, mode='auto')
if os.path.exists(checkpoint_save_path + '.index'): print('------------------ restore model successfully------------------') model.load_weights(checkpoint_save_path) cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) history = model.fit(image_gen_train.flow(x_train, y_train, batch_size=24), epochs=200, validation_data=(x_test, y_test), validation_freq=1, callbacks=[reduce_lr, cp_callback])
acc = history.history['sparse_categorical_accuracy'] val_acc = history.history['val_sparse_categorical_accuracy'] loss = history.history['loss'] val_loss = history.history['val_loss']
plt.subplot(1, 2, 1) plt.plot(acc, label='Training Accuracy') plt.plot(val_acc, label='Validation Accuracy') plt.title('Training and Validation Accuracy') plt.legend()
plt.subplot(1, 2, 2) plt.plot(loss, label='Training Loss') plt.plot(val_loss, label='Validation Loss') plt.title('Training and Validation Loss') plt.legend() plt.show()
|