个人小作业,虽说做的很差,也算是一个学习的转化;主要用于分类自己下载的壁纸

1 背景

学期末需要一个学习成果的展示,高难度的自己做不来,模型也跑不动(电脑有点渣),刚好自己也有图片分类的需求,最后决定做了这个,确实也算做了一个自己用得到的小程序

2 项目说明

2.1 项目需求

需要自动加载指定目录所有图片,自行迁移至指定目录并存入不同的文件夹

2.2 实现思路

  1. 数据来源于各大壁纸网站,通过下载分类好的图片免去了自己手动分类的痛苦
  2. 将图片进行微缩处理,将1920 python怎么把图片分类_tensorflow 1080的图片转化为192 python怎么把图片分类_tensorflow
  3. 第二步可以将图片转化为单通道,数据量会小很多,但是测试过程中发现数据集较小时准确率比直接使用三通道要高一些,但是数据集大之后三通道的图片识别更加准确
  4. 目前数据集是共10000多张图片共五个分类(差不多自己电脑的上限),通过第二步、第三步的三通道缩小处理后,所有数据集大小约600MB,还在接受范围内。
  5. 模型的搭建与其他模型搭建基本一致

3 项目说明

3.1 项目结构

│  colorUi.ui	正在使用的UI界面文件
│  fun.py		对于模型函数的初步封装,为PyQt界面提供支持
│  main.py		入口部分
│  model.py		模型的训练、加载
│  ui.py		正在使用的UI界面py文件
│  ui.ui		老的UI界面文件
│  utils.py		一些读取图片处理图片的函数
├─fun_test			内含各类图片共100张,用于最后的功能测试
├─make_data_set		用于处理制作数据集
├─model				训练好的模型存储的路径
├─test				内含处理好的数据集的测试集,存储格式是是numpy数组的序列化,三通道维度信息(N,108.,192,3);标签一维数组
├─test_pic		测试集原始数据目录,路径下各种图片独占一个目录,用于通过make_data_set制作数据集,目录应与train_pic对应
│  ├─dongman	其中一个分类
│  ├─dongwu		其中一个分类
│  ├─fengjing	其中一个分类
│  ├─meinv		其中一个分类
│  └─youxi		其中一个分类
├─train			内含处理好的数据集的训练集,存储格式是是numpy数组的序列化,三通道维度信息(N,108.,192,3);标签一维数组
└─train_pic
    ├─dongman	其中一个分类
    ├─dongwu	其中一个分类
    ├─fengjing	其中一个分类
    ├─meinv		其中一个分类
    └─youxi		其中一个分类

3.2 源码说明

3.2.1 模型的创建、加载、训练
import json
import os

import cv2
import numpy
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tqdm import tqdm

from utils import img_resize

def init_network():
    """
    初始化神经网络,支持五种类型
    :return: 模型
    """

    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(filters=48, kernel_size=(3, 3), padding='same', activation='relu', strides=1,
                               input_shape=(108, 192, 3)),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        # 抑制过拟合
        tf.keras.layers.Dropout(rate=0.6),
        tf.keras.layers.Conv2D(filters=24, kernel_size=(
            3, 3), padding='same', activation='relu', strides=1),
        # 2*2池化取最大值
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        # 抑制过拟合
        tf.keras.layers.Dropout(rate=0.6),
        # 维度拉伸成1维
        tf.keras.layers.Flatten(),
        # 第二层隐藏层,使用relu激活函数
        tf.keras.layers.Dense(256, activation='relu'),
        # 抑制过拟合
        tf.keras.layers.Dropout(rate=0.6),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dropout(rate=0.5),
        tf.keras.layers.Dense(256, activation='relu'),
        tf.keras.layers.Dropout(rate=0.5),
        # 输出层
        tf.keras.layers.Dense(5, activation='softmax')
    ])
    model.compile(loss='categorical_crossentropy',
                  optimizer='adam', metrics=['accuracy'])
    model.summary()
    return model

