1. 背景

OpenCV提供了基于像素的模板匹配函数matchTemplte,但是该函数不支持带角度的匹配,而且如果使用函数中的mask参数,结果可能偏离预期的结果。

2. 模板训练

通过对模板模板进行角度旋转,获取不同角度下的旋转图像与旋转掩膜图像。然后分别以此旋转图像作为模板进行匹配,获取最优结果作为匹配结果。

// 定义轮廓的类型的别名
typedef std::vector<std::vector<cv::Point>> CV_CONTOURS;

// 显示图像
static void showImage(cv::Mat& image, const std::string& window_name, int timeout = 0, int width=800, int height = 600)
{
    cv::namedWindow(window_name, cv::WINDOW_NORMAL);
    cv::resizeWindow(window_name, cv::Size(width, height));
    cv::imshow(window_name, image);
    cv::waitKey(timeout);
}

struct RotatedTemplate
{
    double rotation;            // 旋转角度
    cv::Mat image;              // 模板图像
    cv::Mat mask;               // 掩膜图像
    CV_CONTOURS contours;       // 模板轮廓

    double score;               // 匹配得分
    cv::Point position;         // 匹配位置
};

/// <summary>
/// 获取图像的中心,以零为起始索引
/// 如果宽度或高度为奇数,则中心为 (width / 2) 或 (height / 2)
/// 如果宽度或高度为偶数,则中心为 (width / 2) - 1 或 (height / 2) - 1
/// </summary>
void get_image_center(const cv::Mat& src, int& row, int& col)
{
    col = src.cols % 2 == 1 ? src.cols / 2 : src.cols / 2 - 1;
    row = src.rows % 2 == 1 ? src.rows / 2 : src.rows / 2 - 1;
}

/// <summary>
/// 训练模板
/// </summary>
void train(const std::string& file_name, std::vector<RotatedTemplate>& templates)
{
    // 读取模板图像
    cv::Mat src = cv::imread(file_name, cv::IMREAD_ANYCOLOR);
    //showImage(src, "src");

    // 灰度化
    cv::Mat gray;
    cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);

    // 二值化
    cv::Mat binary;
    cv::threshold(gray, binary, 200, 255, cv::THRESH_BINARY_INV);
    //showImage(binary, "binary");

    // 外轮廓提取
    CV_CONTOURS contours;
    cv::findContours(binary, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
    cv::drawContours(src, contours, -1, cv::Scalar(0, 255, 0), 3);
    //showImage(src, "contours");

    // 掩模
    cv::Rect rect = cv::boundingRect(contours[0]);
    cv::Mat mask = cv::Mat::zeros(gray.size(), CV_8U);
    cv::rectangle(mask, rect, cv::Scalar(255), -1);
    //showImage(mask, "mask");

    // 创建多角度的模板集合
    templates.clear();
    for (int rotation = -45; rotation <= 45; rotation += 5)
    {
        cv::Mat rotated_gray;
        cv::Point2f center(gray.cols * 0.5f, gray.rows * 0.5f);
        cv::Mat rot_mat = cv::getRotationMatrix2D(center, rotation, 1.0);
        cv::warpAffine(gray, rotated_gray, rot_mat, gray.size());

        cv::Mat rotated_mask;
        cv::warpAffine(mask, rotated_mask, rot_mat, gray.size());

        // 映射轮廓
        CV_CONTOURS output_contours;
        output_contours.resize(1);
        cv::transform(contours[0], output_contours[0], rot_mat);

        // 轮廓及其属性显示
        cv::Mat rotated_color;
        cv::cvtColor(rotated_gray, rotated_color, cv::COLOR_GRAY2BGR);
        cv::drawContours(rotated_color, output_contours, -1, cv::Scalar(0, 255, 0), 3);
        //showImage(rotated_color, "rotated_color", 300);
        //showImage(rotated_mask, "rotated_mask", 300);

        // 轮廓修正到以模板中心坐标为原点
        int row, col;
        get_image_center(rotated_gray, row, col);
        for (size_t i = 0; i < output_contours[0].size(); i++)
        {
            output_contours[0][i].x -= col;
            output_contours[0][i].y -= row;
        }

        RotatedTemplate templ;
        templ.image = rotated_gray;
        templ.mask = rotated_mask;
        templ.rotation = rotation;
        templ.contours = output_contours;
        templates.push_back(templ);
    }

    cv::destroyAllWindows();
}

