1. 背景
OpenCV
提供了基于像素的模板匹配
函数matchTemplte,但是该函数不支持带角度的匹配,而且如果使用函数中的mask
参数,结果可能偏离预期的结果。
2. 模板训练
通过对模板模板进行角度旋转,获取不同角度下的旋转图像与旋转掩膜图像。然后分别以此旋转图像作为模板进行匹配,获取最优结果作为匹配结果。
// 定义轮廓的类型的别名
typedef std::vector<std::vector<cv::Point>> CV_CONTOURS;
// 显示图像
static void showImage(cv::Mat& image, const std::string& window_name, int timeout = 0, int width=800, int height = 600)
{
cv::namedWindow(window_name, cv::WINDOW_NORMAL);
cv::resizeWindow(window_name, cv::Size(width, height));
cv::imshow(window_name, image);
cv::waitKey(timeout);
}
struct RotatedTemplate
{
double rotation; // 旋转角度
cv::Mat image; // 模板图像
cv::Mat mask; // 掩膜图像
CV_CONTOURS contours; // 模板轮廓
double score; // 匹配得分
cv::Point position; // 匹配位置
};
/// <summary>
/// 获取图像的中心,以零为起始索引
/// 如果宽度或高度为奇数,则中心为 (width / 2) 或 (height / 2)
/// 如果宽度或高度为偶数,则中心为 (width / 2) - 1 或 (height / 2) - 1
/// </summary>
void get_image_center(const cv::Mat& src, int& row, int& col)
{
col = src.cols % 2 == 1 ? src.cols / 2 : src.cols / 2 - 1;
row = src.rows % 2 == 1 ? src.rows / 2 : src.rows / 2 - 1;
}
/// <summary>
/// 训练模板
/// </summary>
void train(const std::string& file_name, std::vector<RotatedTemplate>& templates)
{
// 读取模板图像
cv::Mat src = cv::imread(file_name, cv::IMREAD_ANYCOLOR);
//showImage(src, "src");
// 灰度化
cv::Mat gray;
cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);
// 二值化
cv::Mat binary;
cv::threshold(gray, binary, 200, 255, cv::THRESH_BINARY_INV);
//showImage(binary, "binary");
// 外轮廓提取
CV_CONTOURS contours;
cv::findContours(binary, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
cv::drawContours(src, contours, -1, cv::Scalar(0, 255, 0), 3);
//showImage(src, "contours");
// 掩模
cv::Rect rect = cv::boundingRect(contours[0]);
cv::Mat mask = cv::Mat::zeros(gray.size(), CV_8U);
cv::rectangle(mask, rect, cv::Scalar(255), -1);
//showImage(mask, "mask");
// 创建多角度的模板集合
templates.clear();
for (int rotation = -45; rotation <= 45; rotation += 5)
{
cv::Mat rotated_gray;
cv::Point2f center(gray.cols * 0.5f, gray.rows * 0.5f);
cv::Mat rot_mat = cv::getRotationMatrix2D(center, rotation, 1.0);
cv::warpAffine(gray, rotated_gray, rot_mat, gray.size());
cv::Mat rotated_mask;
cv::warpAffine(mask, rotated_mask, rot_mat, gray.size());
// 映射轮廓
CV_CONTOURS output_contours;
output_contours.resize(1);
cv::transform(contours[0], output_contours[0], rot_mat);
// 轮廓及其属性显示
cv::Mat rotated_color;
cv::cvtColor(rotated_gray, rotated_color, cv::COLOR_GRAY2BGR);
cv::drawContours(rotated_color, output_contours, -1, cv::Scalar(0, 255, 0), 3);
//showImage(rotated_color, "rotated_color", 300);
//showImage(rotated_mask, "rotated_mask", 300);
// 轮廓修正到以模板中心坐标为原点
int row, col;
get_image_center(rotated_gray, row, col);
for (size_t i = 0; i < output_contours[0].size(); i++)
{
output_contours[0][i].x -= col;
output_contours[0][i].y -= row;
}
RotatedTemplate templ;
templ.image = rotated_gray;
templ.mask = rotated_mask;
templ.rotation = rotation;
templ.contours = output_contours;
templates.push_back(templ);
}
cv::destroyAllWindows();
}
3. 模板查找
在输入图像中水平/垂直移动,每次移动一个像素,每次获取与模板图像相同尺寸的子区域,并用该子区域执行匹配度的计算。
// 定义函数指针
typedef double (*p_function)(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask);
/// <summary>
/// 方差匹配法
/// </summary>
double sqdiff(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
double result = 0.0;
for (int row = 0; row < src.rows; row++)
{
const uchar* src_row = src.ptr(row);
const uchar* maks_row = mask.ptr(row);
const uchar* templ_row = templ.ptr(row);
for (int col = 0; col < src.cols; col++)
{
if (maks_row[col])
{
double diff_pixel = src_row[col] - templ_row[col];
result += diff_pixel * diff_pixel;
}
}
}
return result;
}
/// <summary>
/// 归一化方差匹配法
/// </summary>
double sqdiff_normed(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
double result = 0.0;
double sum_src = 0.0, sum_templ = 0.0, sum_diff = 0.0;
for (int row = 0; row < src.rows; row++)
{
const uchar* src_row = src.ptr(row);
const uchar* maks_row = mask.ptr(row);
const uchar* templ_row = templ.ptr(row);
for (int col = 0; col < src.cols; col++)
{
if (maks_row[col])
{
double pixel_src = src_row[col];
double pixel_templ = templ_row[col];
double pixel_diff = pixel_src - pixel_templ;
sum_src += pixel_src * pixel_src;
sum_templ += pixel_templ * pixel_templ;
sum_diff += pixel_diff * pixel_diff;
}
}
}
result = sum_diff / std::sqrt(sum_src * sum_templ);
return result;
}
/// <summary>
/// 相关性匹配法
/// </summary>
double ccorr(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
double result = 0.0;
for (int row = 0; row < src.rows; row++)
{
const uchar* src_row = src.ptr(row);
const uchar* maks_row = mask.ptr(row);
const uchar* templ_row = templ.ptr(row);
for (int col = 0; col < src.cols; col++)
{
if (maks_row[col])
{
result += src_row[col] * templ_row[col];
}
}
}
return result;
}
/// <summary>
/// 归一化互相关匹配法
/// </summary>
double ccorr_normed(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
double result = 0.0;
double sum_src = 0.0, sum_templ = 0.0, sum_multi = 0.0;
for (int row = 0; row < src.rows; row++)
{
const uchar* src_row = src.ptr(row);
const uchar* maks_row = mask.ptr(row);
const uchar* templ_row = templ.ptr(row);
for (int col = 0; col < src.cols; col++)
{
if (maks_row[col])
{
double pixel_src = src_row[col];
double pixel_templ = templ_row[col];
sum_src += pixel_src * pixel_src;
sum_templ += pixel_templ * pixel_templ;
sum_multi += pixel_src * pixel_templ;
}
}
}
result = sum_multi / std::sqrt(sum_src * sum_templ);
return result;
}
/// <summary>
/// 匹配函数
/// </summary>
double match(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask, p_function p_func, bool is_show = false)
{
if (is_show)
{
cv::Mat concat_mat;
cv::hconcat(src, templ, concat_mat);
cv::hconcat(concat_mat, mask, concat_mat);
showImage(concat_mat, "concat_mat", 10, 2400, 600);
}
return p_func(src, templ, mask);
}
/// <summary>
/// 在单个图像中查找单个模板
/// </summary>
void find(const cv::Mat& src, const cv::Mat& templ, cv::Mat& mask, cv::Mat& result)
{
for (int row = 0; row < result.rows; row++)
{
for (int col = 0; col < result.cols; col++)
{
cv::Range row_range(row, row + templ.rows);
cv::Range col_range(col, col + templ.cols);
cv::Mat subMat = src(row_range, col_range);
double val = match(subMat, templ, mask, &sqdiff);
result.at<double>(row, col) = val;
}
}
}
int main(int argc, char** argv)
{
cv::utils::logging::setLogLevel(cv::utils::logging::LOG_LEVEL_SILENT);
std::vector<RotatedTemplate> templates;
train("../../template.png", templates);
std::vector<std::string> file_names;
file_names.push_back("../../sample.png");
file_names.push_back("../../sample0.png");
file_names.push_back("../../sample1.png");
file_names.push_back("../../sample2.png");
// 遍历所有图像
for (std::string& file_name : file_names)
{
// 读取彩色图像
std::cout << "The path of image is : " << file_name << std::endl;
cv::Mat src = cv::imread(file_name, cv::IMREAD_ANYCOLOR);
// 灰度化
cv::Mat gray;
cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);
// 初始化输出,多个模板
int result_row = src.cols - templates[0].image.cols + 1;
int result_col = src.rows - templates[0].image.rows + 1;
cv::Mat result_multi(cv::Size(result_row, result_col), CV_64FC1, cv::Scalar(DBL_MAX));
// 遍历所有模板
for (RotatedTemplate& templ : templates)
{
std::cout << "The rotation of templ is : " << templ.rotation << std::endl;
cv::Mat result_single(result_multi.size(), CV_64FC1, cv::Scalar(DBL_MAX));
find(gray, templ.image, templ.mask, result_single);
double minVal = DBL_MAX;
cv::Point minLoc;
cv::minMaxLoc(result_single, &minVal, NULL, &minLoc, NULL);
templ.score = minVal;
templ.position = minLoc;
result_multi.at<double>(templ.position.y, templ.position.x) = templ.score;
}
double minVal = DBL_MAX;
cv::Point minLoc;
cv::minMaxLoc(result_multi, &minVal, NULL, &minLoc, NULL);
for (RotatedTemplate& templ : templates)
{
int templ_center_row, templ_center_col;
get_image_center(templ.image, templ_center_row, templ_center_col);
if (templ.position == minLoc && std::abs(templ.score - minVal) < 1e-3)
{
std::cout << "The matched position is : " << minLoc << std::endl;
std::cout << "The matched rotation is : " << templ.rotation << std::endl;
//修正轮廓坐标
std::vector<std::vector<cv::Point>> contours(templ.contours);
for (size_t i = 0; i < contours[0].size(); i++)
{
contours[0][i] = contours[0][i] + minLoc+cv::Point(templ_center_col, templ_center_row);
}
cv::drawContours(src, contours, -1, cv::Scalar(0, 255, 0), 3);
cv::circle(src, minLoc+cv::Point(templ_center_col, templ_center_row), 10, cv::Scalar(0, 255, 0), -1);
showImage(src, "result");
}
}
}
system("PAUSE");
cv::destroyAllWindows();
return EXIT_SUCCESS;
}
4. 完整源码
#include <opencv2\opencv.hpp>
#include <opencv2\highgui\highgui.hpp>
#include <opencv2\core\utils\logger.hpp>
#include <iostream>
#include <string>
#include <cstdio>
#include <limits.h>
#include <algorithm>
#include <limits>
// 定义函数指针
typedef double (*p_function)(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask);
// 定义轮廓的类型的别名
typedef std::vector<std::vector<cv::Point>> CV_CONTOURS;
// 显示图像
static void showImage(cv::Mat& image, const std::string& window_name, int timeout = 0, int width=800, int height = 600)
{
cv::namedWindow(window_name, cv::WINDOW_NORMAL);
cv::resizeWindow(window_name, cv::Size(width, height));
cv::imshow(window_name, image);
cv::waitKey(timeout);
}
struct RotatedTemplate
{
double rotation; // 旋转角度
cv::Mat image; // 模板图像
cv::Mat mask; // 掩膜图像
CV_CONTOURS contours; // 模板轮廓
double score; // 匹配得分
cv::Point position; // 匹配位置
};
/// <summary>
/// 获取图像的中心,以零为起始索引
/// 如果宽度或高度为奇数,则中心为 (width / 2) 或 (height / 2)
/// 如果宽度或高度为偶数,则中心为 (width / 2) - 1 或 (height / 2) - 1
/// </summary>
void get_image_center(const cv::Mat& src, int& row, int& col)
{
col = src.cols % 2 == 1 ? src.cols / 2 : src.cols / 2 - 1;
row = src.rows % 2 == 1 ? src.rows / 2 : src.rows / 2 - 1;
}
/// <summary>
/// 训练模板
/// </summary>
void train(const std::string& file_name, std::vector<RotatedTemplate>& templates)
{
// 读取模板图像
cv::Mat src = cv::imread(file_name, cv::IMREAD_ANYCOLOR);
//showImage(src, "src");
// 灰度化
cv::Mat gray;
cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);
// 二值化
cv::Mat binary;
cv::threshold(gray, binary, 200, 255, cv::THRESH_BINARY_INV);
//showImage(binary, "binary");
// 外轮廓提取
CV_CONTOURS contours;
cv::findContours(binary, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
cv::drawContours(src, contours, -1, cv::Scalar(0, 255, 0), 3);
//showImage(src, "contours");
// 掩模
cv::Rect rect = cv::boundingRect(contours[0]);
cv::Mat mask = cv::Mat::zeros(gray.size(), CV_8U);
cv::rectangle(mask, rect, cv::Scalar(255), -1);
//showImage(mask, "mask");
// 创建多角度的模板集合
templates.clear();
for (int rotation = -45; rotation <= 45; rotation += 5)
{
cv::Mat rotated_gray;
cv::Point2f center(gray.cols * 0.5f, gray.rows * 0.5f);
cv::Mat rot_mat = cv::getRotationMatrix2D(center, rotation, 1.0);
cv::warpAffine(gray, rotated_gray, rot_mat, gray.size());
cv::Mat rotated_mask;
cv::warpAffine(mask, rotated_mask, rot_mat, gray.size());
// 映射轮廓
CV_CONTOURS output_contours;
output_contours.resize(1);
cv::transform(contours[0], output_contours[0], rot_mat);
// 轮廓及其属性显示
cv::Mat rotated_color;
cv::cvtColor(rotated_gray, rotated_color, cv::COLOR_GRAY2BGR);
cv::drawContours(rotated_color, output_contours, -1, cv::Scalar(0, 255, 0), 3);
//showImage(rotated_color, "rotated_color", 300);
//showImage(rotated_mask, "rotated_mask", 300);
// 轮廓修正到以模板中心坐标为原点
int row, col;
get_image_center(rotated_gray, row, col);
for (size_t i = 0; i < output_contours[0].size(); i++)
{
output_contours[0][i].x -= col;
output_contours[0][i].y -= row;
}
RotatedTemplate templ;
templ.image = rotated_gray;
templ.mask = rotated_mask;
templ.rotation = rotation;
templ.contours = output_contours;
templates.push_back(templ);
}
cv::destroyAllWindows();
}
/// <summary>
/// 方差匹配法
/// </summary>
double sqdiff(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
double result = 0.0;
for (int row = 0; row < src.rows; row++)
{
const uchar* src_row = src.ptr(row);
const uchar* maks_row = mask.ptr(row);
const uchar* templ_row = templ.ptr(row);
for (int col = 0; col < src.cols; col++)
{
if (maks_row[col])
{
double diff_pixel = src_row[col] - templ_row[col];
result += diff_pixel * diff_pixel;
}
}
}
return result;
}
/// <summary>
/// 归一化方差匹配法
/// </summary>
double sqdiff_normed(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
double result = 0.0;
double sum_src = 0.0, sum_templ = 0.0, sum_diff = 0.0;
for (int row = 0; row < src.rows; row++)
{
const uchar* src_row = src.ptr(row);
const uchar* maks_row = mask.ptr(row);
const uchar* templ_row = templ.ptr(row);
for (int col = 0; col < src.cols; col++)
{
if (maks_row[col])
{
double pixel_src = src_row[col];
double pixel_templ = templ_row[col];
double pixel_diff = pixel_src - pixel_templ;
sum_src += pixel_src * pixel_src;
sum_templ += pixel_templ * pixel_templ;
sum_diff += pixel_diff * pixel_diff;
}
}
}
result = sum_diff / std::sqrt(sum_src * sum_templ);
return result;
}
/// <summary>
/// 相关性匹配法
/// </summary>
double ccorr(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
double result = 0.0;
for (int row = 0; row < src.rows; row++)
{
const uchar* src_row = src.ptr(row);
const uchar* maks_row = mask.ptr(row);
const uchar* templ_row = templ.ptr(row);
for (int col = 0; col < src.cols; col++)
{
if (maks_row[col])
{
result += src_row[col] * templ_row[col];
}
}
}
return result;
}
/// <summary>
/// 归一化互相关匹配法
/// </summary>
double ccorr_normed(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
double result = 0.0;
double sum_src = 0.0, sum_templ = 0.0, sum_multi = 0.0;
for (int row = 0; row < src.rows; row++)
{
const uchar* src_row = src.ptr(row);
const uchar* maks_row = mask.ptr(row);
const uchar* templ_row = templ.ptr(row);
for (int col = 0; col < src.cols; col++)
{
if (maks_row[col])
{
double pixel_src = src_row[col];
double pixel_templ = templ_row[col];
sum_src += pixel_src * pixel_src;
sum_templ += pixel_templ * pixel_templ;
sum_multi += pixel_src * pixel_templ;
}
}
}
result = sum_multi / std::sqrt(sum_src * sum_templ);
return result;
}
/// <summary>
/// 匹配函数
/// </summary>
double match(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask, p_function p_func, bool is_show = false)
{
if (is_show)
{
cv::Mat concat_mat;
cv::hconcat(src, templ, concat_mat);
cv::hconcat(concat_mat, mask, concat_mat);
showImage(concat_mat, "concat_mat", 10, 2400, 600);
}
return p_func(src, templ, mask);
}
/// <summary>
/// 在单个图像中查找单个模板
/// </summary>
void find(const cv::Mat& src, const cv::Mat& templ, cv::Mat& mask, cv::Mat& result)
{
for (int row = 0; row < result.rows; row++)
{
for (int col = 0; col < result.cols; col++)
{
cv::Range row_range(row, row + templ.rows);
cv::Range col_range(col, col + templ.cols);
cv::Mat subMat = src(row_range, col_range);
double val = match(subMat, templ, mask, &sqdiff);
result.at<double>(row, col) = val;
}
}
}
int main(int argc, char** argv)
{
cv::utils::logging::setLogLevel(cv::utils::logging::LOG_LEVEL_SILENT);
std::vector<RotatedTemplate> templates;
train("../../template.png", templates);
std::vector<std::string> file_names;
file_names.push_back("../../sample.png");
file_names.push_back("../../sample0.png");
file_names.push_back("../../sample1.png");
file_names.push_back("../../sample2.png");
// 遍历所有图像
for (std::string& file_name : file_names)
{
// 读取彩色图像
std::cout << "The path of image is : " << file_name << std::endl;
cv::Mat src = cv::imread(file_name, cv::IMREAD_ANYCOLOR);
// 灰度化
cv::Mat gray;
cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);
// 初始化输出,多个模板
int result_row = src.cols - templates[0].image.cols + 1;
int result_col = src.rows - templates[0].image.rows + 1;
cv::Mat result_multi(cv::Size(result_row, result_col), CV_64FC1, cv::Scalar(DBL_MAX));
// 遍历所有模板
for (RotatedTemplate& templ : templates)
{
std::cout << "The rotation of templ is : " << templ.rotation << std::endl;
cv::Mat result_single(result_multi.size(), CV_64FC1, cv::Scalar(DBL_MAX));
find(gray, templ.image, templ.mask, result_single);
double minVal = DBL_MAX;
cv::Point minLoc;
cv::minMaxLoc(result_single, &minVal, NULL, &minLoc, NULL);
templ.score = minVal;
templ.position = minLoc;
result_multi.at<double>(templ.position.y, templ.position.x) = templ.score;
}
double minVal = DBL_MAX;
cv::Point minLoc;
cv::minMaxLoc(result_multi, &minVal, NULL, &minLoc, NULL);
for (RotatedTemplate& templ : templates)
{
int templ_center_row, templ_center_col;
get_image_center(templ.image, templ_center_row, templ_center_col);
if (templ.position == minLoc && std::abs(templ.score - minVal) < 1e-3)
{
std::cout << "The matched position is : " << minLoc << std::endl;
std::cout << "The matched rotation is : " << templ.rotation << std::endl;
//修正轮廓坐标
std::vector<std::vector<cv::Point>> contours(templ.contours);
for (size_t i = 0; i < contours[0].size(); i++)
{
contours[0][i] = contours[0][i] + minLoc+cv::Point(templ_center_col, templ_center_row);
}
cv::drawContours(src, contours, -1, cv::Scalar(0, 255, 0), 3);
cv::circle(src, minLoc+cv::Point(templ_center_col, templ_center_row), 10, cv::Scalar(0, 255, 0), -1);
showImage(src, "result");
}
}
}
system("PAUSE");
cv::destroyAllWindows();
return EXIT_SUCCESS;
}