def getTrainData():
    """
    获取训练集数据
    :return: train_images, train_labels, class_names
    """
    fp = open('./train/train.json', 'r', encoding='utf8')
    class_names = json.load(fp)['support']
    fp.close()

    # 返回加载来的数据集
    pic_train_images = numpy.load('./train/train_pic.npy')
    train_images = pic_train_images.reshape(
        pic_train_images.shape[0], 108, 192, 3) / 255.0
    print(train_images.shape)
    train_labels = numpy.load('./train/train_labels.npy')
    print(numpy.load('./train/train_labels.npy').shape)
    return train_images, train_labels, class_names

def getTestData():
    """
    获取测试集包的数据
    :return: train_images, train_labels, class_names
    """
    fp = open('./test/test.json', 'r', encoding='utf8')
    class_names = json.load(fp)['support']
    fp.close()

    # 返回加载来的数据集
    pic_test_images = numpy.load('./test/test_pic.npy')
    test_images = pic_test_images.reshape(pic_test_images.shape[0], 108, 192, 3) / 255.0
    print(test_images.shape)
    test_labels = numpy.load('./test/test_labels.npy')
    print(numpy.load('./test/test_labels.npy').shape)
    return test_images, test_labels, class_names

def getTestImages():
    """
    加载测试集1920*1080的壁纸
    """
    path = './test_pic'
    imgs = []
    labels = []
    k = 0
    paths = os.listdir(path)
    paths.sort()
    for j in paths:
        pbar = tqdm(total=100)
        for i in os.listdir(path + '/' + j):
            pbar.update(100.0 / len(os.listdir(path + '/' + j)))
            pic_path = path + '/' + j + '/' + i
            # img = img_resize(cv2.imread(pic_path, cv2.IMREAD_GRAYSCALE))
            img = img_resize(cv2.imread(pic_path))
            if img.shape[0] != 108 or img.shape[1] != 192:
                os.remove(pic_path)
                continue
            imgs.append(img)
            labels.append(k)
        pbar.close()
        k = k + 1
    pic_test_images = np.array(imgs)
    test_images = pic_test_images.reshape(
        pic_test_images.shape[0], 108, 192, 3) / 255.0
    return test_images, np.array(labels)

def getModel(train_mode=False):
    """
    获取模型
    :param train_mode: 是否训练
    :return: 模型
    """
    # 如果训练
    if train_mode:
        # 初始化神经网络
        model = init_network()
        # 加载数据集
        train_images, train_labels, _ = getTrainData()
        test_images, test_labels, _ = getTestData()
        print(train_images.shape)
        print(train_labels.shape)
        print(test_images.shape)
        print(test_labels.shape)
        # 开始训练,训练二十次,显示日志信息
        model.fit(train_images, keras.utils.to_categorical(
            train_labels), batch_size=128, epochs=100, verbose=2)
        # 评估模型,不输出预测结果
        test_loss, test_acc = model.evaluate(
            test_images, keras.utils.to_categorical(test_labels), verbose=2)
        # 输出损失值
        print('测试集损失:', test_loss)
        # 输出正确率
        print('测试集正确率:', test_acc)
        # 保存模型
        model.save('.\\model\\expll.h5')
        return model, test_loss, test_acc
    else:
        # 加载模型
        model = tf.keras.models.load_model('.\\model\\780_3x3_1_3_100_expll.h5')
        # 打印模型信息
        model.summary()

        test_images, test_labels, _ = getTestData()
        # 评估模型,不输出预测结果
        test_loss, test_acc = model.evaluate(
            test_images, keras.utils.to_categorical(test_labels), verbose=2)
        # print([np.where(i == np.max(i))[0][0] for i in model.predict(test_images)])
        return model, test_loss, test_acc

# 训练模型
# if __name__ == '__main__':
#     model = getModel(True)
3.2.2 模型功能的封装,用于支持PyQt功能界面逻辑
import json
import os
import shutil

import numpy as np
from PyQt5.QtCore import *

import utils
from model import getModel