opencv python 实现多角度 模板匹配 opencv 模板匹配带角度_计算机视觉

3. 模板查找

在输入图像中水平/垂直移动,每次移动一个像素,每次获取与模板图像相同尺寸的子区域,并用该子区域执行匹配度的计算。

// 定义函数指针
typedef double (*p_function)(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask);

/// <summary>
/// 方差匹配法
/// </summary>
double sqdiff(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
    double result = 0.0;
    for (int row = 0; row < src.rows; row++)
    {
        const uchar* src_row = src.ptr(row);
        const uchar* maks_row = mask.ptr(row);
        const uchar* templ_row = templ.ptr(row);
        for (int col = 0; col < src.cols; col++)
        {
            if (maks_row[col])
            {
                double diff_pixel = src_row[col] - templ_row[col];
                result += diff_pixel * diff_pixel;
            }
        }
    }

    return result;
}

/// <summary>
/// 归一化方差匹配法
/// </summary>
double sqdiff_normed(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
    double result = 0.0;
    double sum_src = 0.0, sum_templ = 0.0, sum_diff = 0.0;
    for (int row = 0; row < src.rows; row++)
    {
        const uchar* src_row = src.ptr(row);
        const uchar* maks_row = mask.ptr(row);
        const uchar* templ_row = templ.ptr(row);
        for (int col = 0; col < src.cols; col++)
        {
            if (maks_row[col])
            {
                double pixel_src = src_row[col];
                double pixel_templ = templ_row[col];
                double pixel_diff = pixel_src - pixel_templ;

                sum_src += pixel_src * pixel_src;
                sum_templ += pixel_templ * pixel_templ;

                sum_diff += pixel_diff * pixel_diff;
            }
        }
    }

    result = sum_diff / std::sqrt(sum_src * sum_templ);
    return result;
}

/// <summary>
/// 相关性匹配法
/// </summary>
double ccorr(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
    double result = 0.0;
    for (int row = 0; row < src.rows; row++)
    {
        const uchar* src_row = src.ptr(row);
        const uchar* maks_row = mask.ptr(row);
        const uchar* templ_row = templ.ptr(row);
        for (int col = 0; col < src.cols; col++)
        {
            if (maks_row[col])
            {
                result += src_row[col] * templ_row[col];
            }
        }
    }

    return result;
}

/// <summary>
/// 归一化互相关匹配法
/// </summary>
double ccorr_normed(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
    double result = 0.0;
    double sum_src = 0.0, sum_templ = 0.0, sum_multi = 0.0;
    for (int row = 0; row < src.rows; row++)
    {
        const uchar* src_row = src.ptr(row);
        const uchar* maks_row = mask.ptr(row);
        const uchar* templ_row = templ.ptr(row);
        for (int col = 0; col < src.cols; col++)
        {
            if (maks_row[col])
            {
                double pixel_src = src_row[col];
                double pixel_templ = templ_row[col];

                sum_src += pixel_src * pixel_src;
                sum_templ += pixel_templ * pixel_templ;

                sum_multi += pixel_src * pixel_templ;
            }
        }
    }

    result = sum_multi / std::sqrt(sum_src * sum_templ);
    return result;
}

