Stable Diffusion如何对齐人类偏好?
论文尝试将人类偏好引入Stable Diffusion的模型训练中,证明了人类偏好信息可以提升Stable Diffusion生成的图像质量。值得注意的是,该方法在处理Stable Diffusion一些经典的failure case,例如人体、四肢时,展示出了优于Stable Diffusion的效果。
熟悉文生图模型的人大概都经历过挑图的过程,同样的提示词下,模型用不同随机种子生成的图在质量上常常有很大差别。一般大家会让模型一次性多生成几张,然后挑一张自己最喜欢的。
显然文生图模型可以生成出“好”图,但常常也会生成不那么好看的图,那么是否可以通过少量的微调让模型更容易生成符合人偏好的图?
近期,一篇来自香港中文大学、商汤科技等研究团队的论文尝试将人类偏好引入Stable Diffusion的模型训练中,证明了人类偏好信息可以提升Stable Diffusion生成的图像质量。值得注意的是,该方法在处理Stable Diffusion一些经典的failure case,例如人体、四肢时,展示出了优于Stable Diffusion的效果。
论文链接:https://arxiv.org/abs/2303.14420
代码地址:https://github.com/tgxs002/align_sd
项目主页:https://tgxs002.github.io/align_sd_web/
本文的方法受到了InstructGPT [1]的启发,InstructGPT利用人对于GPT生成文本的反馈训练了一个reward model,用来评价语言模型的输出,然后用强化学习的方式来微调GPT,使得生成文本更符合人的喜好。
而本文将这个思想应用到了图像生成领域,通过人对图像的反馈来提升文生图模型。在这个项目进行的时候,尚不存在一个生成图像的human feedback数据集。此外由于图像和文本模态不同,InstructGPT的RL算法并不能直接应用到diffusion model的图像生成任务上,本文着重介绍这两个问题如何进行处理。
数据收集
首先,如何获得偏好数据呢?作者发现Stable Diffusion的Discord server中有大量的聊天记录,其中一个典型的模式如下图所示:
用户输入一个prompt,channel bot会根据prompt生成多张图,接下来人可以手动选择出其中的一张进行后续操作,例如图生图、超分辨率等等。整个交互过程会被完整地记录在channel里,并且完全公开。
作者爬取了25k组这样的交互数据,其中每组数据包含一个prompt、2-4张图,以及人的选择。
虽然在这个数据集中人的选择不一定完全出于个人偏好,很可能会有很多随机的因素引入的噪声,但整体上来讲被选中的图像大多会有更高的质量,因此可以反映出人的偏好。
目前这个数据集是开源的,下载链接可以在这个推送底部的代码地址中找到。此外,更大更general的HPD v2数据集将在近期release,欢迎大家多多试用。
训练reward model
作者利用收集到的偏好数据微调CLIP model[3],将CLIP对齐到人的偏好上。借助CLIP的置信度可以定义出Human Preference Score (HPS),用作人对图像质量评价的代理,类似于RLHF中的训练reward model。
HPS的训练使用了数据集的20k,剩下的5k用于测试,测试结果如下图所示:
实验证明,相比一系列baseline,微调过后的CLIP (HPS)能够更好地预测人的偏好。
微调Stable Diffusion
有了reward model,如何将其偏好判断“注入”到文生图模型中,以指导图像的生成?
DiffusionDB [2]是一个大规模的生成图像数据集,其中包含了大量由Stable Diffusion生成的图像与prompt对。本文用HPS来对DiffusionDB中的图像进行打分,分别筛选出其中分数较高以及较低的图片作为正、负样本,然后进行类似于DreamBooth的训练,如下图所示:
上图的训练能够将“不好的图“的概念绑定到给定的Identifier中,在inference时将这个Identifier设为负样本可以提升图像质量。
实验结果
上图对比了原始的Stable Diffusion 1.4,只用正则化图像训练,以及正则化图像 + 本文方法的结果。经过微调的模型可以解决一些局部的图像质量问题。
20个人的user study证明了经过微调的模型会有更大概率生成符合人喜好的图像。
文章对比了微调前后的FiD、Aesthetic Score、CLIP Score以及HPS。其中FiD是在LAION数据集上计算的。Aesthetic Score和CLIP Score分别用来衡量图像的美学指标以及图文的匹配程度。结果证明四个指标上均有提升,说明经过微调的模型的图像质量和可控性都得到了提升。
总结
本文收集了一个大规模的生成图片的人类偏好数据集,并且借助这个数据集训练reward model,进而提升了文生图模型的生成质量。