在胳膊骨折修养的这段期间,系统的学习了一下强化学习相关的知识。正好今天看到了黄民烈老师团队2018年在AAAI上发表的paper《Learning Structured Representation for Text Classification via Reinforcement Learning》。

这篇paper以文本分类作为主要的任务,运用强化学习提出了两个模型:ID-LSTM和HS-LSTM,其中ID-LSTM用来自动地把一句话中和目标任务无关的词语删除掉,起到简化句子的作用;HS-LSTM则是用来自动地从句子中抽取phrase,从而更好的帮助模型进行文本分类。

首先这两个模型都由三个部分组成:Policy Network(PNet),structured representation models和Classification Network (CNet)。这里的PNet对应t时刻采用动作NLP网络定位什么含义 nlp net_文本分类动作的概率为:NLP网络定位什么含义 nlp net_Network_02。公式中的NLP网络定位什么含义 nlp net_文本分类_03NLP网络定位什么含义 nlp net_文本分类_04都是PNet的网络参数。在训练阶段,采用的动作NLP网络定位什么含义 nlp net_NLP网络定位什么含义_05由上述公式根据概率分布采样得到;而在预测阶段,则是直接挑选概率最大的NLP网络定位什么含义 nlp net_NLP网络定位什么含义_05所对应的动作。很显然,在该场景下的reward其实就是经过处理之后,该句子被CNet预测为正确label的概率。最终PNet网络的梯度计算公式为:NLP网络定位什么含义 nlp net_强化学习_07.

Information Distilled LSTM (ID-LSTM)

在该模型中,action集合中总共有两个动作:Retain和Delete,使用的基础模型是基于LSTM的,只不过不同时刻针对不同动作LSTM的运作方式和传统的比稍有不同:

if NLP网络定位什么含义 nlp net_强化学习_08 then NLP网络定位什么含义 nlp net_Network_09

if NLP网络定位什么含义 nlp net_Network_10 then NLP网络定位什么含义 nlp net_Network_11

具体流程如下图所示:

NLP网络定位什么含义 nlp net_Network_12


而对于state的定义则有NLP网络定位什么含义 nlp net_NLP网络定位什么含义_13

那么很简单,CNet对应的分类公式为NLP网络定位什么含义 nlp net_强化学习_14

最终PNet部分的Reward定义为:NLP网络定位什么含义 nlp net_Network_15,其中NLP网络定位什么含义 nlp net_强化学习_16代表了被删除的word数量(意思是鼓励机器去多删除一些word),NLP网络定位什么含义 nlp net_Network_17用来权衡这个力度。

Hierarchically Structured LSTM (HS-LSTM)
在该模型中,action集合中总共有两个动作:Inside和End,使用的模型是2个层次化的LSTM,一个用来将word转化成phrase向量,另一个将生成的phrase向量转化成Sentence 向量。
针对phrase向量生成器来说,其LSTM运作公式如下:
if NLP网络定位什么含义 nlp net_文本分类_18 then NLP网络定位什么含义 nlp net_强化学习_19
if NLP网络定位什么含义 nlp net_Network_20 then NLP网络定位什么含义 nlp net_强化学习_21

对于Sentence向量生成器来说,其LSTM运作公式如下:

if NLP网络定位什么含义 nlp net_文本分类_18 then NLP网络定位什么含义 nlp net_Network_23

if NLP网络定位什么含义 nlp net_Network_20 then NLP网络定位什么含义 nlp net_文本分类_25

具体方式如下图所示:

NLP网络定位什么含义 nlp net_Network_26


在该场景下,对于state的定义则有NLP网络定位什么含义 nlp net_NLP网络定位什么含义_27

CNet对应的分类公式为NLP网络定位什么含义 nlp net_文本分类_28

最终PNet部分的Reward定义为:NLP网络定位什么含义 nlp net_文本分类_29,其中NLP网络定位什么含义 nlp net_强化学习_16代表了被删除的word数量,当NLP网络定位什么含义 nlp net_NLP网络定位什么含义_31 NLP网络定位什么含义 nlp net_文本分类_32会取到最大的Reward。作者之所以这样设置参数,是因为他发现在他的语料库中,一个包含了L个word的一句话中,平均的phrase个数为0.316L。

和所有的深度强化学习网络一样,这样的网络是十分难以训练的(即直接训练的话,网络损失函数机会不会收敛)。因此在训练的时候,需要一定的技巧,作者分为以下3个步骤:

1 预训练CNet网络和分类网络参数;对于ID-LSTM直接使用原始的不经过删减的句子进行预训练;对于HD-LSTM则先使用简单的启发式算法对原始句子中的word进行划分phrase处理;
2 固定住CNet部分网络参数,对PNet网络参数进行预训练操作;
3 Jointly 训练整个网络参数。

这篇paper提出的模型其实可以用到任何序列处理的任务中去(比方说对于CTR预估场景下的用户一系列行为的建模),这也是未来值得探索的方向之一。