/// <summary>
/// 匹配函数
/// </summary>
double match(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask, p_function p_func, bool is_show = false)
{
    if (is_show)
    {
        cv::Mat concat_mat;
        cv::hconcat(src, templ, concat_mat);
        cv::hconcat(concat_mat, mask, concat_mat);
        showImage(concat_mat, "concat_mat", 10, 2400, 600);
    }

    return p_func(src, templ, mask);
}

/// <summary>
/// 在单个图像中查找单个模板
/// </summary>
void find(const cv::Mat& src, const cv::Mat& templ, cv::Mat& mask, cv::Mat& result)
{
    for (int row = 0; row < result.rows; row++)
    {
        for (int col = 0; col < result.cols; col++)
        {
            cv::Range row_range(row, row + templ.rows);
            cv::Range col_range(col, col + templ.cols);

            cv::Mat subMat = src(row_range, col_range);

            double val = match(subMat, templ, mask, &sqdiff);
            result.at<double>(row, col) = val;
        }
    }
}

int main(int argc, char** argv)
{
    cv::utils::logging::setLogLevel(cv::utils::logging::LOG_LEVEL_SILENT);
    std::vector<RotatedTemplate> templates;
    train("../../template.png", templates);

    std::vector<std::string> file_names;
    file_names.push_back("../../sample.png");
    file_names.push_back("../../sample0.png");
    file_names.push_back("../../sample1.png");
    file_names.push_back("../../sample2.png");

    // 遍历所有图像
    for (std::string& file_name : file_names)
    {
        // 读取彩色图像
        std::cout << "The path of image is : " << file_name << std::endl;
        cv::Mat src = cv::imread(file_name, cv::IMREAD_ANYCOLOR);

        // 灰度化
        cv::Mat gray;
        cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);

        // 初始化输出,多个模板
        int result_row = src.cols - templates[0].image.cols + 1;
        int result_col = src.rows - templates[0].image.rows + 1;
        cv::Mat result_multi(cv::Size(result_row, result_col), CV_64FC1, cv::Scalar(DBL_MAX));

        // 遍历所有模板
        for (RotatedTemplate& templ : templates)
        {
            std::cout << "The rotation of templ is : " << templ.rotation << std::endl;

            cv::Mat result_single(result_multi.size(), CV_64FC1, cv::Scalar(DBL_MAX));
            find(gray, templ.image, templ.mask, result_single);

            double minVal = DBL_MAX;
            cv::Point minLoc;
            cv::minMaxLoc(result_single, &minVal, NULL, &minLoc, NULL);
            templ.score = minVal;
            templ.position = minLoc;

            result_multi.at<double>(templ.position.y, templ.position.x) = templ.score;
        }

        double minVal = DBL_MAX;
        cv::Point minLoc;
        cv::minMaxLoc(result_multi, &minVal, NULL, &minLoc, NULL);

        for (RotatedTemplate& templ : templates)
        {
            int templ_center_row, templ_center_col;
            get_image_center(templ.image, templ_center_row, templ_center_col);
            if (templ.position == minLoc && std::abs(templ.score - minVal) < 1e-3)
            {
                std::cout << "The matched position is : " << minLoc << std::endl;
                std::cout << "The matched rotation is : " << templ.rotation << std::endl;
                //修正轮廓坐标
                std::vector<std::vector<cv::Point>> contours(templ.contours);
                for (size_t i = 0; i < contours[0].size(); i++)
                {
                    contours[0][i] = contours[0][i] + minLoc+cv::Point(templ_center_col, templ_center_row);
                }
                cv::drawContours(src, contours, -1, cv::Scalar(0, 255, 0), 3);
                cv::circle(src, minLoc+cv::Point(templ_center_col, templ_center_row), 10, cv::Scalar(0, 255, 0), -1);
                showImage(src, "result");
            }
        }
    }

    system("PAUSE");
    cv::destroyAllWindows();
    return EXIT_SUCCESS;
}

opencv python 实现多角度 模板匹配 opencv 模板匹配带角度_opencv_02


