相关变量声明:
static final char[] bos = {'\b', 'x'};
标签状态集合:
static final char[] id2tag = new char[]{'b', 'm', 'e', 's'};
/**
* 2阶隐马的三个参数
*/
double l1, l2, l3;
/**
* 频次统计
*/
Probability tf; //用于统计每个
/**
* 让模型观测一个句子, 进行B M E S序列标注
* @param wordList
*/
public void learn(List<Word> wordList)
{
LinkedList<char[]> sentence = new LinkedList<char[]>();
for (IWord iWord : wordList)
{
String word = iWord.getValue();
if (word.length() == 1)
{
sentence.add(new char[]{word.charAt(0), 's'});
}
else
{
sentence.add(new char[]{word.charAt(0), 'b'});
for (int i = 1; i < word.length() - 1; ++i)
{
sentence.add(new char[]{word.charAt(i), 'm'});
}
sentence.add(new char[]{word.charAt(word.length() - 1), 'e'});
}
} // 转换完毕,开始统计
第一个字和第二字有些特殊,它们的前2个状态不存在,会产生。所以使用bos状态替代不存在的状态计算概率:
char[][] now = new char[3][]; // 定长3的队列
now[1] = bos;
now[2] = bos;
tf.add(1, bos, bos);
tf.add(2, bos);
for (char[] i : sentence)
{
System.arraycopy(now, 1, now, 0, 2);
now[2] = i;
tf.add(1, i); // uni 单个字符出现的次数
tf.add(1, now[1], now[2]); // bi 两个字符共同出现的次数
tf.add(1, now); // tri 三个字符共同出现的次数
}
}
/**
* 观测结束,开始训练
*/
public void train()
{
double tl1 = 0.0;
double tl2 = 0.0;
double tl3 = 0.0;
for (String key : tf.d.keySet())
{
if (key.length() != 6) continue; // tri samples
char[][] now = new char[][]
{
{key.charAt(0), key.charAt(1)},
{key.charAt(2), key.charAt(3)},
{key.charAt(4), key.charAt(5)},
};
double c3 = div(tf.get(now) - 1, tf.get(now[0], now[1]) - 1);
//c2 = sum_(t3,t2) / sum_(t2) = P(t3 | t2)
double c2 = div(tf.get(now[1], now[2]) - 1, tf.get(now[1]) - 1);
//c1 = P(t3) = sum_t3 / sum_total
double c1 = div(tf.get(now[2]) - 1, tf.getsum() - 1);
if (c3 >= c1 && c3 >= c2)
tl3 += tf.get(now);
else if (c2 >= c1 && c2 >= c3)
tl2 += tf.get(now);
else if (c1 >= c2 && c1 >= c3)
tl1 += tf.get(now);
}
l1 = div(tl1, tl1 + tl2 + tl3);
l2 = div(tl2, tl1 + tl2 + tl3);
l3 = div(tl3, tl1 + tl2 + tl3);
}
/**
* 求概率
* @param s1 前2个状态
* @param s2 前1个状态
* @param s3 当前状态
* @return 序列的概率
*/
double log_prob(char[] s1, char[] s2, char[] s3)
{
double uni = l1 * tf.freq(s3);
double bi = div(l2 * tf.get(s2, s3), tf.get(s2));
double tri = div(l3 * tf.get(s1, s2, s3), tf.get(s1, s2));
if (uni + bi + tri == 0)
return inf;
return Math.log(uni + bi + tri);
}
/**
* 序列标注
* @param charArray 观测序列, 对于观测序列 求出最有可能的状态标注序列
* @return 标注序列
*/
public char[] tag(char[] charArray)
{
if (charArray.length == 0) return new char[0];
if (charArray.length == 1) return new char[]{'s'}; char[] tag = new char[charArray.length];
double[][] now = new double[4][4];
double[] first = new double[4];
// link[i][s][t] := 第i个节点在前一个t-1状态是s,当前t状态是t时,t-2节点状态的tag的值
int[][][] link = new int[charArray.length][4][4]; // 第一个字,只可能是bs 求首字的标注序列状态概率
for (int s = 0; s < 4; ++s)
{
double p = (s == 1 || s == 2) ? inf : log_prob(bos, bos, new char[]{charArray[0], id2tag[s]});
first[s] = p;
}
// 第二个字,尚不能完全利用TriGram
for (int f = 0; f < 4; ++f)
{
for (int s = 0; s < 4; ++s)
{
double p = first[f] + log_prob(bos, new char[]{charArray[0], id2tag[f]},
);
now[f][s] = p; //首字状态为f,第二个字状态为s的时候,now保存所有可能的状态路径,4 * 4 = 16条路径
link[1][f][s] = f; //link保存了 当前第二个节点状态为s的时候,前一个节点状态为f 一共有16种情况
}
}
// 第三个字开始,利用TriGram标注
double[][] pre = new double[4][4];
for (int i = 2; i < charArray.length; i++)
{
// swap(now, pre)
double[][] _ = pre; //当前第i个字符节点的时候,
pre = now;
now = _;
// end of swap
for (int s = 0; s < 4; ++s)
{
for (int t = 0; t < 4; ++t)
{
now[s][t] = -1e20;
for (int f = 0; f < 4; ++f)
{
double p = pre[f][s] + log_prob(new char[]{charArray[i - 2], id2tag[f]},
new char[]{charArray[i - 1], id2tag[s]},
new char[]{charArray[i], id2tag[t]});//当前节点时,如果该节点状态为t,前面两个节点分别是t-1,t-2,
//pre[f][s]:表示到达t-2时刻状态为f 和t-1时刻状态为s的路径的概率值
//log_prob : 表示t-2 和 t-1时刻 状态分别为f和s的情况下,当前时刻状态为t的转移概率
//转移概率值为P(t3 | t2,t1) = log_prob算出的值(在原有概率计算的基础上取了对数)
//now[s][t] :这样就保存了 到达前一时刻和当前时刻状态分别为s和t的所有路径的概率值
if (p > now[s][t]) //当前节点状态为t, 前一个节点状态为s的时候的最优路径
{
now[s][t] = p;
link[i][s][t] = f; //f就是到达 s,t 两个状态的最优的节点状态,就是t-2时刻的最优状态
} }
}
}
}
double score = inf;
int s = 0;
int t = 0;
for (int i = 0; i < 4; i++) //找出最后两个字的最有可能的状态组合
{
for (int j = 0; j < 4; j++)
{
if (now[i][j] > score)
{
score = now[i][j];
s = i;
t = j;
}
}
}
//t:表示最末尾的字的最优状态
//s:表示当最末尾的字状态为t时,前一个节点的字的最有可能的状态
//这样依次从最后面往前推,不断更新t和s,找出最优的那条路径
for (int i = link.length - 1; i >= 0; --i)
{
tag[i] = id2tag[t];
int f = link[i][s][t]; //当确定最后一个字的状态为t,倒数第二个字最优状态为s时,通过link找出倒数第三个字的最优状态f
t = s; //序列依次从后往前推,不断更新t,s,然后获取f
s = f;
}
return tag;
}