Stable Diffusion算法代码详解
原创
©著作权归作者所有:来自51CTO博客作者陈亦新的原创作品,请联系作者获取转载授权,否则将追究法律责任
0 综述
Stable Diffusion是一种文本到图像的模型。相对于Disco Diffusion,SD的速度会更快一些。经过我的比较发现,Disco Diffusion对于现存的需求会比SD少一些,而且DD在生成风景上有一些更好的效果。SD会更注重对与prompt内容的还原。
这一类文本到图像的算法都是基于CLIP这个文本图像预训练模型的,之前已经讲过了。
Stable Diffusion模型本身建立潜在扩散模型的基础上,并且结合了Dall-E2和Imagen的条件扩散模型的看法。核心数据在LAION-Aesthetics上训练。
1 代码仓库
repo:github.com/CompVis/sta…
上面讲的很具体,很容易就跑起来了。当然,作为一名算法从事人员,还是有必要把里面的具体细节剖开看一看。个人之前不是从事这个diffusion领域的,所以有一些见解可能存在错误。
其中关键是一个叫做txt2img.py的文件。
1.1 模型加载
def load_model_from_config(config, ckpt, verbose=False):
print(f"Loading model from {ckpt}")
# 加载模型的经典语句
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd:
print(f"Global Step: {pl_sd['global_step']}")
sd = pl_sd["state_dict"]
# 上面我们加载出来了模型的参数,还需要一个模型来接着参数。
# 这里是根据配置文件生成对应的模型,配置文件可见附图1
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("missing keys:")
print(m)
if len(u) > 0 and verbose:
print("unexpected keys:")
print(u)
model.cuda()
model.eval()
return
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
# 这里的config["target"] = "ldm.models.diffusion.ddpm.LatentDiffusion"
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
# module = "ldm.models.diffusion.ddpm"
# cls = "LatentDiffusion"
# 我们可以自然的猜测出,module是py文件的路径,cls是该py文件当中的模型module类
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
至此,我们了解了该算法如何通过配置文件yaml来控制模型类的加载。
1.2 模型结构
打开模型文件后,发现LatentDiffusion模型文件有1k多行,而且继承了DDPM类。此处暂时搁置。
附图