opencv python 实现多角度 模板匹配 opencv 模板匹配带角度_计算机视觉_03


opencv python 实现多角度 模板匹配 opencv 模板匹配带角度_计算机视觉_04


opencv python 实现多角度 模板匹配 opencv 模板匹配带角度_计算机视觉_05

4. 完整源码

#include <opencv2\opencv.hpp>
#include <opencv2\highgui\highgui.hpp>
#include <opencv2\core\utils\logger.hpp>
#include <iostream>
#include <string>
#include <cstdio>
#include <limits.h>
#include <algorithm>
#include <limits>

// 定义函数指针
typedef double (*p_function)(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask);

// 定义轮廓的类型的别名
typedef std::vector<std::vector<cv::Point>> CV_CONTOURS;

// 显示图像
static void showImage(cv::Mat& image, const std::string& window_name, int timeout = 0, int width=800, int height = 600)
{
    cv::namedWindow(window_name, cv::WINDOW_NORMAL);
    cv::resizeWindow(window_name, cv::Size(width, height));
    cv::imshow(window_name, image);
    cv::waitKey(timeout);
}

struct RotatedTemplate
{
    double rotation;            // 旋转角度
    cv::Mat image;              // 模板图像
    cv::Mat mask;               // 掩膜图像
    CV_CONTOURS contours;       // 模板轮廓

    double score;               // 匹配得分
    cv::Point position;         // 匹配位置
};

/// <summary>
/// 获取图像的中心,以零为起始索引
/// 如果宽度或高度为奇数,则中心为 (width / 2) 或 (height / 2)
/// 如果宽度或高度为偶数,则中心为 (width / 2) - 1 或 (height / 2) - 1
/// </summary>
void get_image_center(const cv::Mat& src, int& row, int& col)
{
    col = src.cols % 2 == 1 ? src.cols / 2 : src.cols / 2 - 1;
    row = src.rows % 2 == 1 ? src.rows / 2 : src.rows / 2 - 1;
}

/// <summary>
/// 训练模板
/// </summary>
void train(const std::string& file_name, std::vector<RotatedTemplate>& templates)
{
    // 读取模板图像
    cv::Mat src = cv::imread(file_name, cv::IMREAD_ANYCOLOR);
    //showImage(src, "src");

    // 灰度化
    cv::Mat gray;
    cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);

    // 二值化
    cv::Mat binary;
    cv::threshold(gray, binary, 200, 255, cv::THRESH_BINARY_INV);
    //showImage(binary, "binary");

    // 外轮廓提取
    CV_CONTOURS contours;
    cv::findContours(binary, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_SIMPLE);
    cv::drawContours(src, contours, -1, cv::Scalar(0, 255, 0), 3);
    //showImage(src, "contours");

    // 掩模
    cv::Rect rect = cv::boundingRect(contours[0]);
    cv::Mat mask = cv::Mat::zeros(gray.size(), CV_8U);
    cv::rectangle(mask, rect, cv::Scalar(255), -1);
    //showImage(mask, "mask");

    // 创建多角度的模板集合
    templates.clear();
    for (int rotation = -45; rotation <= 45; rotation += 5)
    {
        cv::Mat rotated_gray;
        cv::Point2f center(gray.cols * 0.5f, gray.rows * 0.5f);
        cv::Mat rot_mat = cv::getRotationMatrix2D(center, rotation, 1.0);
        cv::warpAffine(gray, rotated_gray, rot_mat, gray.size());

        cv::Mat rotated_mask;
        cv::warpAffine(mask, rotated_mask, rot_mat, gray.size());

        // 映射轮廓
        CV_CONTOURS output_contours;
        output_contours.resize(1);
        cv::transform(contours[0], output_contours[0], rot_mat);

        // 轮廓及其属性显示
        cv::Mat rotated_color;
        cv::cvtColor(rotated_gray, rotated_color, cv::COLOR_GRAY2BGR);
        cv::drawContours(rotated_color, output_contours, -1, cv::Scalar(0, 255, 0), 3);
        //showImage(rotated_color, "rotated_color", 300);
        //showImage(rotated_mask, "rotated_mask", 300);

        // 轮廓修正到以模板中心坐标为原点
        int row, col;
        get_image_center(rotated_gray, row, col);
        for (size_t i = 0; i < output_contours[0].size(); i++)
        {
            output_contours[0][i].x -= col;
            output_contours[0][i].y -= row;
        }

        RotatedTemplate templ;
        templ.image = rotated_gray;
        templ.mask = rotated_mask;
        templ.rotation = rotation;
        templ.contours = output_contours;
        templates.push_back(templ);
    }

    cv::destroyAllWindows();
}

