pro_gan_pytorch

GitHub
540 99 中等 1 次阅读 4个月前MIT开发框架图像
AI 解读 由 AI 自动生成,仅供参考

pro_gan_pytorch 是经典论文《Progressive growing of GANs for improved Quality, Stability, and Variation》的非官方 PyTorch 实现版本。它核心解决了传统生成对抗网络(GAN)在训练高分辨率图像时容易出现的模型不稳定、难以收敛以及生成质量不佳等难题。

该工具采用了“渐进式生长”的独特技术策略:从低分辨率图像开始训练,随着模型逐渐稳定,逐步增加网络层数以生成更高分辨率的图像。这种由粗到细的生长方式显著提升了训练的稳定性,并能生成细节丰富、多样性高的高质量图像。此外,它还集成了均衡学习率(Equalized Learning Rate)和指数移动平均(EMA)等优化技巧,进一步保障训练效果。

pro_gan_pytorch 主要面向 AI 研究人员、深度学习开发者以及对图像生成技术有深入探索需求的技术人员。它提供了完善的命令行工具(如 progan_train),支持多 GPU 并行加速训练,并允许用户灵活调整网络深度、潜变量维度等关键参数。虽然部署需要一定的 Linux 环境与 Nvidia GPU 硬件基础,但其模块化的设计使得复现前沿算法和进行自定义实验变得更加便捷高效,是探索高质量图像生成领域的有力助手。

使用场景

某时尚电商公司的算法团队需要为即将上线的虚拟模特项目生成大量高分辨率、多样化的人脸素材,以解决实拍成本高昂且风格单一的问题。

没有 pro_gan_pytorch 时

  • 训练极不稳定:直接训练高分辨率 GAN 模型时,常因梯度消失或模式崩溃导致生成图像模糊、重复,难以收敛。
  • 硬件门槛过高:传统方法往往需要多张顶级显卡并行才能勉强跑通高分辨率训练,中小团队难以承担算力成本。
  • 调参如“开盲盒”:缺乏渐进式生长机制,开发者需手动反复调整学习率和网络深度,耗费数周时间仍难获得清晰细节。
  • 工程落地困难:从零复现论文代码复杂度高,缺乏成熟的命令行工具支持,导致从实验到部署的周期漫长。

使用 pro_gan_pytorch 后

  • 生成质量显著提升:利用渐进式生长策略,模型从低分辨率逐步细化至高清,有效避免了训练崩溃,生成的虚拟模特皮肤纹理清晰自然。
  • 单卡即可启动:依托优化的 PyTorch 实现,仅需一张显存 8GB 以上的显卡(如 1080 Ti)即可完成训练,大幅降低了硬件投入。
  • 自动化超参管理:内置均衡学习率(EQL)和指数移动平均(EMA)等机制,通过简单的命令行参数即可控制训练流程,减少了人工干预。
  • 快速迭代部署:提供 progan_train 等开箱即用的 CLI 工具,支持断点续训和多 GPU 扩展,将原本数周的调试周期缩短至几天。

pro_gan_pytorch 通过渐进式生长机制与工程化封装,让中小团队也能以低成本稳定地生成高质量图像,彻底打破了高分辨率 GAN 的训练壁垒。

运行环境要求

操作系统
  • Linux (Ubuntu 20.04.3 或更高版本)
GPU

必需 NVIDIA GPU,型号 GeForce 1080 Ti 或更高,显存最低 8GB,驱动版本 >= 470.86,CUDA 11.3(PyTorch 自带可跳过)

内存

未说明

依赖
notes该项目为非官方的 PyTorch 实现。建议使用虚拟环境(venv)安装以避免污染全局环境。训练工具支持多 GPU(需设置 CUDA_VISIBLE_DEVICES),但潜空间漫步和 FID 计算工具仅支持单 GPU。目前不支持断点续训(部分训练恢复)。开发模式下需额外安装 pytest 和 black。
python3.8.3
torch (PyTorch)
pro-gan-pth
pro_gan_pytorch hero image

快速开始

pro_gan_pytorch

