前言

之前通过百度AI接口实现了图像识别,目标做图像识别再进行垃圾分类提示,于是乎我在网上查询各垃圾分类的数据集,很多数据集收费的各大网站让我很反感,接下来放两个比较nice的开源站:(开源让技术进步!)

百度AI社区 | 上传者:Thomas-yanxin,快速链接: 垃圾分类数据集ImageNet格式,用于训练效果是很不错的

知乎 | 作者:谢伟通过go语言实现了后端查询垃圾分类,其中包含垃圾分类数据集csv格式。

ps.后文使用到的数据集为csv格式,供学习使用。

正文

1.数据集

数据集包含2534条数据,大致分为4类,根据go语言4类分类,重新处理了一下数据集:

基于图像分类网络VGG实现垃圾分类识别 图像识别 垃圾分类_python

参考如下:

基于图像分类网络VGG实现垃圾分类识别 图像识别 垃圾分类_基于图像分类网络VGG实现垃圾分类识别_02

2.表格匹配规则

读取表格B列存储的名称和识别的名称对比,对比规则: 字符串相等 sort:1 不相等,但包含字符: 字符串相等的字节占总字节的比例 sort = len(name)/len(cell.value)

import openpyxl
from openpyxl import Workbook


def excel(excel_file, name):
    # 定义列表result存储所有读取数据
    result = {"id":0,"name":"test","imageUrl":"","sortId":"可回收垃圾","result":"","sort":0}
    wb = openpyxl.load_workbook(excel_file)  # 读取excel文件
    sheet = wb.worksheets[0]#读取第一个表
    col = sheet["B"]#读取B列
    for index,cell in enumerate(col):
        if cell.value == name:
            result["sort"] = 1
            result["name"] = name
            result["result"] = cell.value
            result["id"] = sheet.cell(index,1).value
            result["sortId"] = sheet.cell(index,4).value
            print(cell.value,1)
        else:
            if name in cell.value:
                sort = len(name)/len(cell.value)
                print(cell.value,sort)
                if sort > result["sort"]:
                    result["sort"] = sort
                    result["name"] = name
                    result["result"] = cell.value
                    result["id"] = sheet.cell(index,1).value
                    result["sortId"] = sheet.cell(index,4).value
                    print(cell.value,sort)

    if result["sort"] == 0:
        print("未识别到是什么垃圾")
    else:
        print("====识别结果=======")
        print(result)
        print(result["name"],"的识别结果是:",result["result"],"是",result["sortId"],"可信度为:",result["sort"])


excel(r'/Users/wangyu/Desktop/waste.xlsx',
              r'花生')  # 调用函数,传入参数

基于图像分类网络VGG实现垃圾分类识别 图像识别 垃圾分类_opencv_03

3.图像识别+垃圾分类

参考上一篇文章图像识别

基于图像分类网络VGG实现垃圾分类识别 图像识别 垃圾分类_百度_04

# coding=utf-8

import requests
import json
import base64

# 防止https证书校验不正确
import ssl
ssl._create_default_https_context = ssl._create_unverified_context


"""
    获取token
"""

def getToken(AccessKey,SecretKey):
    # client_id 为官网获取的AK, client_secret 为官网获取的SK
    host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id='+AccessKey+'&client_secret='+SecretKey
    headers = {
        'Content-Type': 'application/json;charset=UTF-8',
        'grant_type': 'client_credentials',
    }
    access_token = ''
    response = requests.get(url=host, headers=headers)
    if response:
        res = response.json()
        access_token = res['access_token']
    return access_token

def getResult(url,access_token):
#     植物
    request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/plant"
#     动物
#     request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/animal"
    # 二进制方式打开图片文件
    f = open(url, 'rb')
    img = base64.b64encode(f.read())
    params = {"image":img,"top_num":1,"baike_num":1}

    request_url = request_url + "?access_token=" + access_token
    headers = {'content-type': 'application/json'}
    response = requests.post(request_url, data=params, headers=headers)
    result = response.json()
    if result:
        print (json.dumps(result,indent=1,ensure_ascii=False))
        # 打印图片结果
        for data in result["result"]:
            print(u"  菜品名称: " + data["name"])
            return data["name"]
    else:
        return "未识别该图片"