/// <summary>
/// 方差匹配法
/// </summary>
double sqdiff(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
    double result = 0.0;
    for (int row = 0; row < src.rows; row++)
    {
        const uchar* src_row = src.ptr(row);
        const uchar* maks_row = mask.ptr(row);
        const uchar* templ_row = templ.ptr(row);
        for (int col = 0; col < src.cols; col++)
        {
            if (maks_row[col])
            {
                double diff_pixel = src_row[col] - templ_row[col];
                result += diff_pixel * diff_pixel;
            }
        }
    }

    return result;
}

/// <summary>
/// 归一化方差匹配法
/// </summary>
double sqdiff_normed(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
    double result = 0.0;
    double sum_src = 0.0, sum_templ = 0.0, sum_diff = 0.0;
    for (int row = 0; row < src.rows; row++)
    {
        const uchar* src_row = src.ptr(row);
        const uchar* maks_row = mask.ptr(row);
        const uchar* templ_row = templ.ptr(row);
        for (int col = 0; col < src.cols; col++)
        {
            if (maks_row[col])
            {
                double pixel_src = src_row[col];
                double pixel_templ = templ_row[col];
                double pixel_diff = pixel_src - pixel_templ;

                sum_src += pixel_src * pixel_src;
                sum_templ += pixel_templ * pixel_templ;

                sum_diff += pixel_diff * pixel_diff;
            }
        }
    }

    result = sum_diff / std::sqrt(sum_src * sum_templ);
    return result;
}

/// <summary>
/// 相关性匹配法
/// </summary>
double ccorr(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
    double result = 0.0;
    for (int row = 0; row < src.rows; row++)
    {
        const uchar* src_row = src.ptr(row);
        const uchar* maks_row = mask.ptr(row);
        const uchar* templ_row = templ.ptr(row);
        for (int col = 0; col < src.cols; col++)
        {
            if (maks_row[col])
            {
                result += src_row[col] * templ_row[col];
            }
        }
    }

    return result;
}

/// <summary>
/// 归一化互相关匹配法
/// </summary>
double ccorr_normed(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask)
{
    double result = 0.0;
    double sum_src = 0.0, sum_templ = 0.0, sum_multi = 0.0;
    for (int row = 0; row < src.rows; row++)
    {
        const uchar* src_row = src.ptr(row);
        const uchar* maks_row = mask.ptr(row);
        const uchar* templ_row = templ.ptr(row);
        for (int col = 0; col < src.cols; col++)
        {
            if (maks_row[col])
            {
                double pixel_src = src_row[col];
                double pixel_templ = templ_row[col];

                sum_src += pixel_src * pixel_src;
                sum_templ += pixel_templ * pixel_templ;

                sum_multi += pixel_src * pixel_templ;
            }
        }
    }

    result = sum_multi / std::sqrt(sum_src * sum_templ);
    return result;
}