非官方 PyTorch 实现,基于论文《用于提升质量、稳定性和多样性的渐进式生成对抗网络》。
官方 TensorFlow 代码请参考 这个仓库

GitHub PyPi

使用方法:

使用该包

要求(即我们测试过的环境):

  1. Ubuntu 20.04.3 或更高版本
  2. Python 3.8.3
  3. Nvidia GPU GeForce 1080 Ti 或更高,显存至少 8GB
  4. Nvidia 驱动程序 ≥ 470.86
  5. Nvidia CUDA 11.3 | 可以跳过,因为 PyTorch 自带 CUDA、cuDNN 等。

安装包

  1. 最简单的方式是创建一个新的虚拟环境,以避免污染你的全局 Python 环境。
  2. 创建并切换到新虚拟环境:
    (your-machine):~$ python3 -m venv <env-store-path>/pro_gan_pth_env 
    (pro_gan_pth_env)(your-machine):~$ source <env-store-path>/pro_gan_pth_env/bin/activate
  1. 如果满足上述所有依赖条件,即可从 PyPI 安装 pro-gan-pth 包:
    (pro_gan_pth_env)(your-machine):~$ pip install pro-gan-pth 
  1. 安装完成后,你可以使用已安装的命令行工具 progan_trainprogan_lsidprogan_fid
    注意,progan_train 支持多 GPU 训练(如果你有多块 GPU :smile:)。只需确保在 CUDA_VISIBLE_DEVICES=0,1,2 环境变量中正确设置可见的 GPU 即可。另外两个工具则仅使用单个 GPU。
    (your-machine):~$ progan_train --help
用法:训练渐进式生成对抗网络
       [-h]
       [--retrain RETRAIN]
       [--generator_path GENERATOR_PATH]
       [--discriminator_path DISCRIMINATOR_PATH]
       [--rec_dir REC_DIR]
       [--flip_horizontal FLIP_HORIZONTAL]
       [--depth DEPTH]
       [--num_channels NUM_CHANNELS]
       [--latent_size LATENT_SIZE]
       [--use_eql USE_EQL]
       [--use_ema USE_EMA]
       [--ema_beta EMA_BETA]
       [--epochs EPOCHS [EPOCHS ...]]
       [--batch_sizes BATCH_SIZES [BATCH_SIZES ...]]
       [--batch_repeats BATCH_REPEATS]
       [--fade_in_percentages FADE_IN_PERCENTAGES [FADE_IN_PERCENTAGES ...]]
       [--loss_fn LOSS_FN]
       [--g_lrate G_LRATE]
       [--d_lrate D_LRATE]
       [--num_feedback_samples NUM_FEEDBACK_SAMPLES]
       [--start_depth START_DEPTH]
       [--num_workers NUM_WORKERS]
       [--feedback_factor FEEDBACK_FACTOR]
       [--checkpoint_factor CHECKPOINT_FACTOR]
       train_path
       output_dir

    位置性参数:
      train_path            用于训练 ProGAN 的图像文件夹路径
      output_dir            用于保存日志和模型的目录路径

    可选参数:
      -h, --help            显示此帮助信息并退出
      --retrain RETRAIN     当需要从已保存的模型继续训练时启用(默认:False)
      --generator_path GENERATOR_PATH
                            用于重新训练 ProGAN 的生成器模型路径(默认:无)
      --discriminator_path DISCRIMINATOR_PATH 
                            用于重新训练 ProGAN 的判别器模型路径(默认:无)
      --rec_dir REC_DIR     指定图像是否存储在一个文件夹下,或采用递归目录结构(默认:True)
      --flip_horizontal FLIP_HORIZONTAL
                            是否应用水平翻转增强(默认:True)
      --depth DEPTH         生成器和判别器的深度(默认:10)
      --num_channels NUM_CHANNELS
                            图像数据的通道数(默认:3)
      --latent_size LATENT_SIZE
                            生成器和判别器的潜在空间大小(默认:512)
      --use_eql USE_EQL     是否使用等化学习率(默认:True)
      --use_ema USE_EMA     是否使用指数移动平均(默认:True)
      --ema_beta EMA_BETA   移动平均的权重值(默认:0.999)
      --epochs EPOCHS [EPOCHS ...]
                            每个阶段在训练数据集上运行的轮数(默认:[42, 42, 42, 42, 42, 42, 42, 42, 42])
      --batch_sizes BATCH_SIZES [BATCH_SIZES ...]
                            每个阶段用于训练模型的批量大小(默认:[32, 32, 32, 32, 16, 16, 8, 4, 2])
      --batch_repeats BATCH_REPEATS
                            每次训练迭代中执行的 G 和 D 步骤次数(默认:4)
      --fade_in_percentages FADE_IN_PERCENTAGES [FADE_IN_PERCENTAGES ...]
                            新层渐变过渡的迭代次数,以百分比表示(默认:[50, 50, 50, 50, 50, 50, 50, 50, 50])
      --loss_fn LOSS_FN     用于训练 GAN 的损失函数。当前选项:[wgan_gp, standard_gan](默认:wgan_gp)
      --g_lrate G_LRATE     生成器的学习率(默认:0.003)
      --d_lrate D_LRATE     判别器的学习率(默认:0.003)
      --num_feedback_samples NUM_FEEDBACK_SAMPLES
                            用于固定种子 GAN 反馈的样本数量(默认:4)
      --start_depth START_DEPTH
                            开始训练的分辨率。例如,2 表示 (4x4) | 3 表示 (8x8) ... | 10 表示 (1024x1024)。请注意,这不是重启部分训练的方法。目前暂不支持恢复训练,但很快就会支持。(默认:2)
      --num_workers NUM_WORKERS
                            数据加载器的子进程数量。这是 PyTorch 的特性,你也可以忽略它;除非你觉得速度异常缓慢,否则保持默认值即可。(默认:4)
      --feedback_factor FEEDBACK_FACTOR
                            每个 epoch 写入的反馈日志数量(默认:10)
      --checkpoint_factor CHECKPOINT_FACTOR
                            每个训练阶段每隔多少个 epoch 保存一次模型快照(默认:10)

