概述

MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.

TensorFlow2 手把手教你训练 MNIST 数据集 part 2_JAVA

get_data 函数

TensorFlow2 手把手教你训练 MNIST 数据集 part 2_原力计划_02

def get_data():
    """
    获取数据
    :return: 返回分批完的训练集和测试集
    """

    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(60000, seed=0)
    train_db = train_db.batch(batch_size).map(pre_processing)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(10000, seed=0)
    test_db = test_db.batch(batch_size).map(pre_processing)

    # 返回
    return train_db, test_db
pre_processing 函数
def pre_processing(x, y):
    """
    数据预处理
    :param x: 特征值
    :param y: 目标值
    :return: 返回处理好的x, y
    """
    # 转换x
    x = tf.cast(x, tf.float32) / 255
    x = tf.reshape(x, [-1, 784])

    # 转换y
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)

    return x, y
train 函数
def train(train_db):
    """
    训练数据
    :param train_db: 分批的数据集
    :return: 无返回值
    """
    for step, (x, y) in enumerate(train_db):
        with tf.GradientTape() as tape:

            # 获取模型输出结果
            logits = model(x)

            # 计算MSE
            MSE = tf.reduce_mean(tf.losses.MSE(y, logits))

            # 计算交叉熵
            Cross_Entropy = tf.losses.categorical_crossentropy(y, logits, from_logits=True)
            Cross_Entropy = tf.reduce_sum(Cross_Entropy)

        # 计算梯度
        grads = tape.gradient(Cross_Entropy, model.trainable_variables)

        # 跟新参数
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # 每100批调试输出一下误差
        if step % 100 == 0:
            print("step:", step, "Cross_Entropy:", float(Cross_Entropy), "MSE:", float(MSE))
test 函数
def test(epoch, test_db):
    """
    测试模型
    :param epoch: 轮数
    :param test_db: 分批的测试集
    :return: 无返回值
    """
    total_correct = 0  # 正确数
    total_num = 0  # 总数

    for x, y in test_db:
        # 获取模型输出结果
        logits = model(x)

        # 预测结果
        pred = tf.argmax(logits, axis=1)

        # 从one_hot编码变回来
        y = tf.argmax(y, axis=1)

        # 计算准确数
        correct = tf.equal(pred, y)
        correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))

        # 添加正确数和总数
        total_correct += int(correct)
        total_num += x.shape[0]

    # 计算准确率
    accuracy = total_correct / total_num

    # 调试输出
    print("epoch:", epoch, "Accuracy:", accuracy * 100, "%")
main 函数
def main():
    """
    主函数
    :return: 无返回值
    """

    # 获取数据
    train_db, test_db = get_data()

    # 轮期
    for epoch in range(1, iteration_num):
        train(train_db)
        test(epoch, test_db)
完整代码

TensorFlow2 手把手教你训练 MNIST 数据集 part 2_JAVA_03

import tensorflow as tf

# 定义超参数
batch_size = 256  # 一次训练的样本数目
learning_rate = 0.001  # 学习率
iteration_num = 20  # 迭代次数

# 优化器
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

# 模型
model = tf.keras.Sequential([
    tf.keras.layers.Dense(256, activation=tf.nn.relu),
    tf.keras.layers.Dense(128, activation=tf.nn.relu),
    tf.keras.layers.Dense(64, activation=tf.nn.relu),
    tf.keras.layers.Dense(32, activation=tf.nn.relu),
    tf.keras.layers.Dense(10)
])

# 调试输出summary
model.build(input_shape=[None, 28*28])
print(model.summary())


def pre_processing(x, y):
    """
    数据预处理
    :param x: 特征值
    :param y: 目标值
    :return: 返回处理好的x, y
    """
    # 转换x
    x = tf.cast(x, tf.float32) / 255
    x = tf.reshape(x, [-1, 784])

    # 转换y
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)

    return x, y