/// <summary>
/// 匹配函数
/// </summary>
double match(const cv::Mat& src, const cv::Mat& templ, const cv::Mat& mask, p_function p_func, bool is_show = false)
{
    if (is_show)
    {
        cv::Mat concat_mat;
        cv::hconcat(src, templ, concat_mat);
        cv::hconcat(concat_mat, mask, concat_mat);
        showImage(concat_mat, "concat_mat", 10, 2400, 600);
    }

    return p_func(src, templ, mask);
}

/// <summary>
/// 在单个图像中查找单个模板
/// </summary>
void find(const cv::Mat& src, const cv::Mat& templ, cv::Mat& mask, cv::Mat& result)
{
    for (int row = 0; row < result.rows; row++)
    {
        for (int col = 0; col < result.cols; col++)
        {
            cv::Range row_range(row, row + templ.rows);
            cv::Range col_range(col, col + templ.cols);

            cv::Mat subMat = src(row_range, col_range);

            double val = match(subMat, templ, mask, &sqdiff);
            result.at<double>(row, col) = val;
        }
    }
}

int main(int argc, char** argv)
{
    cv::utils::logging::setLogLevel(cv::utils::logging::LOG_LEVEL_SILENT);
    std::vector<RotatedTemplate> templates;
    train("../../template.png", templates);

    std::vector<std::string> file_names;
    file_names.push_back("../../sample.png");
    file_names.push_back("../../sample0.png");
    file_names.push_back("../../sample1.png");
    file_names.push_back("../../sample2.png");

    // 遍历所有图像
    for (std::string& file_name : file_names)
    {
        // 读取彩色图像
        std::cout << "The path of image is : " << file_name << std::endl;
        cv::Mat src = cv::imread(file_name, cv::IMREAD_ANYCOLOR);

        // 灰度化
        cv::Mat gray;
        cv::cvtColor(src, gray, cv::COLOR_BGR2GRAY);

        // 初始化输出,多个模板
        int result_row = src.cols - templates[0].image.cols + 1;
        int result_col = src.rows - templates[0].image.rows + 1;
        cv::Mat result_multi(cv::Size(result_row, result_col), CV_64FC1, cv::Scalar(DBL_MAX));

        // 遍历所有模板
        for (RotatedTemplate& templ : templates)
        {
            std::cout << "The rotation of templ is : " << templ.rotation << std::endl;

            cv::Mat result_single(result_multi.size(), CV_64FC1, cv::Scalar(DBL_MAX));
            find(gray, templ.image, templ.mask, result_single);

            double minVal = DBL_MAX;
            cv::Point minLoc;
            cv::minMaxLoc(result_single, &minVal, NULL, &minLoc, NULL);
            templ.score = minVal;
            templ.position = minLoc;

            result_multi.at<double>(templ.position.y, templ.position.x) = templ.score;
        }

        double minVal = DBL_MAX;
        cv::Point minLoc;
        cv::minMaxLoc(result_multi, &minVal, NULL, &minLoc, NULL);

        for (RotatedTemplate& templ : templates)
        {
            int templ_center_row, templ_center_col;
            get_image_center(templ.image, templ_center_row, templ_center_col);
            if (templ.position == minLoc && std::abs(templ.score - minVal) < 1e-3)
            {
                std::cout << "The matched position is : " << minLoc << std::endl;
                std::cout << "The matched rotation is : " << templ.rotation << std::endl;
                //修正轮廓坐标
                std::vector<std::vector<cv::Point>> contours(templ.contours);
                for (size_t i = 0; i < contours[0].size(); i++)
                {
                    contours[0][i] = contours[0][i] + minLoc+cv::Point(templ_center_col, templ_center_row);
                }
                cv::drawContours(src, contours, -1, cv::Scalar(0, 255, 0), 3);
                cv::circle(src, minLoc+cv::Point(templ_center_col, templ_center_row), 10, cv::Scalar(0, 255, 0), -1);
                showImage(src, "result");
            }
        }
    }

    system("PAUSE");
    cv::destroyAllWindows();
    return EXIT_SUCCESS;
}