------------------------------------------------------------------------------------------------------------------------------------------------------------------

    (your-machine):~$ progan_lsid --help
    用法:ProGAN 潜在空间漫步演示视频制作工具 [-h] [--output_path OUTPUT_PATH] [--generation_depth GENERATION_DEPTH] [--time TIME] [--fps FPS] [--smoothing SMOOTHING] model_path

    位置性参数:
      model_path            已训练模型文件 trained_model.bin 的路径

可选参数:
      -h, --help            显示此帮助消息并退出
      --output_path OUTPUT_PATH
                            输出视频文件的保存路径。请仅使用 mp4 格式(.mp4 扩展名)。我已经为此折腾了很久,其他格式都无法正常工作 :(。(默认:./latent_space_walk.mp4)
      --generation_depth GENERATION_DEPTH
                            生成图像的深度级别。从 2 开始 --> (4x4) | 3 --> (8x8) 等。(默认:无)
      --time TIME           视频时长(秒)。(默认:30)
      --fps FPS             生成视频的帧率。(默认:60)
      --smoothing SMOOTHING
                            隐空间行走的平滑度。值越高,平滑效果越好。(默认:0.75)

------------------------------------------------------------------------------------------------------------------------------------------------------------------

    (你的机器):~$ progan_fid --help
    使用方法:ProGAN fid_score 计算工具 [-h] [--generated_images_path GENERATED_IMAGES_PATH] [--batch_size BATCH_SIZE] [--num_generated_images NUM_GENERATED_IMAGES] 模型路径 数据集路径

    位置参数:
      model_path            训练好的 model.bin 文件路径
      dataset_path          包含数据集图像的目录路径。请注意,该目录必须是扁平结构

    取消参数:
      -h, --help            显示此帮助消息并退出
      --generated_images_path GENERATED_IMAGES_PATH
                            用于存放生成图像的目录路径。默认使用临时目录。如果您想亲自查看生成的图像,请提供此路径 :)。(默认:无)
      --batch_size BATCH_SIZE
                            用于生成随机图像的批量大小。(默认:4)
      --num_generated_images NUM_GENERATED_IMAGES
                            用于计算 FID 的生成图像数量。(默认:50000)
  1. 或者,您也可以在代码中将其作为 Python 包导入 以用于更高级的用例:
    import pro_gan_pytorch as pg 

