-- coding: utf-8 --

“”"
@Time : 19-9-20 下午8:12
@Author : lei
@Site :
@File : captcha_train.py
@Software: PyCharm
“”"

import tensorflow as tf
import os

定义一个初始化权重的函数

def weight_variables(shape):
weight = tf.Variable(tf.random_normal(shape=shape, mean=0.0, stddev=1.0))
return weight

def bias_variables(shape):
bias = tf.Variable(tf.constant(0.0, shape=shape))
return bias

读取数据文件

def read_and_decode():
“”"
读取验证码数据api 返回
:return: image_batch, label_batch
“”"
# 1.构建文件队列
file_queue = tf.train.string_input_producer(["./data/capt_tfrecords"])

# 2.构建阅读器,读取文件内容,默认一个样本
reader = tf.TFRecordReader()

# 3.读取内容
key, value = reader.read(file_queue)

# tfrecords 格式需要example, 需要解析
features = tf.parse_single_example(value, features={
    # 存入什么,解析出来的还是什么
    "image": tf.FixedLenFeature([], tf.string),
    "label": tf.FixedLenFeature([], tf.string),
})

# 4.解码内容,字符串内容
# 1.先解析图片的特征值
image = tf.decode_raw(features["image"], tf.uint8)
# 2.解码图片的目标值
label = tf.decode_raw(features["label"], tf.uint8)

# 还没有形状
# print(image, label)

# 改变形状  高20 宽80 彩色图片  改变形状因为要批处理
image_reshape = tf.reshape(image, [20, 80, 3])

# 目标值会有一个验证码会有四个
label_reshape = tf.reshape(label, [4])

# 特征值和目标值都有形状,可以进行批处理,批处理必须要有明确形状
# print(image_reshape, label_reshape)

# 进行批处理,每批次训练的样本数 100,
image_batch, label_batch = tf.train.batch([image_reshape, label_reshape], batch_size=100, num_threads=1, capacity=100)

return image_batch, label_batch

def fc_model(image_batch):
“”"
进行预测结果
:param image_batch: 图片特征值 [100, 20, 80, 3]
:return: y_predict[100, 426]
“”"
# 建立模型
with tf.variable_scope(“model”):
# 将图片数据转换成二维的形状 不知道有多少个选择-1
image_reshape = tf.reshape(image_batch, [-1, 20
80*3])

# 1.随机初始化权重,偏置
    # matrix [100, 20*80*3] * [20*80*3, 4*26] + [4*26]
    weight = weight_variables([20 * 80 * 3, 4 * 26])
    bias = bias_variables([104])

    # 进行全连接层计算  [100, 4*26]      数据类型进行转换
    y_predict = tf.matmul(tf.cast(image_reshape, tf.float32), weight) + bias

return y_predict

def predict_to_onhot(label):
“”"
将读取文件当中的目标值转换成one-hot编码
:param label: [100, 4] 100426
:return:
“”"
# 进行one_hot编码转换,提供给交叉熵损失计算,精确率计算
# depth 产生的类别=26 某个位置是 就为1.0 axis=对三维中的第二部分转换
label_onehot = tf.one_hot(label, depth=26, on_value=1.0, axis=2)

return label_onehot

def captcharec():
“”"
验证码识别程序
:return:
“”"
# 1.读取验证码的数据文件 label_batch [100, 4]
image_batch, label_batch = read_and_decode()

# 2.通过输入图片特征数据,建立模型,得出预测结果
# 一层全连接神经网络进行预测
# matrix [100, 20*80*3] * [20*80*3, 4*26] + [4*26]
# 100个样本  每个样本产生104个输出  [100, 4*26]
y_predict = fc_model(image_batch)
print(y_predict)

# 3.先把目标值转换成one-hot编码  [100, 4, 26]
y_true = predict_to_onhot(label_batch)

# softmax 计算,交叉上损失计算
with tf.variable_scope("soft_cross"):
    # 求平均交叉熵损失 , y_true [100, 4, 26] --> [100, 4*26]
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
        labels=tf.reshape(y_true, [-1, 4*26]),
        logits=y_predict,
    ))

# 5.梯度下降优化损失
with tf.variable_scope("optimizer"):
   train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)

# 6.求出样本的每批次预测的准确率是多少  三维比较
with tf.variable_scope("acc"):
    # 比较预测值和目标值是否位置(4个)一样  如果相等则置为1
    # 对y_true的第二个位置求最大值 [100, 4, 26]
    print(y_true)
    print(y_predict)
    equal_list = tf.equal(tf.argmax(y_true, 2), tf.argmax(tf.reshape(y_predict, [100, 4, 26]), 2))

    # 求平均值
    accuracy = tf.reduce_mean(tf.cast(equal_list, tf.float32))

init_op = tf.global_variables_initializer()

# 开启会话
with tf.Session() as sess:
    sess.run(init_op)

    # 定义线程协调器和开启线程(有数据在文件当中读取 从而提供给模型)
    cord = tf.train.Coordinator()

    # 开启线程去运行文件操作
    threads = tf.train.start_queue_runners(sess, coord=cord)

    # 训练识别程序
    for i in range(5000):
        sess.run(train_op)
        print("第{}次的准确率为:{}".format(i, accuracy.eval()))
    # 回收线程
    cord.request_stop()
    cord.join(threads)

return None

if name == “main”:
captcharec()