识别各类简单图片中的数字
前言
在上一篇笔记,介绍了如何用MNIST训练和测试数据库来训练和测试LeNet-5神经网络,在现实中没有多大实际作用,此篇笔记将介绍如何将任意一个简单图片中数字在Windows平台上识别出来。
本文不用opencv这类繁重的第三方库,只用基本系统函数和硬件加速库处理图片,以便降低学习难度,降低系统复杂性,提高运行速度和提供更轻量部署。
基本原理
由于上文中提到的训练好的网络,需要提供1x1x28x28的张量输入,每个像素范围为[-1.0, 1.0](-1.0为白色,1.0为黑色),所以需要对图片做如下处理:
加载图片,将图片的pixel format转化为易于做硬件加速处理的pixel format,比如Pre-multiplexed B8G8R8A8,基于硬件加速做图片的缩放,并转化为28x28的图片,然后将图片从彩色转化为单色,也就是3通道转化为1通道的图像,然后进行normalize处理,并输入训练好的网络。
(*) 当然如果能对背景噪音进行处理,效果会更好。
准备工作
智能指针
主要使用<wrl/client.h>中定义的Microsoft::WRL::ComPtr
Windows Image Component
具体请参考Windows Image Component对应的描述,以及参考MSDN提供的例子。
需要引用头文件
#include <wincodec.h>
#include <wincodecsdk.h>
并链接library: Windowscodecs.lib
Direct2D
具体参考微软MSDNDirect2D 需要引用头文件:
#include <d3d.h>
#include <d2d1.h>
#include <d2d1_2.h>
需要链接library: D2d1.lib
开始代码
具体网络构建,训练,训练结果保存请参考上一篇文章,这里就不再列出
#include <stdio.h>
#include <torch/torch.h>
#include <torch/data/datasets/mnist.h>
#include <iostream>
#include <tuple>
#include <chrono>
#define NOMINMAX
#include <wincodec.h>
#include <wincodecsdk.h>
#include <wrl/client.h>
#include <d3d.h>
#include <d2d1.h>
#include <d2d1_2.h>
#include <shlwapi.h>
using namespace Microsoft::WRL;
void FreeBlob(void* p)
{
printf("%s(), %s: %d.\n", __FUNCTION__, __FILE__, __LINE__);
free(p);
}
// 省略上节中提到代码
int main()
{
auto tm_start = std::chrono::system_clock::now();
auto tm_end = std::chrono::system_clock::now();
HRESULT hr = S_OK;
ComPtr<ID2D1Factory> spD2D1Factory; // D2D1 factory
ComPtr<IWICImagingFactory> spWICImageFactory; // Image codec factory
ComPtr<IWICBitmapDecoder> spDecoder; // Image decoder
ComPtr<IWICBitmapFrameDecode> spBitmapFrameDecode; // Decoded image
ComPtr<IWICBitmapSource> spConverter; // Converted image
ComPtr<ID2D1RenderTarget> pRenderTarget; // Render target to scale image
ComPtr<IWICBitmap> spNetInputBitmap; // The final bitmap 1x28x28
ComPtr<IWICBitmap> spHandWrittenBitmap; // The original bitmap
ComPtr<ID2D1Bitmap> spD2D1Bitmap; // D2D1 bitmap
UINT uiFrameCount = 0;
UINT uiWidth = 0, uiHeight = 0;
WICPixelFormatGUID pixelFormat;
tm_start = std::chrono::system_clock::now();
CoInitializeEx(NULL, COINIT_MULTITHREADED);
// Create D2D1 factory to create the related render target and D2D1 objects
D2D1_FACTORY_OPTIONS options;
ZeroMemory(&options, sizeof(D2D1_FACTORY_OPTIONS));
#if defined(_DEBUG)
// If the project is in a debug build, enable Direct2D debugging via SDK Layers.
options.debugLevel = D2D1_DEBUG_LEVEL_INFORMATION;
#endif
if (FAILED(hr = D2D1CreateFactory(D2D1_FACTORY_TYPE_MULTI_THREADED,
__uuidof(ID2D1Factory2), &options, &spD2D1Factory)))
goto done;
// Create the image factory
if (FAILED(hr = CoCreateInstance(CLSID_WICImagingFactory,
nullptr,
CLSCTX_INPROC_SERVER,
IID_IWICImagingFactory,
(LPVOID*)&spWICImageFactory)))
goto done;
// 加载图片I:\\4.png, 并为其创建图像解码器
if (FAILED(spWICImageFactory->CreateDecoderFromFilename(L"I:\\4.png", NULL,
GENERIC_READ, WICDecodeMetadataCacheOnDemand, &spDecoder)))
goto done;
// 得到多少帧图像在图片文件中,如果无可解帧,结束程序
if (FAILED(hr = spDecoder->GetFrameCount(&uiFrameCount)) || uiFrameCount == 0)
goto done;
// 得到第一帧图片
if (FAILED(hr = hr = spDecoder->GetFrame(0, &spBitmapFrameDecode)))
goto done;
// 得到图片大小
if (FAILED(hr = spBitmapFrameDecode->GetSize(&uiWidth, &uiHeight)))
goto done;
// 得到图片像素格式
if (FAILED(hr = spBitmapFrameDecode->GetPixelFormat(&pixelFormat)))
goto done;
// 如果图片不是Pre-multiplexed BGRA格式,转化成这个格式,以便用D2D硬件处理图形转换
if (!IsEqualGUID(pixelFormat, GUID_WICPixelFormat32bppPBGRA))
{
if (FAILED(hr = WICConvertBitmapSource(GUID_WICPixelFormat32bppPBGRA,
spBitmapFrameDecode.Get(), &spConverter)))
goto done;
}
else
spConverter = spBitmapFrameDecode;
// 转化为Pre-multiplexed BGRA格式的WICBitmap
if (FAILED(hr = spWICImageFactory->CreateBitmapFromSource(
spConverter.Get(), WICBitmapCacheOnDemand, &spHandWrittenBitmap)))
goto done;
// 创建一个Pre-multiplexed BGRA的28x28的WICBitmap
if (FAILED(hr = spWICImageFactory->CreateBitmap(28, 28, GUID_WICPixelFormat32bppPBGRA,
WICBitmapCacheOnDemand, &spNetInputBitmap)))
goto done;
// 在此WICBitmap上创建D2D1 Render Target
{
D2D1_RENDER_TARGET_PROPERTIES props = D2D1::RenderTargetProperties(D2D1_RENDER_TARGET_TYPE_DEFAULT,
D2D1::PixelFormat(DXGI_FORMAT_B8G8R8A8_UNORM, D2D1_ALPHA_MODE_PREMULTIPLIED), 96, 96);
if (FAILED(hr = spD2D1Factory->CreateWicBitmapRenderTarget(spNetInputBitmap.Get(), props, &pRenderTarget)))
goto done;
}
// 将转化为Pre-multiplexed BGRA格式的WICBitmap的原始图片转换到D2D1Bitmap对象中来,以便后面的缩放处理
if (FAILED(hr = pRenderTarget->CreateBitmapFromWicBitmap(spHandWrittenBitmap.Get(), &spD2D1Bitmap)))
goto done;
// 将图片进行缩放处理,转化为28x28的图片
{
pRenderTarget->BeginDraw();
D2D1_RECT_F dst_rect = { 0, 0, 28, 28 };
pRenderTarget->DrawBitmap(spD2D1Bitmap.Get(), &dst_rect);
pRenderTarget->EndDraw();
}
//SaveAs(spNetInputBitmap, L"I:\\test.png");
// 将3x28x28的图片转化为灰度图片
if (FAILED(hr = WICConvertBitmapSource(GUID_WICPixelFormat8bppGray, spNetInputBitmap.Get(), &spConverter)))
goto done;
// 并将灰度图像转化为[-1.0(白色)1.0(黑色)]的raw data
{
uint8_t gray8bbp[28 * 28] = { 0 };
WICRect rect = { 0, 0, 28, 28 };
hr = spConverter->CopyPixels(&rect, 28, 1*28*28, gray8bbp);
float* res_data = (float*)malloc(1 * 28 * 28 * sizeof(float));
for (int i = 0; i < 28; i++)
{
for (int j = 0; j < 28; j++)
{
//printf("%02X ", gray8bbp[i*28 + j]);
res_data[i*28 + j] = ((255 - gray8bbp[i * 28 + j]) / 255.0f - 0.5f) / 0.5f;
}
//printf("\n");
}
// 创建向量,并从上面处理的原始数据中转化为4阶张量
torch::Tensor res_tensor;
res_tensor = torch::from_blob(res_data, { 1, 1, 28, 28 }, FreeBlob);
//std::cout << res_tensor << '\n';
// 加载训练好的网络
LeNet5 net1(2);
torch::serialize::InputArchive archive;
archive.load_from("I:\\mnist.pt");
net1.load(archive);
// 处理数据,并得到预测结果
auto outputs = net1.forward(res_tensor);
auto predicted = torch::max(outputs, 1);
tm_end = std::chrono::system_clock::now();
printf("predicted label: %d, cost %lld msec.\n",
std::get<1>(predicted).item<int>(),
std::chrono::duration_cast<std::chrono::milliseconds>(tm_end - tm_start).count());
}
done:
CoUninitialize();
return 0;
}
代码解读
在上面代码中已经给出了一些具体的注释。
输出结果
假设有这样一副图片4.png需要识别:
图片是Paint3D画出来的:
运行结果为:
识别正确,33毫秒加载网络和识别搞定,速度还可以!