Sampling-Bias-Corrected Neural Modeling 论文阅读笔记
written on October 19, 2022
一、概述(出发点)
- 双塔模型通过采样负样本来优化损失函数,但是这种方式会受到采样偏差影响,影响模型性能,特别是在样本分布极度倾斜的情况下
- YouTube中的视频数据是流数据,新增的Item无法包含在固定的语料库,因此需要在batch中进行负采样并计算in-batch softmax
- 从流数据中估计item的采样概率,并应用到采样偏差的修正上是改善模型性能的关键
二、具体贡献
2.1 针对batch negative sampling的bias修正
在双塔模型中模型的输出用user embeding和item embeding的点乘来表示,用于衡量二者的相似度
双塔模型的训练方式分为三种:point-wise、piar-wise和list-wise。YuTube的双塔模型采用的是list-wise方式来进行训练。list-wise的训练可以看做是一个经典的多分类问题,我们通常采用softmax计算概率,然后使用交叉熵作为损失函数,优化目标是让正样本的分数尽可能高
作者采用了加权对数似然作为损失函数,这里的表示的奖励,每一个label都是同等重要的在分类任务中,表示正样例,表示负样例;在推荐系统中,的含义则可以进一步拓展,例如若用户在一个视频上观看的时间较长,则可以设置得较大,表示用户更喜欢这些视频。
由于YouTube这类产品面对的数据量非常庞大,因此不可能针对整个语料库中的item进行softmax,然后再计算交叉熵损失。因此一般会采样出一个样本子集,在这个样本子集上计算softmax(sampled softmax),但是不同于MLP模型负样本从固定的语料库中采样,面对实际业务中的流数据,负采样只能在batch中进行(batch negative sampling)
但是这种采样方式会造成较大的bias,由于batch negative sampling隐式地使用了基于item出现频率的采样分布,因此对于热门的item,它被采样的概率更大,会被更多地作为负样本,从而热门样本会被过度惩罚。引用sampled softmax的做法,作者对user embeding 和item embeding的内积进行修正
表示batch中item j的采样概率,而的估计则是后续工作的重点,在引入修正项之后,模型的训练就可以通过SGD来进行优化
Tricks
作者在内积计算的部分还采用了两个tricks
- Normalization
对user侧和item侧的输出进行L2标准化 - Temperature
引入温度参数,对输出进行平滑处理
如何理解温度参数τ?
假设向量s=[1,2,3],的情形就是直接进行softmax,得到的概率分布为[0.09,0.24,0.67],逐渐提高的值,可以发现概率分布越来越平滑,反之则越陡峭。文章后面也针对不同温度系数进行了实验
一般来讲,如果温度系数设的越大,logits分布变得越平滑,那么log-softmax损失会对所有的负样本一视同仁,导致模型学习没有轻重。如果温度系数设的过小,则模型会越关注特别困难的负样本,但其实那些负样本很可能是潜在的正样本,这样会导致模型很难收敛或者泛化能力差。
总之,温度系数的作用就是它控制了模型对负样本的区分度。
τ | s0 | s1 | s2 | sum |
1 | 0.090031 | 0.244728 | 0.665241 | 1 |
2 | 0.186324 | 0.307196 | 0.50648 | 1 |
4 | 0.254275 | 0.326496 | 0.419229 | 1 |
0.5 | 0.015876 | 0.11731 | 0.866813 | 1 |
2.2 in-batch item采样概率的估计
该部分是论文的重点,作者提出的Streaming Frequency Estimation算法通过计算item连续两次被采样的平均间隔,得到采样概率的估计
算法使用了两个数组A、B和一个哈希函数h,数组B记录当前平均采样间隔,数组A记录上次一采样的时间,利用A来辅助更新B
注意到哈希函数的输出空间大小为H,当H<M时会存在hash冲突,从而导致item采样概率的过度估计(因为A[h(y)]更新得很频繁,t-A[h(y)]也就偏小),论文的改进方法是使用multi-hash,即使用多组A、B和哈希函数h,最终的计算结果取最大的B[h(y)]
三、实验结果
作者首先测试了不同学习率和multi-hash对于概率估计算法的影响,并在Wikipedia dataset和YouTube上验证了引入修正项后moxing的有效性
3.1 Simulation on Frequency Estimation
实验表明:
- 较高的学习率导致更快的收敛时间,但误差相对较高
- multi-hash可以有效降低误差,即使在相同数量的参数下
3.2 Wikipedia Page Retrieval
经过修正的模型Recall@K表现明显优于未经过修正的模型和mse-gramian
3.3 YouTube Experiment
相较于基线模型,经过修正的模型在YuTube视频推荐中的表现更好
除了离线实验,作者还进行了在线实验,并利用了reward来训练模型,从而真实反映用户对于视频的参与程度