DTW 是 Dynamic Time Warping,可以动态扭曲时间轴,来计算两个不同长度的序列之间的相似性。

百科中关于原理的叙述



如果路径已经通过了格点(n ,m ),那么下一个通过的格点(n ,m )只可能是下列三种情况之一:
  (n ,m )=(n +1,m +2)
  (n ,m )=(n +1,m +1)
  (n ,m )=(n +1,m )



错误的,其实应该是

如果路径已经通过了格点(n ,m ),那么下一个通过的格点(n ,m )只可能是下列三种情况之一: 
  
  (n ,m )=(n ,m +1) 
  
  (n ,m )=(n +1,m +1) 
  
  (n ,m )=(n +1,m )

 

算法原理的叙述错误,导致了后面matlab程序的错误,更改后的写法如下:



function dist = dtw(cols,rows)
%% 直接距离矩阵
colSize = max(size(cols));
rowSize = max(size(rows));

% 帧匹配距离矩阵
dirtDist = zeros(rowSize,colSize);
for i=1:rowSize
    for j= 1:colSize
        dirtDist(i,j)=(rows(i)-cols(j))^2;
    end
end

%% 累积距离矩阵
accumDist = zeros(rowSize,colSize);
% DTW的第一行
accumDist(1,:) = cumsum(dirtDist(1,:));

%% 动态规划
for i = 2:rowSize
    for j = 1:colSize
        upper = accumDist(i-1,j);
        if j>1
            left = accumDist(i,j-1);
        else
            left = realmax;
        end
        if j>1
            diag = accumDist(i-1,j-1);
        else
            diag = realmax;
        end
        accumDist(i,j) = dirtDist(i,j) + min([upper,left,diag]);
    end
end

%% 最后的值
dist = sqrt(accumDist(end));


 

下面给出Java的代码实现:

 

package cn.edu.xjtu.utils;

import java.util.List;

/**
 * DTW(Danymic Time Wrapping,动态时间弯曲)算法的Java实现<br>
 * 
 * DTW距离包括2部分:A)直接距离 B)累加距离,在本实现中直接距离采用的是Euclidean距离
 * 
 * @author lenovo
 */
public class JavaDTWOpti {

	private double[][] distanceMatrix;

	private List<Double> aList;
	private List<Double> bList;

	private void spreadCalc(int startColumn, int startRow, int endColumn,
			int endRow) {

		distanceMatrix = new double[bList.size()][aList.size()];

		if (!(startColumn < endColumn) && !(startRow < endRow)) {
			throw new IllegalArgumentException("Error Occure!");
		}

		// 将所有的数值转换成java的从0开始
		startColumn = startColumn - 1;
		startRow = startRow - 1;
		endColumn = endColumn - 1;
		endRow = endRow - 1;

		// DTW距离第一部分:直接距离,这里计算采用Euclidean距离
		double diredist = 0.0;
		// DTW距离第二部分:累积距离
		double acumdist = 0.0;
		boolean stopIndex = false;

		do {
			// 计算某一顶点开始 横向 的所有DTW距离
			for (int i = startColumn; i < endColumn + 1; i++) {
				diredist = aList.get(i) - bList.get(startRow);
				acumdist = minAcumDist(startRow, i);
				distanceMatrix[startRow][i] = diredist * diredist + acumdist;
			}
			// 计算某一顶点开始 纵向 的所有DTW距离
			for (int i = startRow + 1; i < endRow + 1; i++) {
				diredist = aList.get(startColumn) - bList.get(i);
				acumdist = minAcumDist(i, startColumn);
				distanceMatrix[i][startColumn] = diredist * diredist + acumdist;
			}
			// 下一次推进的顶点
			startColumn = startColumn + 1;
			startRow = startRow + 1;

			// 达到边界后,最多再推进一次;
			if (stopIndex) {
				break;
			}

			// 因为DTW经常处理的是两不同长度的向量,所以在推进的过程中会遇到长度问题
			// 横向达到边界
			if (!(startColumn < endColumn)) {
				stopIndex = true;
			}
			// 纵向达到边界
			if (!(startRow < endRow)) {
				stopIndex = true;
			}

		} while (true);

	}

	private double[][] calcDTWDistance(int rowIndex, int columnIndex) {

		spreadCalc(1, 1, columnIndex, rowIndex);

		return distanceMatrix;
	}

	/**
	 * 返回距离矩阵,其中的值都是平方后的,方便查看,以备后续处理<br>
	 * aList={1,2,3},b={4,5,6,7},则矩阵是<br>
	 * 9.0000 13.0000 14.0000 <br>
	 * 25.0000 18.0000 17.0000 <br>
	 * 50.0000 34.0000 26.0000 <br>
	 * 86.0000 59.0000 42.0000 <br>
	 * 
	 * @note 该方法当且仅当距离矩阵为空的时候,才会调用计算方法;否则,该方法只是返回上一次计算结果
	 * 
	 * @see calcDtwValue,主动计算距离矩阵
	 * 
	 * @return 距离平方的二维矩阵,行号是bList的元素确定,列号由aList的元素确定
	 */
	public double[][] getDistanceMatrix() {

		if (this.distanceMatrix == null) {
			calcDTWDistance(bList.size(), aList.size());
		}

		return this.distanceMatrix;
	}

	/**
	 * 打印出距离矩阵,横向是aList,纵向是bList<br>
	 * aList={1,2,3},b={4,5,6,7},则打印出来的是<br>
	 * 9.0000 13.0000 14.0000 <br>
	 * 25.0000 18.0000 17.0000 <br>
	 * 50.0000 34.0000 26.0000 <br>
	 * 86.0000 59.0000 42.0000 <br>
	 * 
	 * @note 该方法不会调用计算方法;使用前请主动调用calcDtwValue计算距离矩阵
	 * 
	 * @see calcDtwValue
	 */
	public void printDistanceMatrix() {
		for (int i = 0; i < bList.size(); i++) {
			for (int k = 0; k < aList.size(); k++) {
				System.out.printf("%10.4f", distanceMatrix[i][k]);
				System.out.print("\t");
			}
			System.out.println();
		}
	}

	/**
	 * 计算向量aList和bList的DTW距离
	 * 
	 * @return DTW距离,开平方后的
	 */
	public double calcDtwValue() {

		this.calcDTWDistance(bList.size(), aList.size());

		return Math.sqrt(distanceMatrix[bList.size() - 1][aList.size() - 1]);
	}

	private double minAcumDist(int rowIndex, int columnIndex) {

		double pre = getNumber(rowIndex - 1, columnIndex - 1);
		double left = getNumber(rowIndex, columnIndex - 1);
		double up = getNumber(rowIndex - 1, columnIndex);

		return min(pre, left, up);
	}

	private double getNumber(int rowIndex, int columnIndex) {
		if (rowIndex >= 0 && columnIndex >= 0) {
			return distanceMatrix[rowIndex][columnIndex];
		} else if (rowIndex == -1 && columnIndex == -1) {
			return 0.0;
		} else {
			return Double.MAX_VALUE;
		}
	}

	private double min(double... inputs) {

		double minValue = inputs[0];
		for (double each : inputs) {
			minValue = Math.min(minValue, each);
		}
		return minValue;
	}

	public List<Double> getaList() {
		return aList;
	}

	public void setaList(List<Double> aList) {
		this.aList = aList;
	}

	public List<Double> getbList() {
		return bList;
	}

	public void setbList(List<Double> bList) {
		this.bList = bList;
	}

}