def getModelSupportTypes(data):
    """
    获取模型支持的分类
    :return:
    """
    temp = ''
    for i in data:
        temp = temp + ' ' + i
    return temp

def getModelInfo(loss, acc):
    """
    获取模型信息
    :return: 模型测试准确度
    """
    return '测试集损失:{:.3f}\n测试集准确率:{:.3f}%'.format(loss, acc * 100)

class Service(QObject):
    signalRunTime = pyqtSignal(str, bool)
    model = None
    signalWorking = pyqtSignal(bool)
    loadModelStatus = False
    signalModelInfo = pyqtSignal(str)
    signalModelSupportTypes = pyqtSignal(str)

    def __init__(self):
        super().__init__()

    def predict(self, imgs: np.array):
        """
        预测
        :param imgs: 预测图片集
        :return: 预测结果
        """
        rs = self.model.predict(imgs)
        return [np.where(i == np.max(i))[0][0] for i in rs]

    def iniModel(self):
        """
        初始化加载模型
        """
        if self.loadModelStatus:
            self.signalRunTime.emit('模型加载中···', False)
            return
        self.loadModelStatus = True
        self.signalRunTime.emit('正在加载模型···', False)
        self.model, loss, acc = getModel()
        with open('model/model.json', 'r', encoding='utf8') as fp:
            info = json.load(fp)
            self.signalModelInfo.emit('方法:' + info['way'] + '\n' + getModelInfo(loss, acc))
            self.signalModelSupportTypes.emit(getModelSupportTypes(info['support']))
        self.signalRunTime.emit('模型加载完成', False)
        self.loadModelStatus = False

    def startRun(self, window):
        """
        开始进行分类
        :param window: 窗口对象
        """
        if len(window.getFromPath()) == 0 or len(window.getTargetPath()) == 0:
            self.signalRunTime.emit('\n存在路径为空\n', False)
            self.signalWorking.emit(False)
            return
        list_path = []
        self.signalRunTime.emit('\n检索中······\n', False)
        utils.getListDir(window.fromPath.toPlainText(), window.getRecursionPathStatus(), list_path, imageCallback=None,
                         dirCallback=lambda x: self.signalRunTime.emit('检索检索到目录: {0}\n'.format(x), False))
        self.signalRunTime.emit('检索完成,共计{0}张图片\n'.format(len(list_path)), False)
        if len(list_path) == 0:
            self.signalWorking.emit(False)
            return
        self.signalRunTime.emit('开始读取图片······', False)
        img = utils.get_data(list_path, lambda x: self.signalRunTime.emit('已加载: {0}\n'.format(x), False))
        self.signalRunTime.emit('读取图片完成', False)
        self.signalRunTime.emit('维度信息:{0}'.format(img.shape), False)
        self.signalRunTime.emit('进行分类识别中······', False)
        rs = self.predict(img)
        self.signalRunTime.emit('分类识别完成\n***********\n识别结果:\n***********\n***********\n***********\n', False)

        with open('.\\model\\model.json', encoding='utf8') as fp:
            supportTypes = json.load(fp)['support']
            outRunInfo = '\n'
            for i in zip(list_path, rs):
                outRunInfo = outRunInfo + '路径: {0}; 结果:{1}\n\n'.format(i[0], supportTypes[i[1]])
            self.signalRunTime.emit(outRunInfo + '\n\n***********\n***********\n识别结果输出结束\n***********\n***********\n',
                                    False)
            targetPathRoot = window.getTargetPath()
            for i in supportTypes:
                if not os.path.exists(targetPathRoot + '/' + i):
                    os.mkdir(targetPathRoot + '/' + i)
        self.signalRunTime.emit('\n\n开始进行分类迁移······', False)

        onlyMoveMax = window.getOnlyNumber()
        with open('.\\model\\model.json', encoding='utf8') as fp:
            supportTypes = json.load(fp)['support']
            for j in range(0, int(len(list_path) * 1.0 / onlyMoveMax + 1)):
                for i in list(zip(list_path, rs))[onlyMoveMax * j:onlyMoveMax * (j + 1)]:
                    try:
                        self.signalRunTime.emit(
                            '来源: {0}; 迁移至:{1}\n\n'.format(i[0], (targetPathRoot + '/' + supportTypes[i[1]])), False)
                        shutil.move(i[0], targetPathRoot + '/' + supportTypes[i[1]])
                    except Exception as e:
                        self.signalRunTime.emit(
                            'ERROR: {0}'.format(e, False))
        self.signalRunTime.emit('\n\n迁移结束,任务完成\n\n', False)
        self.signalWorking.emit(False)
