简单的 Denoising Diffusion 模型实现

本仓库包含一个使用 PyTorch 实现的极简版 denoising diffusion 模型 [1,2]。其大部分代码来自 The Annotated DiffusionPhil Wang's diffusion repository。这两个资源对于 diffusion 模型入门都很有帮助,但当我刚开始学习 diffusion 模型时,它们对我来说仍然有些复杂。因此,我重构了 The Annotated Diffusion 的大部分实现,并将其简化为一个极简的版本,并将函数和类在逻辑上分离到不同的文件中,作为一项学习练习。我的目标是理解 diffusion 模型的基本构建块,以便在即将到来的项目中中使用它们。我分享这个仓库,希望我的练习能帮助你理解更复杂的实现。

代码概览

代码按照以下方式组织在 src 文件夹下:

示例

来自数据集的示例:

由 diffusion 模型生成的示例 (旋转是因为数据增强,虽然很搞笑):

示例 diffusion 过程:

总结

正如你所看到的,生成的图像不如数据集中的图像清晰。可以加入许多改进来提高图像质量,但每种改进都会增加复杂性。Phil Wang's diffusion repository 是发现其中一些方法的好地方。

依赖

torch
torchvision
datasets
PIL
numpy

参考文献

[1] Song and Ermon, Generative Modeling by Estimating Gradients of the Data Distribution [2] Ho et al., Denoising Diffusion Probabilistic Models