Mean shift 算法是基于核密度估计的爬山算法,可用于聚类、图像分割、跟踪等,其在声呐图像数据处理也有广泛的应用,笔者在网上找了一遍也没有找到关于Mean shift的matlab实现代码,找到的都是关于它的文字描述,无奈笔者只能根据网上找到的文字描述自己动手编写相关的matlab代码,现分享给大家。
1、均值漂移的基本形式
对于N维空间中给定的点集
,则对于空间中的任意点
与点集
中距离小于r的点
的mean shift向量为:
,
而漂移的过程,就是通过计算偏移量,然后不断的更新球心的位置,更新公式为:
直到偏移量的值很小时停止更新。
2、mean shift算法流程文字描述
假设多维空间中的数据点类别数未知,选定搜素半径r,执行如下步骤:
1、在未被标记的数据点中随机选择一个点作为中心
;2、找出所有离
距离小于r的点,记作集合M,并认为这些点属于类别c,同时将这些点在类别c上的访问次数加1;3、以
为中心点,计算
到集合M中每个元素的向量,将这些向量相加,得到漂移向量
。4、更新中心点,
。表示
沿着方向
移动了距离
。5、重复步骤2-4,直到
的大小很小,小于设置的阈值后,停止迭代,记住此时的
,在这个迭代过程中的遇到的所有的点都属于类别c。6、如果收敛时当前的类别c的中心于之前已经存在的类别
的中心小于阈值,那么当前的c应该和
属于同一类,并合并成
,否则把c作为新的类别,增加一类。
7、重复1-6直到所有的点都被标记访问。
8、分类:根据每个点找出其被访问次数最多的那一类,并将其归属到此类中。
以上就是均值漂移聚类算法流程。
3、mean shift算法matlab实现
下面笔者将给出均值漂移算法的matlab程序:
function [out,category] = mean_shift(radius,threshould,data)
% 均值漂移聚类分析
% 输入参数
% radius 聚类半径
% data K-by-N k个N维数据点集
% 输出参数
%%
r2 = radius*radius;
threshould2 = threshould*threshould;
[k,N] = size(data);
access_cnt = zeros(k,1); %每个点被不同类访问次数计数
center = data(1,:);
dir = zeros(k,N);
cluster_cnt = 1;
density_l = 0;
cnt = 0;
theta = (0:1:360)/180*pi;
circle_x = radius*cos(theta);
circle_y = radius*sin(theta);
figure;
h = axes;
plot(h,data(:,1),data(:,2),'k.');
hold(h,'on');grid(h,'on');
while 1
cnt = cnt + 1;
for i = 1:N
dir(:,i) = data(:,i)-center(i);
end
dis = sum(dir.^2,2); %按行求和
indx = find(dis < r2); %找到半径r内的数据点
density = length(indx);
shift = sum(dir(indx,:))/density; %求飘移值
access_cnt(indx,cluster_cnt) = access_cnt(indx,cluster_cnt) + 1; %当前类访问次数累加
% if cnt > 1
% delete(h1);
% delete(h2);
% end
h1 = plot(h,circle_x+center(1),circle_y+center(2),'g');
h2 = plot(h,data(indx,1),data(indx,2),'r.');
% if shift*shift' < threshould2 %判断是否满足停止收敛条件
if density_l >= density
density_l = 0;
if cluster_cnt == 1
out(cluster_cnt,:) = center;
else
dir_t = out;
for kk = 1:cluster_cnt-1
dir_t(kk,:) = out(kk,:)-center; %将当前的收敛中心于之前的计算距离
end
dis_t = sum(dir_t.^2,2);
[min_dis,min_indx] = min(dis_t);
if min_dis < threshould2 %判断当前的中心离之前已有的中心距离是否小于阈值
access_cnt(:,min_indx)= access_cnt(:,min_indx) + access_cnt(:,cluster_cnt);
access_cnt(:,cluster_cnt) = 0; %清零之前的分类访问
cluster_cnt = cluster_cnt - 1;
else
out(cluster_cnt,:) = center;
end
end
cluster_cnt = cluster_cnt +1; %类别计数
acc_cnt_p = sum(access_cnt,2); %求每个点已被访问的次数
no_acc_p = find(acc_cnt_p == 0); %找出还没有被访问的点
if size(no_acc_p,1) > 0
center = data(no_acc_p(1),:); %初始化成没有被访问点
else
break;
end
if size(access_cnt,2) < cluster_cnt %判断有没有新增的类,有的话添加一类
access_cnt = [access_cnt,zeros(k,1)];
end
else
density_l = density;
center = center + shift; %更新中心值
pause(0.02);
end
end
category = zeros(k,1);
for kk = 1:k
[max_acc,max_indx] = max(access_cnt(kk,:)); %找出当前点的最大访问次数及其类别
category(kk) = max_indx; %将对应的点表上类别
end
笔者生成了一组二维的数据点进行了测试,代码如下:
clear all;
close all;
clc;
%%
num = 500;
radius = 3;
threshould = 0.2;
data1 = [randn(num,1),randn(num,1)];
data2 = [randn(num,1)+6,randn(num,1)+6];
data3 = [randn(num,1)-6,randn(num,1)+6];
data4 = [randn(num,1)-6,randn(num,1)-6];
data5 = [randn(num,1)+6,randn(num,1)-6];
data = [data1;data2;data3;data4;data5];
[out,category] = mean_shift(radius, threshould, data);
category_num = size(out,1);
figure;
plot(data(:,1),data(:,2),'k.');
hold on; grid on;
plot(out(:,1),out(:,2),'r*');
if category_num == 1
plot(data(:,1),data(:,2),'c.');
elseif category_num == 2
category1 = data(find(category == 1),:);
category2 = data(find(category == 2),:);
plot(category1(:,1),category1(:,2),'c.');
plot(category2(:,1),category2(:,2),'g.');
elseif category_num == 3
category1 = data(find(category == 1),:);
category2 = data(find(category == 2),:);
category3 = data(find(category == 3),:);
plot(category1(:,1),category1(:,2),'c.');
plot(category2(:,1),category2(:,2),'g.');
plot(category3(:,1),category3(:,2),'y.');
elseif category_num == 4
category1 = data(find(category == 1),:);
category2 = data(find(category == 2),:);
category3 = data(find(category == 3),:);
category4 = data(find(category == 4),:);
plot(category1(:,1),category1(:,2),'c.');
plot(category2(:,1),category2(:,2),'g.');
plot(category3(:,1),category3(:,2),'y.');
plot(category4(:,1),category4(:,2),'b.');
else
category1 = data(find(category == 1),:);
category2 = data(find(category == 2),:);
category3 = data(find(category == 3),:);
category4 = data(find(category == 4),:);
category5 = data(find(category == 5),:);
plot(category1(:,1),category1(:,2),'c.');
plot(category2(:,1),category2(:,2),'g.');
plot(category3(:,1),category3(:,2),'y.');
plot(category4(:,1),category4(:,2),'b.');
plot(category5(:,1),category5(:,2),'k.');
end
运行一下代码出现如下结果:
图1
图2
图1为算法运行时,数据点被访问的过程图。图2位分类的结果图,可以看到一共有5类数据被准确的分出来。