识别各类简单图片中的数字

前言

在上一篇笔记,介绍了如何用MNIST训练和测试数据库来训练和测试LeNet-5神经网络,在现实中没有多大实际作用,此篇笔记将介绍如何将任意一个简单图片中数字在Windows平台上识别出来。
本文不用opencv这类繁重的第三方库,只用基本系统函数和硬件加速库处理图片,以便降低学习难度,降低系统复杂性,提高运行速度和提供更轻量部署。

基本原理

由于上文中提到的训练好的网络,需要提供1x1x28x28的张量输入,每个像素范围为[-1.0, 1.0](-1.0为白色,1.0为黑色),所以需要对图片做如下处理:

图像识别数米粒去除不必要边框 识别图像中的数字_神经网络


加载图片,将图片的pixel format转化为易于做硬件加速处理的pixel format,比如Pre-multiplexed B8G8R8A8,基于硬件加速做图片的缩放,并转化为28x28的图片,然后将图片从彩色转化为单色,也就是3通道转化为1通道的图像,然后进行normalize处理,并输入训练好的网络。

(*) 当然如果能对背景噪音进行处理,效果会更好。

准备工作

智能指针

主要使用<wrl/client.h>中定义的Microsoft::WRL::ComPtr

Windows Image Component

具体请参考Windows Image Component对应的描述,以及参考MSDN提供的例子
需要引用头文件

#include <wincodec.h>
#include <wincodecsdk.h>

并链接library: Windowscodecs.lib

Direct2D

具体参考微软MSDNDirect2D 需要引用头文件:

#include <d3d.h>
#include <d2d1.h>
#include <d2d1_2.h>

需要链接library: D2d1.lib

开始代码

具体网络构建,训练,训练结果保存请参考上一篇文章,这里就不再列出

#include <stdio.h>
#include <torch/torch.h>
#include <torch/data/datasets/mnist.h>
#include <iostream>
#include <tuple>
#include <chrono>

#define NOMINMAX

#include <wincodec.h>
#include <wincodecsdk.h>
#include <wrl/client.h>
#include <d3d.h>
#include <d2d1.h>
#include <d2d1_2.h>
#include <shlwapi.h>

using namespace Microsoft::WRL;

