DTW 是 Dynamic Time Warping,可以动态扭曲时间轴,来计算两个不同长度的序列之间的相似性。
百科中关于原理的叙述
如果路径已经通过了格点(n ,m ),那么下一个通过的格点(n ,m )只可能是下列三种情况之一:
(n ,m )=(n +1,m +2)
(n ,m )=(n +1,m +1)
(n ,m )=(n +1,m )
错误的,其实应该是
如果路径已经通过了格点(n ,m ),那么下一个通过的格点(n ,m )只可能是下列三种情况之一:
(n ,m )=(n ,m +1)
(n ,m )=(n +1,m +1)
(n ,m )=(n +1,m )
算法原理的叙述错误,导致了后面matlab程序的错误,更改后的写法如下:
function dist = dtw(cols,rows)
%% 直接距离矩阵
colSize = max(size(cols));
rowSize = max(size(rows));
% 帧匹配距离矩阵
dirtDist = zeros(rowSize,colSize);
for i=1:rowSize
for j= 1:colSize
dirtDist(i,j)=(rows(i)-cols(j))^2;
end
end
%% 累积距离矩阵
accumDist = zeros(rowSize,colSize);
% DTW的第一行
accumDist(1,:) = cumsum(dirtDist(1,:));
%% 动态规划
for i = 2:rowSize
for j = 1:colSize
upper = accumDist(i-1,j);
if j>1
left = accumDist(i,j-1);
else
left = realmax;
end
if j>1
diag = accumDist(i-1,j-1);
else
diag = realmax;
end
accumDist(i,j) = dirtDist(i,j) + min([upper,left,diag]);
end
end
%% 最后的值
dist = sqrt(accumDist(end));
下面给出Java的代码实现:
package cn.edu.xjtu.utils;
import java.util.List;
/**
* DTW(Danymic Time Wrapping,动态时间弯曲)算法的Java实现<br>
*
* DTW距离包括2部分:A)直接距离 B)累加距离,在本实现中直接距离采用的是Euclidean距离
*
* @author lenovo
*/
public class JavaDTWOpti {
private double[][] distanceMatrix;
private List<Double> aList;
private List<Double> bList;
private void spreadCalc(int startColumn, int startRow, int endColumn,
int endRow) {
distanceMatrix = new double[bList.size()][aList.size()];
if (!(startColumn < endColumn) && !(startRow < endRow)) {
throw new IllegalArgumentException("Error Occure!");
}
// 将所有的数值转换成java的从0开始
startColumn = startColumn - 1;
startRow = startRow - 1;
endColumn = endColumn - 1;
endRow = endRow - 1;
// DTW距离第一部分:直接距离,这里计算采用Euclidean距离
double diredist = 0.0;
// DTW距离第二部分:累积距离
double acumdist = 0.0;
boolean stopIndex = false;
do {
// 计算某一顶点开始 横向 的所有DTW距离
for (int i = startColumn; i < endColumn + 1; i++) {
diredist = aList.get(i) - bList.get(startRow);
acumdist = minAcumDist(startRow, i);
distanceMatrix[startRow][i] = diredist * diredist + acumdist;
}
// 计算某一顶点开始 纵向 的所有DTW距离
for (int i = startRow + 1; i < endRow + 1; i++) {
diredist = aList.get(startColumn) - bList.get(i);
acumdist = minAcumDist(i, startColumn);
distanceMatrix[i][startColumn] = diredist * diredist + acumdist;
}
// 下一次推进的顶点
startColumn = startColumn + 1;
startRow = startRow + 1;
// 达到边界后,最多再推进一次;
if (stopIndex) {
break;
}
// 因为DTW经常处理的是两不同长度的向量,所以在推进的过程中会遇到长度问题
// 横向达到边界
if (!(startColumn < endColumn)) {
stopIndex = true;
}
// 纵向达到边界
if (!(startRow < endRow)) {
stopIndex = true;
}
} while (true);
}
private double[][] calcDTWDistance(int rowIndex, int columnIndex) {
spreadCalc(1, 1, columnIndex, rowIndex);
return distanceMatrix;
}
/**
* 返回距离矩阵,其中的值都是平方后的,方便查看,以备后续处理<br>
* aList={1,2,3},b={4,5,6,7},则矩阵是<br>
* 9.0000 13.0000 14.0000 <br>
* 25.0000 18.0000 17.0000 <br>
* 50.0000 34.0000 26.0000 <br>
* 86.0000 59.0000 42.0000 <br>
*
* @note 该方法当且仅当距离矩阵为空的时候,才会调用计算方法;否则,该方法只是返回上一次计算结果
*
* @see calcDtwValue,主动计算距离矩阵
*
* @return 距离平方的二维矩阵,行号是bList的元素确定,列号由aList的元素确定
*/
public double[][] getDistanceMatrix() {
if (this.distanceMatrix == null) {
calcDTWDistance(bList.size(), aList.size());
}
return this.distanceMatrix;
}
/**
* 打印出距离矩阵,横向是aList,纵向是bList<br>
* aList={1,2,3},b={4,5,6,7},则打印出来的是<br>
* 9.0000 13.0000 14.0000 <br>
* 25.0000 18.0000 17.0000 <br>
* 50.0000 34.0000 26.0000 <br>
* 86.0000 59.0000 42.0000 <br>
*
* @note 该方法不会调用计算方法;使用前请主动调用calcDtwValue计算距离矩阵
*
* @see calcDtwValue
*/
public void printDistanceMatrix() {
for (int i = 0; i < bList.size(); i++) {
for (int k = 0; k < aList.size(); k++) {
System.out.printf("%10.4f", distanceMatrix[i][k]);
System.out.print("\t");
}
System.out.println();
}
}
/**
* 计算向量aList和bList的DTW距离
*
* @return DTW距离,开平方后的
*/
public double calcDtwValue() {
this.calcDTWDistance(bList.size(), aList.size());
return Math.sqrt(distanceMatrix[bList.size() - 1][aList.size() - 1]);
}
private double minAcumDist(int rowIndex, int columnIndex) {
double pre = getNumber(rowIndex - 1, columnIndex - 1);
double left = getNumber(rowIndex, columnIndex - 1);
double up = getNumber(rowIndex - 1, columnIndex);
return min(pre, left, up);
}
private double getNumber(int rowIndex, int columnIndex) {
if (rowIndex >= 0 && columnIndex >= 0) {
return distanceMatrix[rowIndex][columnIndex];
} else if (rowIndex == -1 && columnIndex == -1) {
return 0.0;
} else {
return Double.MAX_VALUE;
}
}
private double min(double... inputs) {
double minValue = inputs[0];
for (double each : inputs) {
minValue = Math.min(minValue, each);
}
return minValue;
}
public List<Double> getaList() {
return aList;
}
public void setaList(List<Double> aList) {
this.aList = aList;
}
public List<Double> getbList() {
return bList;
}
public void setbList(List<Double> bList) {
this.bList = bList;
}
}