def get_data():
    """
    获取数据
    :return: 返回分批完的训练集和测试集
    """

    # 获取数据
    (X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

    # 分割训练集
    train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(60000, seed=0)
    train_db = train_db.batch(batch_size).map(pre_processing)

    # 分割测试集
    test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(10000, seed=0)
    test_db = test_db.batch(batch_size).map(pre_processing)

    # 返回
    return train_db, test_db


def train(train_db):
    """
    训练数据
    :param train_db: 分批的数据集
    :return: 无返回值
    """
    for step, (x, y) in enumerate(train_db):
        with tf.GradientTape() as tape:

            # 获取模型输出结果
            logits = model(x)

            # 计算MSE
            MSE = tf.reduce_mean(tf.losses.MSE(y, logits))

            # 计算交叉熵
            Cross_Entropy = tf.losses.categorical_crossentropy(y, logits, from_logits=True)
            Cross_Entropy = tf.reduce_sum(Cross_Entropy)

        # 计算梯度
        grads = tape.gradient(Cross_Entropy, model.trainable_variables)

        # 跟新参数
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # 每100批调试输出一下误差
        if step % 100 == 0:
            print("step:", step, "Cross_Entropy:", float(Cross_Entropy), "MSE:", float(MSE))


def test(epoch, test_db):
    """
    测试模型
    :param epoch: 轮数
    :param test_db: 分批的测试集
    :return: 无返回值
    """
    total_correct = 0  # 正确数
    total_num = 0  # 总数

    for x, y in test_db:
        # 获取模型输出结果
        logits = model(x)

        # 预测结果
        pred = tf.argmax(logits, axis=1)

        # 从one_hot编码变回来
        y = tf.argmax(y, axis=1)

        # 计算准确数
        correct = tf.equal(pred, y)
        correct = tf.reduce_sum(tf.cast(correct, dtype=tf.int32))

        # 添加正确数和总数
        total_correct += int(correct)
        total_num += x.shape[0]

    # 计算准确率
    accuracy = total_correct / total_num

    # 调试输出
    print("epoch:", epoch, "Accuracy:", accuracy * 100, "%")


def main():
    """
    主函数
    :return: 无返回值
    """

    # 获取数据
    train_db, test_db = get_data()

    # 轮期
    for epoch in range(1, iteration_num):
        train(train_db)
        test(epoch, test_db)


if __name__ == "__main__":
    main()

输出结果:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense (Dense)                (None, 256)               200960    
_________________________________________________________________
dense_1 (Dense)              (None, 128)               32896     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_3 (Dense)              (None, 32)                2080      
_________________________________________________________________
dense_4 (Dense)              (None, 10)                330       
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
_________________________________________________________________
None

step: 0 Cross_Entropy: 589.905029296875 MSE: 0.1215171366930008
step: 100 Cross_Entropy: 61.73141098022461 MSE: 16.245494842529297
step: 200 Cross_Entropy: 46.609832763671875 MSE: 17.865381240844727
epoch: 1 Accuracy: 95.44 %
step: 0 Cross_Entropy: 47.514892578125 MSE: 20.183507919311523
step: 100 Cross_Entropy: 35.65019226074219 MSE: 18.90221405029297
step: 200 Cross_Entropy: 33.837703704833984 MSE: 16.84846305847168
epoch: 2 Accuracy: 96.61 %
step: 0 Cross_Entropy: 17.38262939453125 MSE: 18.48729133605957
step: 100 Cross_Entropy: 27.96572494506836 MSE: 21.008562088012695
step: 200 Cross_Entropy: 27.25030517578125 MSE: 21.703704833984375
epoch: 3 Accuracy: 97.22 %
step: 0 Cross_Entropy: 21.492198944091797 MSE: 22.19614028930664
step: 100 Cross_Entropy: 11.623129844665527 MSE: 27.867923736572266
step: 200 Cross_Entropy: 7.261983394622803 MSE: 25.641494750976562
epoch: 4 Accuracy: 97.41 %
step: 0 Cross_Entropy: 11.380800247192383 MSE: 26.688203811645508
step: 100 Cross_Entropy: 10.21794319152832 MSE: 27.864110946655273
step: 200 Cross_Entropy: 14.44814682006836 MSE: 31.53815460205078
epoch: 5 Accuracy: 97.18 %
step: 0 Cross_Entropy: 5.241445541381836 MSE: 30.080406188964844
step: 100 Cross_Entropy: 3.1642959117889404 MSE: 33.59324645996094
step: 200 Cross_Entropy: 9.680063247680664 MSE: 34.96605682373047
epoch: 6 Accuracy: 97.95 %
step: 0 Cross_Entropy: 11.292088508605957 MSE: 36.604915618896484
step: 100 Cross_Entropy: 4.599205017089844 MSE: 38.455101013183594
step: 200 Cross_Entropy: 13.383275032043457 MSE: 41.19858932495117
epoch: 7 Accuracy: 97.65 %
step: 0 Cross_Entropy: 6.985865592956543 MSE: 33.687713623046875
step: 100 Cross_Entropy: 5.281797409057617 MSE: 44.13557815551758
step: 200 Cross_Entropy: 6.665032863616943 MSE: 44.898216247558594
epoch: 8 Accuracy: 97.72 %
step: 0 Cross_Entropy: 1.8101396560668945 MSE: 42.560211181640625
step: 100 Cross_Entropy: 4.517214298248291 MSE: 46.41954803466797
step: 200 Cross_Entropy: 5.113927364349365 MSE: 47.692081451416016
epoch: 9 Accuracy: 97.84 %
step: 0 Cross_Entropy: 5.45690393447876 MSE: 44.61886978149414
step: 100 Cross_Entropy: 6.035201549530029 MSE: 51.11096954345703
step: 200 Cross_Entropy: 7.727978229522705 MSE: 50.56428527832031
epoch: 10 Accuracy: 97.78 %
step: 0 Cross_Entropy: 6.566008567810059 MSE: 53.64844512939453
step: 100 Cross_Entropy: 12.636188507080078 MSE: 59.566192626953125
step: 200 Cross_Entropy: 0.9305715560913086 MSE: 63.96886444091797
epoch: 11 Accuracy: 97.68 %
step: 0 Cross_Entropy: 3.799677610397339 MSE: 57.57715606689453
step: 100 Cross_Entropy: 7.782512664794922 MSE: 63.94820785522461
step: 200 Cross_Entropy: 6.952803611755371 MSE: 59.19414138793945
epoch: 12 Accuracy: 97.85000000000001 %
step: 0 Cross_Entropy: 1.316650152206421 MSE: 57.405555725097656
step: 100 Cross_Entropy: 3.3630568981170654 MSE: 65.93612670898438
step: 200 Cross_Entropy: 2.8188657760620117 MSE: 63.6553955078125
epoch: 13 Accuracy: 97.71 %
step: 0 Cross_Entropy: 1.0694936513900757 MSE: 73.58941650390625
step: 100 Cross_Entropy: 1.1532164812088013 MSE: 72.19602966308594
step: 200 Cross_Entropy: 4.054533958435059 MSE: 66.22490692138672
epoch: 14 Accuracy: 97.69 %
step: 0 Cross_Entropy: 0.5501946806907654 MSE: 67.73658752441406
step: 100 Cross_Entropy: 1.6239964962005615 MSE: 75.26908874511719
step: 200 Cross_Entropy: 0.25266233086586 MSE: 79.37750244140625
epoch: 15 Accuracy: 97.96000000000001 %
step: 0 Cross_Entropy: 0.5946800112724304 MSE: 78.45301818847656
step: 100 Cross_Entropy: 3.876664638519287 MSE: 86.45103454589844
step: 200 Cross_Entropy: 13.129545211791992 MSE: 70.39665222167969
epoch: 16 Accuracy: 97.67 %
step: 0 Cross_Entropy: 4.019548416137695 MSE: 66.26248168945312
step: 100 Cross_Entropy: 0.7121025323867798 MSE: 67.56402587890625
step: 200 Cross_Entropy: 3.106649875640869 MSE: 77.95216369628906
epoch: 17 Accuracy: 97.71 %
step: 0 Cross_Entropy: 0.797190248966217 MSE: 70.34780883789062
step: 100 Cross_Entropy: 5.868640422821045 MSE: 74.68391418457031
step: 200 Cross_Entropy: 2.415027141571045 MSE: 85.03378295898438
epoch: 18 Accuracy: 97.54 %
step: 0 Cross_Entropy: 1.5692293643951416 MSE: 90.47661590576172
step: 100 Cross_Entropy: 0.6557420492172241 MSE: 81.88681030273438
step: 200 Cross_Entropy: 5.726837158203125 MSE: 76.24435424804688
epoch: 19 Accuracy: 98.0 %