有两个不同版本的源代码。TensorFlow 版本更新更完善,如果您希望试验我们的技术、在此基础上构建或将其应用于新数据集,我们通常建议将其作为起点。另一方面,原始的Theano 版本是我们用来生成论文中显示的所有结果的版本。我们建议在——且仅当——您希望为基准数据集(如 CIFAR-10、MNIST-RGB 和 CelebA)重现我们的准确结果时使用它。
下表总结了主要差异:
requirements-pip.txt
[hidecontent type="logged" desc="隐藏内容:登录后可查看"]
在 Google Drive 上找到的所有预训练网络,以及训练脚本生成的网络,都存储为 Python PKL 文件。pickle
只要满足两个条件,就可以使用标准机制导入它们:(1) 包含 Progressive GAN 代码存储库的目录必须包含在 PYTHONPATH 环境变量中,以及 (2)tf.Session()
必须事先创建一个对象并将其设置为默认。每个 PKL 文件包含 3 个实例tfutil.Network
:
# Import official CelebA-HQ networks.
with open('karras2018iclr-celebahq-1024x1024.pkl', 'rb') as file:
G, D, Gs = pickle.load(file)
# G = Instantaneous snapshot of the generator, mainly useful for resuming a previous training run.
# D = Instantaneous snapshot of the discriminator, mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator, yielding higher-quality results than the instantaneous snapshot.
也可以导入使用 Theano 实现生成的网络,只要它们不使用 TensorFlow 版本本身不支持的任何功能(小批量歧视、批量归一化等)。但是,要启用 Theano 网络导入,您必须使用misc.load_pkl()
代替pickle.load()
# Import Theano versions of the official CelebA-HQ networks.
import misc
G, D, Gs = misc.load_pkl('200-celebahq-1024x1024/network-final.pkl')
导入网络后,您可以调用Gs.run()
为给定的潜在向量生成一组图像,或Gs.get_output_for()
将生成器网络包含在更大的 TensorFlow 表达式中。有关详细信息,请参阅 Google Drive 上的示例脚本。指示:
pip install -r requirements-pip.txt
import_example.py
自networks/tensorflow-version/example_import_script
karras2018iclr-celebahq-1024x1024.pkl
从中下载networks/tensorflow-version
并将其放在与脚本相同的目录中。python import_example.py
img0.png
- img9.png
),它们与在中找到的图像networks/tensorflow-version/example_import_script
完全匹配。Progressive GAN 代码存储库包含一个命令行工具,用于重新创建我们在论文中使用的数据集的位精确副本。该工具还提供了各种用于操作数据集的实用程序:
usage: dataset_tool.py [-h] <command> ...
display Display images in dataset.
extract Extract images from dataset.
compare Compare two datasets.
create_mnist Create dataset for MNIST.
create_mnistrgb Create dataset for MNIST-RGB.
create_cifar10 Create dataset for CIFAR-10.
create_cifar100 Create dataset for CIFAR-100.
create_svhn Create dataset for SVHN.
create_lsun Create dataset for single LSUN category.
create_celeba Create dataset for CelebA.
create_celebahq Create dataset for CelebA-HQ.
create_from_images Create dataset from a directory full of images.
create_from_hdf5 Create dataset from legacy HDF5 archive.
Type "dataset_tool.py <command> -h" for more information.
数据集由包含多种分辨率的相同图像数据的目录表示,以实现高效流式传输。每个分辨率都有一个单独的*.tfrecords
文件,如果数据集包含标签,它们也会存储在一个单独的文件中:
> python dataset_tool.py create_cifar10 datasets/cifar10 ~/downloads/cifar10
> ls -la datasets/cifar10
drwxr-xr-x 2 user user 7 Feb 21 10:07 .
drwxrwxr-x 10 user user 62 Apr 3 15:10 ..
-rw-r--r-- 1 user user 4900000 Feb 19 13:17 cifar10-r02.tfrecords
-rw-r--r-- 1 user user 12350000 Feb 19 13:17 cifar10-r03.tfrecords
-rw-r--r-- 1 user user 41150000 Feb 19 13:17 cifar10-r04.tfrecords
-rw-r--r-- 1 user user 156350000 Feb 19 13:17 cifar10-r05.tfrecords
-rw-r--r-- 1 user user 2000080 Feb 19 13:17 cifar10-rxx.labels
这些create_*
命令将给定数据集的标准版本作为输入,并生成相应的*.tfrecords
文件作为输出。此外,该create_celebahq
命令需要一组数据文件来表示相对于原始 CelebA 数据集的增量。这些增量 (27.6GB) 可以从下载datasets/celeba-hq-deltas
。
关于模块版本的注意事项:一些数据集命令需要特定版本的 Python 模块和系统库(例如 pillow、libjpeg),如果版本不匹配,它们将给出错误。请注意错误消息——除了安装这些特定版本之外,没有其他方法可以让命令工作。
一旦设置了必要的数据集,您就可以继续训练您自己的网络。一般程序如下:
config.py
以通过取消注释/编辑特定行来指定数据集和训练配置。python train.py
。config.result_dir
默认情况下,config.py
配置为使用单 GPU 为 CelebA-HQ 训练 1024x1024 网络。即使在最高端的 NVIDIA GPU 上,这预计也需要大约两周的时间。实现更快训练的关键是使用多个 GPU 和/或使用较低分辨率的数据集。为此,config.py
包含几个常用数据集的示例,以及一组用于多 GPU 训练的“配置预设”。预计所有预设都会为 CelebA-HQ 产生大致相同的图像质量,但它们的总训练时间可能会有很大差异:
preset-v1-1gpu
:用于生成论文中显示的 CelebA-HQ 和 LSUN 结果的原始配置。在 NVIDIA Tesla V100 上预计需要大约 1 个月的时间。preset-v2-1gpu
:优化的配置比原来的配置收敛速度快得多。预计在 1xV100 上需要大约 2 周的时间。preset-v2-2gpus
:针对 2 个 GPU 的优化配置。2xV100 大约需要 1 周。preset-v2-4gpus
:针对 4 个 GPU 的优化配置。在 4xV100 上大约需要 3 天。preset-v2-8gpus
:针对 8 个 GPU 的优化配置。在 8xV100 上大约需要 2 天。作为参考,可以在以下位置找到 CelebA-HQ 的每个配置预设的预期输出networks/tensorflow-version/example_training_runs
其他值得注意的配置选项:
fp16
:启用FP16 混合精度训练以进一步减少训练时间。实际加速比在很大程度上取决于 GPU 架构和 cuDNN 版本,预计未来会大幅提升。BENCHMARK
:快速迭代分辨率以测量原始训练性能。BENCHMARK0
:与 相同BENCHMARK
,但仅使用最高分辨率。syn1024rgb
: 仅包含黑色图像的合成 1024x1024 数据集。对基准测试很有用。VERBOSE
:非常频繁地保存图像和网络快照,以方便调试。GRAPH
和HIST
:在 TensorBoard 报告中包含其他数据。可以通过多种方式分析训练结果:
fakes*.png
报告整体进度log.txt
。*.tfevents
文件中,该文件可以在 TensorBoard 中以tensorboard --logdir <result_subdir>
.config.py
,有几个预定义的配置来启动实用程序脚本 ( generate_*
)。例如:
010-pgan-celebahq-preset-v1-1gpu-fp32
,并且您想要为最新快照生成随机插值的视频。generate_interpolation_video
行config.py
,替换run_id=10
并运行python train.py
config.py
还包含预定义的配置来计算现有训练运行的各种质量指标(切片 Wasserstein 距离、Fréchet 起始距离等)。为每个网络快照连续计算指标,并将其存储在metric-*.txt
原始结果目录中。[/hidecontent]