Mean shift 算法是基于核密度估计的爬山算法,可用于聚类、图像分割、跟踪等,其在声呐图像数据处理也有广泛的应用,笔者在网上找了一遍也没有找到关于Mean shift的matlab实现代码,找到的都是关于它的文字描述,无奈笔者只能根据网上找到的文字描述自己动手编写相关的matlab代码,现分享给大家。

1、均值漂移的基本形式

对于N维空间中给定的点集

基于均值漂移图像分割代码 python 均值漂移聚类matlab_mean shift

,则对于空间中的任意点

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_02

与点集

基于均值漂移图像分割代码 python 均值漂移聚类matlab_mean shift

中距离小于r的点

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_04

的mean shift向量为:

基于均值漂移图像分割代码 python 均值漂移聚类matlab_算法_05

,     

基于均值漂移图像分割代码 python 均值漂移聚类matlab_聚类_06

而漂移的过程,就是通过计算偏移量,然后不断的更新球心的位置,更新公式为:

                                          

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_07

直到偏移量的值很小时停止更新。

2、mean shift算法流程文字描述

假设多维空间中的数据点类别数未知,选定搜素半径r,执行如下步骤:

1、在未被标记的数据点中随机选择一个点作为中心

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_08

;2、找出所有离

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_08

距离小于r的点,记作集合M,并认为这些点属于类别c,同时将这些点在类别c上的访问次数加1;3、以

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_08

为中心点,计算

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_08

到集合M中每个元素的向量,将这些向量相加,得到漂移向量

基于均值漂移图像分割代码 python 均值漂移聚类matlab_matlab_12

。4、更新中心点,

基于均值漂移图像分割代码 python 均值漂移聚类matlab_聚类_13

。表示

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_08

沿着方向

基于均值漂移图像分割代码 python 均值漂移聚类matlab_matlab_12

移动了距离

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_16

。5、重复步骤2-4,直到

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_16

的大小很小,小于设置的阈值后,停止迭代,记住此时的

基于均值漂移图像分割代码 python 均值漂移聚类matlab_基于均值漂移图像分割代码 python_08

,在这个迭代过程中的遇到的所有的点都属于类别c。6、如果收敛时当前的类别c的中心于之前已经存在的类别

基于均值漂移图像分割代码 python 均值漂移聚类matlab_mean shift_19

的中心小于阈值,那么当前的c应该和

基于均值漂移图像分割代码 python 均值漂移聚类matlab_mean shift_19

属于同一类,并合并成

基于均值漂移图像分割代码 python 均值漂移聚类matlab_mean shift_19

,否则把c作为新的类别,增加一类。

7、重复1-6直到所有的点都被标记访问。

8、分类:根据每个点找出其被访问次数最多的那一类,并将其归属到此类中。

以上就是均值漂移聚类算法流程。

3、mean shift算法matlab实现

下面笔者将给出均值漂移算法的matlab程序:

function [out,category] = mean_shift(radius,threshould,data)
% 均值漂移聚类分析
% 输入参数
% radius    聚类半径
% data      K-by-N    k个N维数据点集
% 输出参数
%%
r2 = radius*radius;
threshould2 = threshould*threshould;
[k,N] = size(data);
access_cnt = zeros(k,1);  %每个点被不同类访问次数计数
center = data(1,:);
dir = zeros(k,N);
cluster_cnt = 1;
density_l = 0;
cnt = 0;

theta = (0:1:360)/180*pi;
circle_x = radius*cos(theta);
circle_y = radius*sin(theta);

figure;
h = axes;
plot(h,data(:,1),data(:,2),'k.');
hold(h,'on');grid(h,'on');

while 1
    cnt = cnt + 1;
    
    for i = 1:N
        dir(:,i) = data(:,i)-center(i);
    end
    dis = sum(dir.^2,2);  %按行求和

    indx = find(dis < r2);  %找到半径r内的数据点
    density = length(indx);
    shift = sum(dir(indx,:))/density;  %求飘移值
    access_cnt(indx,cluster_cnt) = access_cnt(indx,cluster_cnt) + 1; %当前类访问次数累加
    
%     if cnt > 1
%         delete(h1);
%         delete(h2);
%     end
    h1 = plot(h,circle_x+center(1),circle_y+center(2),'g');
    h2 = plot(h,data(indx,1),data(indx,2),'r.');
    
