Simple Denoising Diffusion
该文章介绍了一个使用 PyTorch 实现的极简版 denoising diffusion 模型。作者参考了 [The Annotated Diffusion](https://github.com/utkuozbulak/) 和 [Phil Wang's diffusion repository](https://github.com/utkuozbulak/),简化了实现,并将其分解为不同的文件,包括 diffusion 过程函数、数据集、模型和训练/生成脚本。文章展示了生成的金鱼图像示例,并指出图像质量有待提高。最后,作者列出了依赖项和参考文献。
简单的 Denoising Diffusion 模型实现
本仓库包含一个使用 PyTorch 实现的极简版 denoising diffusion 模型 [1,2]。其大部分代码来自 The Annotated Diffusion 和 Phil Wang's diffusion repository。这两个资源对于 diffusion 模型入门都很有帮助,但当我刚开始学习 diffusion 模型时,它们对我来说仍然有些复杂。因此,我重构了 The Annotated Diffusion 的大部分实现,并将其简化为一个极简的版本,并将函数和类在逻辑上分离到不同的文件中,作为一项学习练习。我的目标是理解 diffusion 模型的基本构建块,以便在即将到来的项目中中使用它们。我分享这个仓库,希望我的练习能帮助你理解更复杂的实现。
代码概览
代码按照以下方式组织在 src
文件夹下:
_funct_diffusion.py
- 包含正向和反向 diffusion 过程所需的所有必要函数,包括 scheduler。_cls_dataset.py
- 包含与数据相关的函数和类。我使用了一个带有增强 (例如,旋转和翻转) 的单一类别(n01443537 - Carassius auratus - Goldfish,金鱼),这就是为什么生成的图像中有几条倒置的金鱼。_cls_model.py
- 包含模型。本仓库中的模型基本上是 The Annotated Diffusion 实现的复制粘贴,只是dim_mults=(1, 2, 4, 8)
和channe=3
(RGB)。_main_train_diffusion.py
- 我想将训练和生成分离到两个不同的文件中,以便理解每个参数的作用。该文件用于训练 diffusion 模型。_main_generate_images.py
- 使用训练好的模型生成图像。
示例
来自数据集的示例:
由 diffusion 模型生成的示例 (旋转是因为数据增强,虽然很搞笑):
示例 diffusion 过程:
总结
正如你所看到的,生成的图像不如数据集中的图像清晰。可以加入许多改进来提高图像质量,但每种改进都会增加复杂性。Phil Wang's diffusion repository 是发现其中一些方法的好地方。
依赖
torch
torchvision
datasets
PIL
numpy
参考文献
[1] Song and Ermon, Generative Modeling by Estimating Gradients of the Data Distribution [2] Ho et al., Denoising Diffusion Probabilistic Models