void FreeBlob(void* p)
{
	printf("%s(), %s: %d.\n", __FUNCTION__, __FILE__, __LINE__);
	free(p);
}
// 省略上节中提到代码
int main()
{
	auto tm_start = std::chrono::system_clock::now();
	auto tm_end = std::chrono::system_clock::now();
	
	HRESULT hr = S_OK;
	ComPtr<ID2D1Factory> spD2D1Factory;					// D2D1 factory
	ComPtr<IWICImagingFactory> spWICImageFactory;		// Image codec factory
	ComPtr<IWICBitmapDecoder> spDecoder;				// Image decoder
	ComPtr<IWICBitmapFrameDecode> spBitmapFrameDecode;	// Decoded image
	ComPtr<IWICBitmapSource> spConverter;				// Converted image
	ComPtr<ID2D1RenderTarget> pRenderTarget;			// Render target to scale image
	ComPtr<IWICBitmap> spNetInputBitmap;				// The final bitmap 1x28x28
	ComPtr<IWICBitmap> spHandWrittenBitmap;				// The original bitmap
	ComPtr<ID2D1Bitmap> spD2D1Bitmap;					// D2D1 bitmap
	UINT uiFrameCount = 0;
	UINT uiWidth = 0, uiHeight = 0;
	WICPixelFormatGUID pixelFormat;

	tm_start = std::chrono::system_clock::now();

	CoInitializeEx(NULL, COINIT_MULTITHREADED);

	// Create D2D1 factory to create the related render target and D2D1 objects
	D2D1_FACTORY_OPTIONS options;
	ZeroMemory(&options, sizeof(D2D1_FACTORY_OPTIONS));
#if defined(_DEBUG)
	// If the project is in a debug build, enable Direct2D debugging via SDK Layers.
	options.debugLevel = D2D1_DEBUG_LEVEL_INFORMATION;
#endif
	if (FAILED(hr = D2D1CreateFactory(D2D1_FACTORY_TYPE_MULTI_THREADED, 
		__uuidof(ID2D1Factory2), &options, &spD2D1Factory)))
		goto done;

	// Create the image factory
	if (FAILED(hr = CoCreateInstance(CLSID_WICImagingFactory,
		nullptr,
		CLSCTX_INPROC_SERVER,
		IID_IWICImagingFactory,
		(LPVOID*)&spWICImageFactory)))
		goto done;

	// 加载图片I:\\4.png, 并为其创建图像解码器
	if (FAILED(spWICImageFactory->CreateDecoderFromFilename(L"I:\\4.png", NULL, 
		GENERIC_READ, WICDecodeMetadataCacheOnDemand, &spDecoder)))
		goto done;
			
	// 得到多少帧图像在图片文件中,如果无可解帧,结束程序
	if (FAILED(hr = spDecoder->GetFrameCount(&uiFrameCount)) || uiFrameCount == 0)
		goto done;

	// 得到第一帧图片
	if (FAILED(hr = hr = spDecoder->GetFrame(0, &spBitmapFrameDecode)))
		goto done;

	// 得到图片大小
	if (FAILED(hr = spBitmapFrameDecode->GetSize(&uiWidth, &uiHeight)))
		goto done;

	// 得到图片像素格式
	if (FAILED(hr = spBitmapFrameDecode->GetPixelFormat(&pixelFormat)))
		goto done;

	// 如果图片不是Pre-multiplexed BGRA格式,转化成这个格式,以便用D2D硬件处理图形转换
	if (!IsEqualGUID(pixelFormat, GUID_WICPixelFormat32bppPBGRA))
	{
		if (FAILED(hr = WICConvertBitmapSource(GUID_WICPixelFormat32bppPBGRA, 
			spBitmapFrameDecode.Get(), &spConverter)))
			goto done;
	}
	else
		spConverter = spBitmapFrameDecode;

	// 转化为Pre-multiplexed BGRA格式的WICBitmap
	if (FAILED(hr = spWICImageFactory->CreateBitmapFromSource(
		spConverter.Get(), WICBitmapCacheOnDemand, &spHandWrittenBitmap)))
		goto done;

	// 创建一个Pre-multiplexed BGRA的28x28的WICBitmap
	if (FAILED(hr = spWICImageFactory->CreateBitmap(28, 28, GUID_WICPixelFormat32bppPBGRA, 
		WICBitmapCacheOnDemand, &spNetInputBitmap)))
		goto done;

	// 在此WICBitmap上创建D2D1 Render Target
	{
		D2D1_RENDER_TARGET_PROPERTIES props = D2D1::RenderTargetProperties(D2D1_RENDER_TARGET_TYPE_DEFAULT,
			D2D1::PixelFormat(DXGI_FORMAT_B8G8R8A8_UNORM, D2D1_ALPHA_MODE_PREMULTIPLIED), 96, 96);
		if (FAILED(hr = spD2D1Factory->CreateWicBitmapRenderTarget(spNetInputBitmap.Get(), props, &pRenderTarget)))
			goto done;
	}

	// 将转化为Pre-multiplexed BGRA格式的WICBitmap的原始图片转换到D2D1Bitmap对象中来,以便后面的缩放处理
	if (FAILED(hr = pRenderTarget->CreateBitmapFromWicBitmap(spHandWrittenBitmap.Get(), &spD2D1Bitmap)))
		goto done;

	// 将图片进行缩放处理,转化为28x28的图片
	{
		pRenderTarget->BeginDraw();
		D2D1_RECT_F dst_rect = { 0, 0, 28, 28 };
		pRenderTarget->DrawBitmap(spD2D1Bitmap.Get(), &dst_rect);
		pRenderTarget->EndDraw();
	}

	//SaveAs(spNetInputBitmap, L"I:\\test.png");

	// 将3x28x28的图片转化为灰度图片
	if (FAILED(hr = WICConvertBitmapSource(GUID_WICPixelFormat8bppGray, spNetInputBitmap.Get(), &spConverter)))
		goto done;

	// 并将灰度图像转化为[-1.0(白色)1.0(黑色)]的raw data
	{
		uint8_t gray8bbp[28 * 28] = { 0 };
		WICRect rect = { 0, 0, 28, 28 };
		hr = spConverter->CopyPixels(&rect, 28, 1*28*28, gray8bbp);

		float* res_data = (float*)malloc(1 * 28 * 28 * sizeof(float));
		for (int i = 0; i < 28; i++)
		{
			for (int j = 0; j < 28; j++)
			{
				//printf("%02X ", gray8bbp[i*28 + j]);
				res_data[i*28 + j] = ((255 - gray8bbp[i * 28 + j]) / 255.0f - 0.5f) / 0.5f;
			}
			//printf("\n");
		}

		// 创建向量,并从上面处理的原始数据中转化为4阶张量
		torch::Tensor res_tensor;
		res_tensor = torch::from_blob(res_data, { 1, 1, 28, 28 }, FreeBlob);

		//std::cout << res_tensor << '\n';

		// 加载训练好的网络
		LeNet5 net1(2);
		torch::serialize::InputArchive archive;
		archive.load_from("I:\\mnist.pt");

		net1.load(archive);

		// 处理数据,并得到预测结果
		auto outputs = net1.forward(res_tensor);
		auto predicted = torch::max(outputs, 1);

		tm_end = std::chrono::system_clock::now();

		printf("predicted label: %d, cost %lld msec.\n", 
			std::get<1>(predicted).item<int>(),
			std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());
	}

done:
	CoUninitialize();
	return 0;
}

代码解读

在上面代码中已经给出了一些具体的注释。

输出结果

假设有这样一副图片4.png需要识别:

图像识别数米粒去除不必要边框 识别图像中的数字_计算机视觉_02


图片是Paint3D画出来的:

图像识别数米粒去除不必要边框 识别图像中的数字_神经网络_03


运行结果为:

图像识别数米粒去除不必要边框 识别图像中的数字_神经网络_04


识别正确,33毫秒加载网络和识别搞定,速度还可以!