%      if shift*shift' < threshould2   %判断是否满足停止收敛条件
    if density_l >= density
        density_l = 0;
        if cluster_cnt == 1
            out(cluster_cnt,:) = center;
        else
            dir_t = out;
            for kk = 1:cluster_cnt-1
                dir_t(kk,:) = out(kk,:)-center;   %将当前的收敛中心于之前的计算距离
            end
            dis_t = sum(dir_t.^2,2);
            [min_dis,min_indx] = min(dis_t);
            if min_dis < threshould2        %判断当前的中心离之前已有的中心距离是否小于阈值
                access_cnt(:,min_indx)= access_cnt(:,min_indx) + access_cnt(:,cluster_cnt);
                access_cnt(:,cluster_cnt) = 0;   %清零之前的分类访问
                cluster_cnt = cluster_cnt - 1;
            else
                out(cluster_cnt,:) = center;
            end
        end
        cluster_cnt = cluster_cnt +1;      %类别计数
        acc_cnt_p = sum(access_cnt,2);     %求每个点已被访问的次数
        no_acc_p = find(acc_cnt_p == 0);   %找出还没有被访问的点
        if size(no_acc_p,1) > 0
            center = data(no_acc_p(1),:);  %初始化成没有被访问点
        else
            break;
        end
        if size(access_cnt,2) < cluster_cnt      %判断有没有新增的类,有的话添加一类
            access_cnt = [access_cnt,zeros(k,1)];
        end
     else
        density_l = density;
        center = center + shift;  %更新中心值
        pause(0.02);        
    end
end

category = zeros(k,1);
for kk = 1:k
    [max_acc,max_indx] = max(access_cnt(kk,:));  %找出当前点的最大访问次数及其类别
    category(kk) = max_indx;                     %将对应的点表上类别
end

 笔者生成了一组二维的数据点进行了测试,代码如下:

clear all;
close all;
clc;
%%
num = 500;
radius = 3;
threshould = 0.2;
data1 = [randn(num,1),randn(num,1)];
data2 = [randn(num,1)+6,randn(num,1)+6];
data3 = [randn(num,1)-6,randn(num,1)+6];
data4 = [randn(num,1)-6,randn(num,1)-6];
data5 = [randn(num,1)+6,randn(num,1)-6];
data = [data1;data2;data3;data4;data5];
[out,category] = mean_shift(radius, threshould, data);

category_num = size(out,1);
figure;
plot(data(:,1),data(:,2),'k.');
hold on; grid on;
plot(out(:,1),out(:,2),'r*');
if category_num == 1
    plot(data(:,1),data(:,2),'c.');
elseif category_num == 2
    category1 = data(find(category == 1),:);
    category2 = data(find(category == 2),:);
    plot(category1(:,1),category1(:,2),'c.');
    plot(category2(:,1),category2(:,2),'g.');
elseif category_num == 3
    category1 = data(find(category == 1),:);
    category2 = data(find(category == 2),:);
    category3 = data(find(category == 3),:);
    plot(category1(:,1),category1(:,2),'c.');
    plot(category2(:,1),category2(:,2),'g.');    
    plot(category3(:,1),category3(:,2),'y.'); 
elseif category_num == 4
    category1 = data(find(category == 1),:);
    category2 = data(find(category == 2),:);
    category3 = data(find(category == 3),:);
    category4 = data(find(category == 4),:);
    plot(category1(:,1),category1(:,2),'c.');
    plot(category2(:,1),category2(:,2),'g.');    
    plot(category3(:,1),category3(:,2),'y.');     
    plot(category4(:,1),category4(:,2),'b.'); 
else
    category1 = data(find(category == 1),:);
    category2 = data(find(category == 2),:);
    category3 = data(find(category == 3),:);
    category4 = data(find(category == 4),:);
    category5 = data(find(category == 5),:);
    plot(category1(:,1),category1(:,2),'c.');
    plot(category2(:,1),category2(:,2),'g.');    
    plot(category3(:,1),category3(:,2),'y.');     
    plot(category4(:,1),category4(:,2),'b.');    
    plot(category5(:,1),category5(:,2),'k.');
end

 运行一下代码出现如下结果:

基于均值漂移图像分割代码 python 均值漂移聚类matlab_聚类_22

图1

基于均值漂移图像分割代码 python 均值漂移聚类matlab_matlab_23

 

图2

图1为算法运行时,数据点被访问的过程图。图2位分类的结果图,可以看到一共有5类数据被准确的分出来。