您可以使用包中的所有模块,例如:pg.networks.Generator, pg.networks.Discriminator, pg.gan.ProGAN 等。通常,您只需要 pg.gan.ProGAN 模块来进行训练。而在推理阶段,您可能只需要 pg.networks.Generator。请参考第 4 条中 pro_gan_pytorch_scripts/ 下的工具脚本,以获取有关如何使用该包的示例。此外,您也可以直接阅读代码。它非常容易理解 (至少我希望如此 :sweat_smile: :grimacing:)。

包的开发

对于项目中更高级的用例,或者如果您希望为该项目贡献新功能,以下步骤将帮助您设置开发环境。目前这里没有标准的贡献规范(没有 CONTRIBUTING.md),但我们仍会尽量保持代码库的整体风格 :smile:。

  1. 克隆此仓库
    (你的机器):~$ cd <项目路径>
    (你的机器):<项目路径>$ git clone https://github.com/akanimax/pro_gan_pytorch.git
  1. 提前致歉,因为第 1 步可能会花费一些时间。当时我把一些 GIF 和其他大文件资产上传到了 Git 仓库。那时我还不太了解最佳实践 :sad:。我会尝试解决这个问题。完成克隆后,设置一个开发虚拟环境,
    (你的机器):<项目路径>$ python3 -m venv pro-gan-pth-dev-env
    (你的机器):<项目路径>$ source pro-gan-pth-dev-env/source/activate
  1. 以开发模式安装包:
    (pro-gan-pth-dev-env)(你的机器):<项目路径>$ pip install -e .
  1. 同时安装开发依赖:
    (pro-gan-pth-dev-env)(你的机器):<项目路径>$ pip install -r requirements-dev.txt
  1. 现在,您可以在自己选择的编辑器中打开项目,就可以开始了。我使用 pytest 进行测试,使用 black 进行代码格式化。请参阅 this_link 以了解如何将 black 配置到各种 IDE 中。

  2. 由于这是一个规模较小的项目,因此没有复杂的 CI、自动化测试或文档构建流程。不过,如果未来有更多功能加入,我们愿意考虑引入这些工具。

训练好的模型

我们将使用此包在不同的数据集上持续训练模型。同时,如果您为自己的数据集训练了模型,请随时为下表提交 PR。如果您参与贡献,请设置一个文件托管方案来发布训练好的模型。

出处 数据集 大小 分辨率 使用的 GPU 每个阶段的轮数 训练时长 FID 分数 链接 定性样本
@owang Metfaces ~1.3K 1024 x 1024 1 V100-32GB 42 24 小时 101.624 model_link image

请注意,我们使用来自 Parmar 等人 的 clean_fid 版本来计算 FID。

有趣的内容 :smile:

训练延时摄影(固定隐变量点):

根据训练过程中记录的图像制作的训练延时摄影看起来非常酷。


请观看这个 YT 视频 以获取 4K 版本 :smile:。

如果您感兴趣,可以查看我撰写的这篇 medium 博客 ,其中解释了渐进式增长技术。

参考文献

1. Tero Karras, Timo Aila, Samuli Laine, & Jaakko Lehtinen (2018). 
用于提升质量、稳定性和多样性的 GAN 渐进式增长。 
国际表示学习大会论文集。

2. Parmar, Gaurav, Richard Zhang, and Jun-Yan Zhu. 
“关于有缺陷的缩放库以及 FID 计算中的意外微妙之处。” 
arXiv 预印本 arXiv:2104.11222 (2021)。

功能需求

  • 条件 GAN 支持
  • 从日志图像生成延时视频的工具
  • 将 fid 指标计算集成到训练日志中

感谢

一如既往,
欢迎在此处提交 PR/问题/建议。希望这项工作对您的项目有所帮助 :smile:.

干杯 :beers:!
@akanimax :sunglasses:

版本历史

v_1.2.02018/07/17

常见问题

相似工具推荐