3.2.3 入口部分
# -*- coding: utf-8 -*-
import os
import sys
from concurrent.futures import ThreadPoolExecutor

from PyQt5.QtWidgets import *

import fun
from ui import Ui_Form

threadPool = ThreadPoolExecutor(max_workers=20)

def openPath(callback):
    # 选择图片
    path = QFileDialog.getExistingDirectory(None, "选择存储文件夹", os.getcwd())
    if path == "":
        return 0
    callback(path)

class MainWindow(QWidget, Ui_Form):
    service = None
    img = None
    working = False

    def __init__(self, service_):
        super(MainWindow, self).__init__()
        self.service = service_
        self.setupUi(self)

    def openFromPath(self):
        """
        选择来源路径
        """
        openPath(callback=lambda x: self.fromPath.setText(x))

    def openTargetPath(self):
        """
        选择输出路径
        """
        openPath(callback=lambda x: self.targetPath.setText(x))

    def outRuntimeInfo(self, data, refresh=True):
        """
        输出运行时
        :param data: 日志
        :param refresh: 追加或清空再输出
        """
        if refresh:
            self.runtimeInfor.setText(data)
        else:
            self.runtimeInfor.setText(self.runtimeInfor.toPlainText() + '\n' + data)
        self.runtimeInfor.moveCursor(self.runtimeInfor.textCursor().End)

    def getFromPath(self):
        """
        获取源路径
        :return: 源路径
        """
        return self.fromPath.toPlainText()

    def getTargetPath(self):
        """
        获取输出路径
        :return: 输出路径
        """
        return self.targetPath.toPlainText()

    def outSupportTypes(self, data):
        """
        输出模型支持的类型
        :param data: 类型串
        """
        self.modelType.setText(data)

    def outModelInfo(self, data):
        """
        输出模型信息
        :param data: 模型信息
        """
        self.modelInfor.setText(data)

    def getOnlyNumber(self):
        """
        单次处理图片数量
        :return: 数量
        """
        return self.onlyNumber.value()

    def getRecursionPathStatus(self):
        """
        是否递归目录
        """
        return self.recursionPath.checkState() == 2

    def startRun(self):
        """
        开始分类
        """
        if self.working:
            self.outRuntimeInfo('任务执行中', False)
            return
        try:
            threadPool.submit(service.startRun, self)
        except Exception as e:
            print(e)

    def setWorking(self, status):
        self.working = status

if __name__ == '__main__':
    service = fun.Service()
    app = QApplication(sys.argv)
    # 初始化窗口
    m = MainWindow(service)
    m.btu_selectFromPath.clicked.connect(m.openFromPath)
    m.btu_selectTargetPath.clicked.connect(m.openTargetPath)
    m.btu_startRun.clicked.connect(m.startRun)
    m.setWindowTitle('1920*1080壁纸分类')
    m.show()
    service.signalRunTime.connect(m.outRuntimeInfo)
    service.signalWorking.connect(m.setWorking)
    service.signalModelInfo.connect(m.outModelInfo)
    service.signalModelSupportTypes.connect(m.outSupportTypes)
    threadPool.submit(service.iniModel)
    sys.exit(app.exec_())
3.2.4 UI界面

python怎么把图片分类_python怎么把图片分类_03

4 结语

  虽说很简单,或许显得很那么······没用,但是也是自己的一个小成果,也算是又做了一个对自己有用的工具吧!

项目文件所在地址,内含训练好的模型,目前支持五种:https://github.com/WindSnowLi/picture-classify