CRF:条件随机场,一种机器学习技术。给定一组输入随机变量条件下,另一组输出随机变量的条件概率分布模型。
以一组词性标注为例,给定输入X={我,喜欢,学习},那么输出为Y={名词,动词,名词}的概率应该为最大。输入序列X又称为观测序列,输出序列Y又称为状态序列。这个状态序列构成马尔可夫随机场,所以根据观测序列,得出状态序列的概率就包括,前一个状态转化为后一状态的概率(即转移概率)和状态变量到观测变量的概率(即发射概率)。
CRF分词原理
1. CRF把分词当做字的词位分类问题,通常定义字的词位信息如下:
- 词首,常用B表示;
- 词中,常用M表示;
- 词尾,常用E表示;
- 单子词,常用S表示;
2. CRF分词的过程就是对词位标注后,将B和E之间的字,以及S单字构成分词;
3. CRF分词实例:
- 原始例句:我爱北京天安门
- CRF标注后:我/S 爱/S 北/B 京/E 天/B 安/M 门/E
- 分词结果:我/爱/北京/天安门
语料截图如下:
由于语料很小,下面程序中创建的映射字典也小,所以预测时不能出现字典外的字,否则报KeyError。
链接:https://pan.baidu.com/s/1SUd-QwlD-WlfqGvo7ElhDw
提取码:v0hx
1.config.py
存放一些超参数。
1 filename='word.txt'
2 EMBEDDING_DIM = 5
3 HIDDEN_DIM = 4
4 epochs=100
2.data_process.py
预处理数据
1 import re
2 import torch
3 START_TAG = "<START>"
4 STOP_TAG = "<STOP>"
5 tag_to_ix = {"B": 0, "M": 1, "E": 2,"S":3, START_TAG: 4, STOP_TAG: 5}
6
7 def prepare_sequence(seq, to_ix): #seq是字序列,to_ix是字和序号的字典
8 idxs = [to_ix[w] for w in seq] #idxs是字序列对应的向量
9 return torch.tensor(idxs, dtype=torch.long)
10
11 #将句子转换为字序列
12 def get_word(sentence):
13 word_list = []
14 sentence = ''.join(sentence.split(' '))
15 for i in sentence:
16 word_list.append(i)
17 return word_list
18
19 #将句子转换为BMES序列
20 def get_str(sentence):
21 output_str = []
22 sentence = re.sub(' ', ' ', sentence) #发现有些句子里面,有两格空格在一起
23 list = sentence.split(' ')
24 for i in range(len(list)):
25 if len(list[i]) == 1:
26 output_str.append('S')
27 elif len(list[i]) == 2:
28 output_str.append('B')
29 output_str.append('E')
30 else:
31 M_num = len(list[i]) - 2
32 output_str.append('B')
33 output_str.extend('M'* M_num)
34 output_str.append('E')
35 return output_str
36
37 def read_file(filename):
38 word, content, label = [], [], []
39 text = open(filename, 'r', encoding='utf-8')
40 for eachline in text:
41 eachline = eachline.strip('\n')
42 eachline = eachline.strip(' ')
43 word_list = get_word(eachline)
44 letter_list = get_str(eachline)
45 word.extend(word_list)
46 content.append(word_list)
47 label.append(letter_list)
48 return word, content, label #word是单列表,content和label是双层列表
查看下数据内容:
1 text, content, label = read_file('word.txt')
2 print(text)
3 print(content)
4 print(label)
1 ['十', '亿', '中', '华', '儿', '女', '踏', '上', '新', '的', '征', '程', '。', '过', '去', '的', '一', '年', ',', '是', '全', '国', '各', '族', '人', '民', '在', '中', '国', '共', '产', '党', '领', '导', '下', ',', '在', '建', '设', '有', '中', '国', '特', '色', '的', '社', '会', '主', '义', '道', '路', '上', ',', '坚', '持', '改', '革', '、', '开', '放', ',', '团', '结', '奋', '斗', '、', '胜', '利', '前', '进', '的', '一', '年', '。', '城', '乡', '经', '济', '体', '制', '改', '革', '向', '纵', '深', '稳', '步', '发', '展', ',', '对', '外', '开', '放', '迈', '出', '了', '新', '的', '步', '伐', ',', '工', '农', '业', '生', '产', '和', '其', '它', '各', '项', '建', '设', '事', '业', '全', '面', '完', '成', '了', '“', '七', '五', '”', '计', '划', '第', '一', '年', '的', '任', '务', ',', '人', '民', '生', '活', '继', '续', '有', '所', '改', '善', '。', '政', '治', '上', '安', '定', '团', '结', ',', '端', '正', '党', '风', '和', '社', '会', '风', '气', '的', '工', '作', '取', '得', '了', '新', '的', '进', '展', ',', '社', '会', '主', '义', '民', '主', '和', '法', '制', '建', '设', '不', '断', '加', '强', '。', '在', '党', '的', '十', '二', '届', '六', '中', '全', '会', '通', '过', '的', '《', '关', '于', '社', '会', '主', '义', '精', '神', '文', '明', '建', '设', '指', '导', '方', '针', '的', '决', '议', '》', '指', '引', '下', ',', '我', '国', '两', '个', '文', '明', '的', '建', '设', '正', '在', '向', '新', '的', '水', '平', '迈', '步', '。', '从', '党', '的', '十', '一', '届', '三', '中', '全', '会', '实', '现', '伟', '大', '历', '史', '转', '折', '到', '现', '在', ',', '我', '国', '政', '治', '安', '定', '团', '结', ',', '经', '济', '稳', '定', '、', '持', '续', '、', '协', '调', '发', '展', '已', '经', '八', '年', '了', ',', '这', '是', '建', '国', '以', '来', '稳', '步', '发', '展', '持', '续', '时', '间', '最', '长', '的', '时', '期', '。', '在', '十', '年', '动', '乱', '之', '后', ',', '取', '得', '这', '样', '一', '个', '大', '好', '局', '面', '是', '不', '容', '易', '的', '。']
2 [['十', '亿', '中', '华', '儿', '女', '踏', '上', '新', '的', '征', '程', '。'], ['过', '去', '的', '一', '年', ',', '是', '全', '国', '各', '族', '人', '民', '在', '中', '国', '共', '产', '党', '领', '导', '下', ','], ['在', '建', '设', '有', '中', '国', '特', '色', '的', '社', '会', '主', '义', '道', '路', '上', ',', '坚', '持', '改', '革', '、', '开', '放', ',', '团', '结', '奋', '斗', '、', '胜', '利', '前', '进', '的', '一', '年', '。'], ['城', '乡', '经', '济', '体', '制', '改', '革', '向', '纵', '深', '稳', '步', '发', '展', ',', '对', '外', '开', '放', '迈', '出', '了', '新', '的', '步', '伐', ',', '工', '农', '业', '生', '产', '和', '其', '它', '各', '项', '建', '设', '事', '业', '全', '面', '完', '成', '了', '“', '七', '五', '”', '计', '划', '第', '一', '年', '的', '任', '务', ',', '人', '民', '生', '活', '继', '续', '有', '所', '改', '善', '。'], ['政', '治', '上', '安', '定', '团', '结', ',', '端', '正', '党', '风', '和', '社', '会', '风', '气', '的', '工', '作', '取', '得', '了', '新', '的', '进', '展', ',', '社', '会', '主', '义', '民', '主', '和', '法', '制', '建', '设', '不', '断', '加', '强', '。'], ['在', '党', '的', '十', '二', '届', '六', '中', '全', '会', '通', '过', '的', '《', '关', '于', '社', '会', '主', '义', '精', '神', '文', '明', '建', '设', '指', '导', '方', '针', '的', '决', '议', '》', '指', '引', '下', ',', '我', '国', '两', '个', '文', '明', '的', '建', '设', '正', '在', '向', '新', '的', '水', '平', '迈', '步', '。'], ['从', '党', '的', '十', '一', '届', '三', '中', '全', '会', '实', '现', '伟', '大', '历', '史', '转', '折', '到', '现', '在', ',', '我', '国', '政', '治', '安', '定', '团', '结', ',', '经', '济', '稳', '定', '、', '持', '续', '、', '协', '调', '发', '展', '已', '经', '八', '年', '了', ',', '这', '是', '建', '国', '以', '来', '稳', '步', '发', '展', '持', '续', '时', '间', '最', '长', '的', '时', '期', '。'], ['在', '十', '年', '动', '乱', '之', '后', ',', '取', '得', '这', '样', '一', '个', '大', '好', '局', '面', '是', '不', '容', '易', '的', '。']]
3 [['B', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'S'], ['B', 'E', 'S', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'B', 'M', 'M', 'M', 'E', 'B', 'E', 'S', 'S'], ['S', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S', 'B', 'M', 'M', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S'], ['B', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'S', 'S', 'B', 'E', 'S', 'B', 'M', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'S', 'B', 'E', 'S', 'B', 'E', 'B', 'M', 'E', 'S', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S'], ['B', 'E', 'S', 'B', 'M', 'M', 'E', 'S', 'B', 'M', 'M', 'E', 'S', 'B', 'M', 'M', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'S', 'S', 'B', 'E', 'S', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'B', 'M', 'M', 'E', 'B', 'M', 'M', 'E', 'S'], ['S', 'S', 'S', 'B', 'M', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'M', 'M', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'B', 'E', 'S', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'S', 'S', 'B', 'E', 'B', 'E', 'S'], ['B', 'E', 'S', 'B', 'M', 'M', 'M', 'M', 'M', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'B', 'E', 'B', 'E', 'S', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'E', 'S', 'S', 'B', 'E', 'B', 'M', 'M', 'E', 'B', 'M', 'M', 'E', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'B', 'E', 'S'], ['S', 'B', 'M', 'M', 'E', 'B', 'E', 'S', 'B', 'E', 'B', 'E', 'B', 'E', 'B', 'M', 'M', 'E', 'S', 'S', 'B', 'E', 'S', 'S']]
View Code
3.BiLSTM_CRF.py
关于BiLSTM+CRF的详细理解:https://zhuanlan.zhihu.com/p/97676647
转移概率矩阵transitions,transitionsij表示t时刻隐状态为qi,t+1时刻隐状态转换为qj的概率,即P(it+1=qj|it=qi)
1 import torch
2 from data_process import START_TAG,STOP_TAG
3 from torch import nn
4
5 def argmax(vec): #返回每一行最大值的索引
6 _, idx = torch.max(vec, 1)
7 return idx.item()
8
9
10 def prepare_sequence(seq, to_ix): #seq是字序列,to_ix是字和序号的字典
11 idxs = [to_ix[w] for w in seq] #idxs是字序列对应的向量
12 return torch.tensor(idxs, dtype=torch.long)
13
14
15 #LSE函数,模型中经常用到的一种路径运算的实现
16 def log_sum_exp(vec): #vec.shape=[1, target_size]
17 max_score = vec[0, argmax(vec)] #每一行的最大值
18 max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
19 return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))
20
21
22 class BiLSTM_CRF(nn.Module):
23
24 def __init__(self, vocab_size, tag_to_ix, embedding_dim, hidden_dim):
25 super(BiLSTM_CRF, self).__init__()
26 self.embedding_dim = embedding_dim
27 self.hidden_dim = hidden_dim
28 self.vocab_size = vocab_size
29 self.tag_to_ix = tag_to_ix
30 self.tagset_size = len(tag_to_ix)
31
32 self.word_embeds = nn.Embedding(vocab_size, embedding_dim)
33 self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True)
34
35 self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size) # Maps the output of the LSTM into tag space
36
37 self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size)) #随机初始化转移矩阵
38
39 self.transitions.data[tag_to_ix[START_TAG], :] = -10000 #tag_to_ix[START_TAG]: 3(第三行,即其他状态到START_TAG的概率)
40 self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000 #tag_to_ix[STOP_TAG]: 4(第四列,即STOP_TAG到其他状态的概率)
41 self.hidden = self.init_hidden()
42
43 def init_hidden(self):
44 return (torch.randn(2, 1, self.hidden_dim // 2), torch.randn(2, 1, self.hidden_dim // 2))
45
46 #所有路径的得分,CRF的分母
47 def _forward_alg(self, feats):
48 init_alphas = torch.full((1, self.tagset_size), -10000.) #初始隐状态概率,第1个字是O1的实体标记是qi的概率
49 init_alphas[0][self.tag_to_ix[START_TAG]] = 0.
50
51 forward_var = init_alphas #初始状态的forward_var,随着step t变化
52
53 for feat in feats: #feat的维度是[1, target_size]
54 alphas_t = []
55 for next_tag in range(self.tagset_size): #给定每一帧的发射分值,按照当前的CRF层参数算出所有可能序列的分值和
56
57 emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size) #发射概率[1, target_size] 隐状态到观测状态的概率
58 trans_score = self.transitions[next_tag].view(1, -1) #转移概率[1, target_size] 隐状态到下一个隐状态的概率
59 next_tag_var = forward_var + trans_score + emit_score #本身应该相乘求解的,因为用log计算,所以改为相加
60
61 alphas_t.append(log_sum_exp(next_tag_var).view(1))
62
63 forward_var = torch.cat(alphas_t).view(1, -1)
64
65 terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]] #最后转到[STOP_TAG],发射分值为0,转移分值为列向量 self.transitions[:, [self.tag2ix[END_TAG]]]
66 return log_sum_exp(terminal_var)
67
68 #得到feats,维度=len(sentence)*tagset_size,表示句子中每个词是分别为target_size个tag的概率
69 def _get_lstm_features(self, sentence):
70 self.hidden = self.init_hidden()
71 embeds = self.word_embeds(sentence).view(len(sentence), 1, -1)
72 lstm_out, self.hidden = self.lstm(embeds, self.hidden)
73 lstm_out = lstm_out.view(len(sentence), self.hidden_dim)
74 lstm_feats = self.hidden2tag(lstm_out)
75 return lstm_feats
76
77 #正确路径的分数,CRF的分子
78 def _score_sentence(self, feats, tags):
79 score = torch.zeros(1)
80 tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
81 for i, feat in enumerate(feats):
82 #self.transitions[tags[i + 1], tags[i]] 是从标签i到标签i+1的转移概率
83 #feat[tags[i+1]], feat是step i的输出结果,有5个值,对应B, I, E, START_TAG, END_TAG, 取对应标签的值
84 score = score + self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]] # 沿途累加每一帧的转移和发射
85 score = score + self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]] # 加上到END_TAG的转移
86 return score
87
88
89 #解码,得到预测序列的得分,以及预测的序列
90 def _viterbi_decode(self, feats):
91 backpointers = [] #回溯路径;backpointers[i][j]=第i帧到达j状态的所有路径中, 得分最高的那条在i-1帧是什么状态
92
93 # Initialize the viterbi variables in log space
94 init_vvars = torch.full((1, self.tagset_size), -10000.)
95 init_vvars[0][self.tag_to_ix[START_TAG]] = 0
96
97 forward_var = init_vvars
98 for feat in feats:
99 bptrs_t = []
100 viterbivars_t = []
101
102 for next_tag in range(self.tagset_size):
103
104 next_tag_var = forward_var + self.transitions[next_tag] #其他标签(B,I,E,Start,End)到标签next_tag的概率
105 best_tag_id = argmax(next_tag_var) #选择概率最大的一条的序号
106 bptrs_t.append(best_tag_id)
107 viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
108
109 forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1) #从step0到step(i-1)时5个序列中每个序列的最大score
110 backpointers.append(bptrs_t)
111
112
113 terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]] #其他标签到STOP_TAG的转移概率
114 best_tag_id = argmax(terminal_var)
115 path_score = terminal_var[0][best_tag_id]
116
117 best_path = [best_tag_id]
118 for bptrs_t in reversed(backpointers): #从后向前走,找到一个best路径
119 best_tag_id = bptrs_t[best_tag_id]
120 best_path.append(best_tag_id)
121
122 start = best_path.pop()
123 assert start == self.tag_to_ix[START_TAG] #安全性检查
124 best_path.reverse() #把从后向前的路径倒置
125 return path_score, best_path
126
127 #求负对数似然,作为loss
128 def neg_log_likelihood(self, sentence, tags):
129 feats = self._get_lstm_features(sentence) #emission score
130 forward_score = self._forward_alg(feats) #所有路径的分数和,即b
131 gold_score = self._score_sentence(feats, tags) #正确路径的分数,即a
132 return forward_score - gold_score #注意取负号 -log(a/b) = -[log(a) - log(b)] = log(b) - log(a)
133
134
135 def forward(self, sentence):
136 lstm_feats = self._get_lstm_features(sentence)
137 score, tag_seq = self._viterbi_decode(lstm_feats)
138 return score, tag_seq
4.training.py
1 from data_process import read_file, tag_to_ix
2 from config import *
3 from BiLSTM_CRF import *
4 import torch
5 from torch import nn
6 from torch import optim
7
8 _, content, label = read_file(filename)
9
10 def train_data(content, label):
11 train_data = []
12 for i in range(len(label)):
13 train_data.append((content[i], label[i]))
14 return train_data
15 data = train_data(content,label)
16
17 word_to_ix = {}
18 for sentence, tags in data:
19 for word in sentence:
20 if word not in word_to_ix:
21 word_to_ix[word] = len(word_to_ix) #单词映射,字到序号
22
23
24 model = BiLSTM_CRF(len(word_to_ix), tag_to_ix, EMBEDDING_DIM, HIDDEN_DIM)
25 optimizer = optim.Adam(model.parameters(), lr=1e-3)
26
27 #训练
28 for epoch in range(epochs):
29 for sentence, tags in data:
30 model.zero_grad()
31
32 sentence_in = prepare_sequence(sentence, word_to_ix)
33 targets = torch.tensor([tag_to_ix[t] for t in tags], dtype=torch.long)
34 loss = model.neg_log_likelihood(sentence_in, targets)
35
36 loss.backward()
37 optimizer.step()
38 if epoch%10 == 0:
39 print('epoch/epochs: {}/{}, loss:{:.6f}'.format(epoch+1, epochs, loss.data[0]))
40
41 #保存模型
42 torch.save(model,'cws.model')
43 torch.save(model.state_dict(),'cws_all.model')
5.test_model.py
调用上面保存的模型,进行预测。
1 from trainning import word_to_ix
2 from data_process import prepare_sequence
3 import torch
4
5 net = torch.load('cws.model')
6 net.eval()
7 stri="改善人民生活水平,建设社会主义政治经济。"
8 precheck_sent = prepare_sequence(stri, word_to_ix)
9 #precheck_sent= tensor([ 45, 102, 23, 24, 80, 98, 140, 141, 17, 32, 33, 37, 38, 39, 40, 103, 104, 60, 61, 12])
10
11 label = net(precheck_sent)[1]
12 #net(precheck_sent)= (tensor(32.3123, grad_fn=<SelectBackward>), [0, 2, 0, 2, 0, 2, 0, 2, 3, 0, 2, 0, 2, 0, 2, 0, 2, 0, 2, 2])
13
14 cws=[]
15 for i in range(len(label)):
16 cws.extend(stri[i])
17 if label[i]==2 or label==3:
18 cws.append('/')
19 #cws= ['改', '善', '/', '人', '民', '/', '生', '活', '/', '水', '平', '/', ',', '建', '设', '/', '社', '会', '/', '主', '义', '/', '政', '治', '/', '经', '济', '/', '。']
20
21 print('输入未分词语句:', stri)
22 print('分词结果:', ''.join(cws))
1 epoch/epochs: 1/100, loss:33.839325
2 epoch/epochs: 11/100, loss:31.749798
3 epoch/epochs: 21/100, loss:29.822870
4 epoch/epochs: 31/100, loss:27.391972
5 epoch/epochs: 41/100, loss:26.033567
6 epoch/epochs: 51/100, loss:24.467463
7 epoch/epochs: 61/100, loss:22.403660
8 epoch/epochs: 71/100, loss:20.725002
9 epoch/epochs: 81/100, loss:18.280849
10 epoch/epochs: 91/100, loss:16.049187
输入未分词语句: 改善人民生活水平,建设社会主义政治经济。
分词结果: 改善/人民/生活/水平/,建设/社会/主义/政治/经济/。