openclaw

OpenClaw 是一款专为个人打造的本地化 AI 助手,旨在让你在自己的设备上拥有完全可控的智能伙伴。它打破了传统 AI 助手局限于特定网页或应用的束缚,能够直接接入你日常使用的各类通讯渠道,包括微信、WhatsApp、Telegram、Discord、iMessage 等数十种平台。无论你在哪个聊天软件中发送消息,OpenClaw 都能即时响应,甚至支持在 macOS、iOS 和 Android 设备上进行语音交互,并提供实时的画布渲染功能供你操控。 这款工具主要解决了用户对数据隐私、响应速度以及“始终在线”体验的需求。通过将 AI 部署在本地,用户无需依赖云端服务即可享受快速、私密的智能辅助,真正实现了“你的数据,你做主”。其独特的技术亮点在于强大的网关架构,将控制平面与核心助手分离,确保跨平台通信的流畅性与扩展性。 OpenClaw 非常适合希望构建个性化工作流的技术爱好者、开发者,以及注重隐私保护且不愿被单一生态绑定的普通用户。只要具备基础的终端操作能力(支持 macOS、Linux 及 Windows WSL2),即可通过简单的命令行引导完成部署。如果你渴望拥有一个懂你

349.3k|★★★☆☆|5天前
Agent开发框架图像

stable-diffusion-webui

stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。

162.1k|★★★☆☆|6天前
开发框架图像Agent

everything-claude-code

everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上

150k|★★☆☆☆|今天
开发框架Agent语言模型

ComfyUI

ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。

108.3k|★★☆☆☆|昨天
开发框架图像Agent

gemini-cli

gemini-cli 是一款由谷歌推出的开源 AI 命令行工具,它将强大的 Gemini 大模型能力直接集成到用户的终端环境中。对于习惯在命令行工作的开发者而言,它提供了一条从输入提示词到获取模型响应的最短路径,无需切换窗口即可享受智能辅助。 这款工具主要解决了开发过程中频繁上下文切换的痛点,让用户能在熟悉的终端界面内直接完成代码理解、生成、调试以及自动化运维任务。无论是查询大型代码库、根据草图生成应用,还是执行复杂的 Git 操作,gemini-cli 都能通过自然语言指令高效处理。 它特别适合广大软件工程师、DevOps 人员及技术研究人员使用。其核心亮点包括支持高达 100 万 token 的超长上下文窗口,具备出色的逻辑推理能力;内置 Google 搜索、文件操作及 Shell 命令执行等实用工具;更独特的是,它支持 MCP(模型上下文协议),允许用户灵活扩展自定义集成,连接如图像生成等外部能力。此外,个人谷歌账号即可享受免费的额度支持,且项目基于 Apache 2.0 协议完全开源,是提升终端工作效率的理想助手。

100.8k|★★☆☆☆|昨天
插件Agent图像

markitdown

MarkItDown 是一款由微软 AutoGen 团队打造的轻量级 Python 工具,专为将各类文件高效转换为 Markdown 格式而设计。它支持 PDF、Word、Excel、PPT、图片(含 OCR)、音频(含语音转录)、HTML 乃至 YouTube 链接等多种格式的解析,能够精准提取文档中的标题、列表、表格和链接等关键结构信息。 在人工智能应用日益普及的今天,大语言模型(LLM)虽擅长处理文本,却难以直接读取复杂的二进制办公文档。MarkItDown 恰好解决了这一痛点,它将非结构化或半结构化的文件转化为模型“原生理解”且 Token 效率极高的 Markdown 格式,成为连接本地文件与 AI 分析 pipeline 的理想桥梁。此外,它还提供了 MCP(模型上下文协议)服务器,可无缝集成到 Claude Desktop 等 LLM 应用中。 这款工具特别适合开发者、数据科学家及 AI 研究人员使用,尤其是那些需要构建文档检索增强生成(RAG)系统、进行批量文本分析或希望让 AI 助手直接“阅读”本地文件的用户。虽然生成的内容也具备一定可读性,但其核心优势在于为机器

93.4k|★★☆☆☆|4天前
插件开发框架