def getWaste(name):
    waste = '未识别到该垃圾的分类'
    # 定义列表result存储所有读取数据
    result = {"id":0,"name":"test","imageUrl":"","sortId":"可回收垃圾","result":"","sort":0}
    excel_file = "/Users/wangyu/Desktop/waste.xlsx"
    wb = openpyxl.load_workbook(excel_file)  # 读取excel文件
    sheet = wb.worksheets[0]#读取第一个表
    col = sheet["B"]#读取B列
    for index,cell in enumerate(col):
        if cell.value == name:
            result["sort"] = 1
            result["name"] = name
            result["result"] = cell.value
            result["id"] = sheet.cell(index,1).value
            result["sortId"] = sheet.cell(index,4).value
            print(cell.value,1)
        else:
            if name in cell.value:
                sort = len(name)/len(cell.value)
                print(cell.value,sort)
                if sort > result["sort"]:
                    result["sort"] = sort
                    result["name"] = name
                    result["result"] = cell.value
                    result["id"] = sheet.cell(index,1).value
                    result["sortId"] = sheet.cell(index,4).value
                    print(cell.value,sort)

    if result["sort"] == 0:
        print("未识别到是什么垃圾")
        return waste
    else:
        print("====识别结果=======")
        print(result["name"],"的识别结果是:",result["result"],"是",result["sortId"],"可信度为:",result["sort"])
        return result



if __name__ == '__main__':
    # 识别的图片
#     url = '/Users/wangyu/Desktop/shicai.jpg'
#     url = '/Users/wangyu/Desktop/fish.jpg'
    url = '/Users/wangyu/Desktop/hongzao.jpg'
    # 百度账号信息
    AccessKey = ''
    SecretKey = ''
    # 获取小票识别结果
    access_token = getToken(AccessKey,SecretKey)
    img_result = getResult(url,access_token)
    waste_result = getWaste(img_result)

4.整合显示

通过标识将主要内容输出显示于图片上:

基于图像分类网络VGG实现垃圾分类识别 图像识别 垃圾分类_百度_05

# coding=utf-8
# 网络数据请求
import requests
import json
import base64
import numpy as np
# 表格处理
import openpyxl
from openpyxl import Workbook
# 图像处理
import cv2
from matplotlib import pyplot as plt
from urllib import request
from PIL import Image, ImageDraw, ImageFont

# 防止https证书校验不正确
import ssl
ssl._create_default_https_context = ssl._create_unverified_context


"""
    获取token
"""

def getToken(AccessKey,SecretKey):
    # client_id 为官网获取的AK, client_secret 为官网获取的SK
    host = 'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id='+AccessKey+'&client_secret='+SecretKey
    headers = {
        'Content-Type': 'application/json;charset=UTF-8',
        'grant_type': 'client_credentials',
    }
    access_token = ''
    response = requests.get(url=host, headers=headers)
    if response:
        res = response.json()
        access_token = res['access_token']
    return access_token
"""
    图像识别结果
    输入:本地图片地址,token
    输出:识别结果,识别分数
"""
def getResult(url,access_token):
    img = cv2.imread(url)
    plt.figure(figsize=(5,5))
    plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))#BGR转RGB
    plt.xlabel(u'img')
    plt.show()
#     植物
    request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/plant"
#     动物
#     request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/animal"
    # 二进制方式打开图片文件
    f = open(url, 'rb')
    img = base64.b64encode(f.read())
    params = {"image":img,"top_num":1,"baike_num":1}

    request_url = request_url + "?access_token=" + access_token
    headers = {'content-type': 'application/json'}
    response = requests.post(request_url, data=params, headers=headers)
    result = response.json()
    if result:
        print (json.dumps(result,indent=1,ensure_ascii=False))
        # 打印图片结果
        for data in result["result"]:
            print(u"  菜品名称: " + data["name"])
            if data["baike_info"]["image_url"]:
                print(u"  图片为" + data["baike_info"]["image_url"])
                plt.figure(figsize=(5,5))
                response = requests.get(data["baike_info"]["image_url"])
                resp = request.urlopen(data["baike_info"]["image_url"])
                image = np.asarray(bytearray(resp.read()), dtype="uint8")
                image = cv2.imdecode(image, cv2.IMREAD_COLOR)
                plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))#BGR转RGB
                plt.xlabel("score is"+str(data["score"]))
                plt.show()
            return data["name"],data["score"]
    else:
        return "未识别该图片",0

"""
    垃圾分类识别
    输入:物体名称
    输出:垃圾分类结果,分数
"""
def getWaste(name):
    waste = '未识别到该垃圾的分类'
    # 定义列表result存储所有读取数据
    result = {"id":0,"name":"test","imageUrl":"","sortId":"可回收垃圾","result":"","sort":0}
    excel_file = "/Users/wangyu/Desktop/waste.xlsx"
    wb = openpyxl.load_workbook(excel_file)  # 读取excel文件
    sheet = wb.worksheets[0]#读取第一个表
    col = sheet["B"]#读取B列
    for index,cell in enumerate(col):
        if cell.value == name:
            result["sort"] = 1
            result["name"] = name
            result["result"] = cell.value
            result["id"] = sheet.cell(index,1).value
            result["sortId"] = sheet.cell(index,4).value
            print(cell.value,1)
        else:
            if name in cell.value:
                sort = len(name)/len(cell.value)
                print(cell.value,sort)
                if sort > result["sort"]:
                    result["sort"] = sort
                    result["name"] = name
                    result["result"] = cell.value
                    result["id"] = sheet.cell(index,1).value
                    result["sortId"] = sheet.cell(index,4).value
                    print(cell.value,sort)

    if result["sort"] == 0:
        print("未识别到是什么垃圾")
        return waste,0
    else:
        print("====识别结果=======")
        print(result["name"],"的识别结果是:",result["result"],"是",result["sortId"],"可信度为:",result["sort"])
        return result["sortId"],result["sort"]

"""
    在图像上标记中文
    输入:图片(cv2格式),文字,写到图片上的位置(x,y),文字颜色,文字大小
    输出:图片
"""
def cv2AddChineseText(img, text, position=(0,0), textColor=(0, 255, 0), textSize=30):
    if (isinstance(img, np.ndarray)):  # 判断是否OpenCV图片类型
        img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    # 创建一个可以在给定图像上绘图的对象
    draw = ImageDraw.Draw(img)
    # 字体的格式,需要下载
    fontStyle = ImageFont.truetype(
        "simsun/simsun.ttc", textSize, encoding="utf-8")
    # 绘制文本
    draw.text(position, text, textColor, font=fontStyle)
    # 转换回OpenCV格式
    return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)

if __name__ == '__main__':
    # 识别的图片
#     url = '/Users/wangyu/Desktop/shicai.jpg'
#     url = '/Users/wangyu/Desktop/fish.jpg'
    url = '/Users/wangyu/Desktop/hongzao.jpg'
    img = cv2.imread(url)
    # 百度账号信息
    AccessKey = ''
    SecretKey = ''
    # 获取小票识别结果
    access_token = getToken(AccessKey,SecretKey)
    img_result,score = getResult(url,access_token)
    if img_result != "未识别该图片":
        waste_result,sort = getWaste(img_result)
        image = cv2AddChineseText(image, "图片识别:"+img_result, (30, 50), (255,0,0), 50)
        image = cv2AddChineseText(image, "score: {:.2f}".format(score), (500, 50), (255,0,255), 30)
        image = cv2AddChineseText(image, "垃圾分类:"+waste_result, (30, 120), (255,0,0), 50)
        image = cv2AddChineseText(image, "sort: {:.2f}".format(sort), (500, 120), (255,0,255), 30)
        plt.figure(figsize=(5,5))
        plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))#BGR转RGB
        plt.xlabel(u'image')
        plt.show()

总结

通过百度API图像接口实现了单物种的图像识别,取识别率最高的一个结果;通过excel匹配,实现垃圾分类提示。垃圾分类匹配规格达2k+。

图像识别目前仅用了一个类别的识别,百度组合API的识别试了一下,暂时不ok,报400的错误,传参这里还有问题,思考中。

不限于API识别图像,也可以使用其他方案识别图像。