benchmark_VAE

GitHub
2k 178 简单 1 次阅读 5天前Apache-2.0图像开发框架其他
AI 解读 由 AI 自动生成,仅供参考

Pythae 是一个基于 PyTorch 的开源库,旨在统一变分自编码器(VAE)及其衍生模型的实现。在深度学习研究中,对比不同 VAE 模型往往因代码架构差异而变得复杂且难以复现。Pythae 通过提供标准化的统一框架,让研究人员能够在完全相同的编码器 - 解码器架构下训练和比较多种主流 VAE 模型,从而有效解决了实验公平性与可复现性的痛点。

该工具特别适合人工智能研究人员、算法工程师及高校学生使用。它不仅支持快速复现 NeurIPS 2022 论文中的基准实验,还具备高度的灵活性:用户既可以调用预置模型,也能轻松接入自定义的网络架构和数据集进行训练。技术亮点方面,Pythae 原生支持分布式训练(DDP),能显著提升大规模数据集上的训练速度;同时深度集成了 WandB、MLflow 等实验监控工具,并支持与 HuggingFace Hub 无缝对接,仅需几行代码即可实现模型的分享与加载。无论是想要深入探究生成模型原理,还是希望高效开展对比实验,Pythae 都能提供简洁、专业且强大的技术支持。

使用场景

某医疗影像实验室的研究团队正试图通过对比多种变分自编码器(VAE)模型,从有限的肺部 CT 扫描数据中学习更鲁棒的潜在特征,以辅助早期病灶检测。

没有 benchmark_VAE 时

  • 代码重复劳动繁重:团队需为 VAE、β-VAE、VQ-VAE 等不同模型分别寻找并适配独立的开源代码库,每切换一个模型就要重写一遍数据加载和训练循环。
  • 公平对比难以保证:由于各源码的编码器/解码器架构、超参数设置及随机种子管理不一致,导致实验结果差异可能源于实现细节而非模型本身的优劣,结论缺乏说服力。
  • 实验监控分散:缺乏统一的接口对接 WandB 或 MLflow,研究人员需手动整理不同脚本产生的日志,难以实时追踪和可视化多组实验的损失曲线与重建效果。
  • 复现与协作成本高:新成员加入时需花费数天理解杂乱的代码结构,且在不同机器上复现论文结果时,常因环境依赖或缺失模块而失败。

使用 benchmark_VAE 后

  • 统一接口快速切换:借助 benchmark_VAE 标准化的 API,团队仅需修改几行配置即可在同一套自定义的 Encoder-Decoder 架构下训练十几种主流 VAE 模型,开发效率提升数倍。
  • 确保控制变量严谨:该工具强制所有模型共享相同的网络骨架和训练流程,消除了实现偏差,使团队能确信性能提升真正源自算法改进,显著增强了论文的可信度。
  • 原生集成监控生态:通过内置插件一键连接 WandB 或 HuggingFace Hub,实验指标自动同步云端,团队成员可实时协作分析生成样本质量,并轻松分享训练好的模型权重。
  • 分布式训练加速迭代:利用其支持的 PyTorch DDP 功能,团队直接在多卡服务器上并行跑通大规模数据集训练,将原本需要数周的对比实验压缩至几天内完成。

benchmark_VAE 通过统一实现标准与自动化流程,将研究人员从繁琐的工程泥潭中解放出来,使其能专注于算法创新与业务价值挖掘。

运行环境要求

操作系统
  • 未说明
GPU

未说明(支持分布式训练 DDP,暗示可使用多 GPU 加速)

内存

未说明

依赖
notes该库名为 pythae,专注于统一实现多种变分自编码器(VAE)模型以进行基准测试。支持使用 PyTorch DDP 进行分布式训练。允许用户自定义编码器和解码器架构。集成 wandb、mlflow 和 comet-ml 用于实验监控,并支持通过 HuggingFace Hub 分享和加载模型。
python3.7, 3.8, 3.9+
pytorch
wandb
mlflow
comet-ml
huggingface_hub
benchmark_VAE hero image

快速开始

Python Python Documentation Status

文档

pythae

该库以统一的实现方式实现了几种最常见的(变分)自编码器模型。特别地,它提供了通过使用相同的自动编码神经网络架构来训练这些模型,从而进行基准实验和比较的可能性。其“自定义自编码器”功能允许您使用自己的数据以及自定义的编码器和解码器神经网络来训练这些模型。该库集成了诸如 wandbmlflowcomet-ml 🧪 等实验监控工具,并支持在几行代码内从 HuggingFace Hub 🤗 上共享和加载模型。

新闻 📢

自 v0.1.0 版本起,Pythae 现已支持使用 PyTorch 的 DDP 进行分布式训练。现在您可以更快地在更大的数据集上训练您喜爱的 VAE,而且仍然只需几行代码。 请参阅我们的加速 基准测试

快速访问:

安装

要安装该库的最新稳定版,请使用 pip 运行以下命令:

$ pip install pythae

要安装该库的最新 GitHub 版本,请使用 pip 运行以下命令:

$ pip install git+https://github.com/clementchadebec/benchmark_VAE.git

或者,您也可以克隆 GitHub 仓库以访问测试、教程和脚本。

$ git clone https://github.com/clementchadebec/benchmark_VAE.git

然后进入目录并安装库:

$ cd benchmark_VAE
$ pip install -e .

可用模型

以下是当前库中已实现的模型列表。

模型 训练示例 论文 官方实现
自编码器 (AE) 在 Colab 中打开
变分自编码器 (VAE) 在 Colab 中打开 链接
Beta 变分自编码器 (BetaVAE) 在 Colab 中打开 链接
变分自编码器结合线性归一化流 (VAE_LinNF) 在 Colab 中打开 链接
变分自编码器结合逆向自回归流 (VAE_IAF) 在 Colab 中打开 链接 链接
解耦合 Beta 变分自编码器 (DisentangledBetaVAE) 在 Colab 中打开 链接
因子分解解耦 (FactorVAE) 在 Colab 中打开 链接
Beta-TC-VAE (BetaTCVAE) 在 Colab 中打开 链接 链接
重要性加权自编码器 (IWAE) 在 Colab 中打开 链接 链接
多重重要性加权自编码器 (MIWAE) 在 Colab 中打开 链接
部分重要性加权自编码器 (PIWAE) 在 Colab 中打开 链接
组合重要性加权自编码器 (CIWAE) 在 Colab 中打开 链接
基于感知度量相似性 (MSSSIM) 的变分自编码器 在 Colab 中打开 链接
Wasserstein 自编码器 (WAE) 在 Colab 中打开 链接 链接
信息变分自编码器 (INFOVAE_MMD) 在 Colab 中打开 链接
VAMP 自编码器 (VAMP) 在 Colab 中打开 链接 链接
超球面变分自编码器 (SVAE) 在 Colab 中打开 链接 链接
波兰克圆盘变分自编码器 (PoincareVAE) 在 Colab 中打开 链接 链接
对抗自编码器 (Adversarial_AE) 在 Colab 中打开 链接
变分自编码器 GAN (VAEGAN) 🥗 在 Colab 中打开 链接 链接
向量量化变分自编码器 (VQVAE) 在 Colab 中打开 链接 链接
哈密顿变分自编码器 (HVAE) 在 Colab 中打开 链接 链接
使用 L2 解码器参数正则化的自编码器 (RAE_L2) 在 Colab 中打开 链接 链接
使用梯度惩罚正则化的自编码器 (RAE_GP) 在 Colab 中打开 链接 链接
黎曼哈密顿变分自编码器 (RHVAE) 在 Colab 中打开 链接 链接
层次残差量化 (HRQVAE) 在 Colab 中打开 链接 链接

请参阅重建生成结果,了解所有上述模型的表现

可用的采样器

以下是当前库中已实现的模型列表。

采样器 模型 论文 官方实现
正态先验 (NormalSampler) 所有模型 链接
高斯混合 (GaussianMixtureSampler) 所有模型 链接 链接
两阶段VAE采样器 (TwoStageVAESampler) 所有基于VAE的模型 链接 链接
单位球面均匀采样器 (HypersphereUniformSampler) SVAE 链接 链接
波兰卡圆盘采样器 (PoincareDiskSampler) PoincareVAE 链接 链接
VAMP先验采样器 (VAMPSampler) VAMP 链接 链接
流形采样器 (RHVAESampler) RHVAE 链接 链接
掩码自回归流采样器 (MAFSampler) 所有模型 链接 链接
逆向自回归流采样器 (IAFSampler) 所有模型 链接 链接
PixelCNN (PixelCNNSampler) VQVAE 链接

可复现性

我们通过复现原始论文中的一些结果来验证实现的正确性,前提是官方代码已发布,或者论文的实验部分提供了足够详细的信息。更多详情请参阅可复现性

启动模型训练

要启动模型训练,只需调用一个 TrainingPipeline 实例即可。

>>> from pythae.pipelines import TrainingPipeline
>>> from pythae.models import VAE, VAEConfig
>>> from pythae.trainers import BaseTrainerConfig

>>> # 设置训练配置
>>> my_training_config = BaseTrainerConfig(
...	output_dir='my_model',
...	num_epochs=50,
...	learning_rate=1e-3,
...	per_device_train_batch_size=200,
...	per_device_eval_batch_size=200,
...	train_dataloader_num_workers=2,
...	eval_dataloader_num_workers=2,
...	steps_saving=20,
...	optimizer_cls="AdamW",
...	optimizer_params={"weight_decay": 0.05, "betas": (0.91, 0.995)},
...	scheduler_cls="ReduceLROnPlateau",
...	scheduler_params={"patience": 5, "factor": 0.5}
... )
>>> # 设置模型配置 
>>> my_vae_config = model_config = VAEConfig(
...	input_dim=(1, 28, 28),
...	latent_dim=10
... )
>>> # 构建模型
>>> my_vae_model = VAE(
...	model_config=my_vae_config
... )
>>> # 构建流水线
>>> pipeline = TrainingPipeline(
... 	training_config=my_training_config,
... 	model=my_vae_model
... )
>>> # 启动流水线
>>> pipeline(
...	train_data=your_train_data, # 必须是 torch.Tensor、np.array 或 torch 数据集
...	eval_data=your_eval_data # 必须是 torch.Tensor、np.array 或 torch 数据集
... )

训练结束后,最佳模型权重、模型配置和训练配置将存储在 my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss 文件夹中的 final_model 目录下(其中 my_modelBaseTrainerConfigoutput_dir 参数)。如果进一步设置了 steps_saving 参数,则还会在 my_model/MODEL_NAME_training_YYYY-MM-DD_hh-mm-ss 中出现名为 checkpoint_epoch_k 的文件夹,其中包含第 k 个 epoch 时的最佳模型权重、优化器、调度器、配置和训练配置。

在基准数据集上启动训练

我们还提供了一个训练脚本示例此处,可用于在基准数据集(mnist、cifar10、celeba 等)上训练模型。该脚本可以通过以下命令行启动:

python training.py --dataset mnist --model_name ae --model_config 'configs/ae_config.json' --training_config 'configs/base_training_config.json'

有关此脚本的更多详细信息,请参阅 README.md

启动数据生成

使用 GenerationPipeline

从已训练模型中启动数据生成的最简单方法是使用 Pythae 内置的 GenerationPipeline。假设您想使用 MAFSampler 生成 100 个样本,您只需执行以下步骤:1) 重新加载已训练的模型,2) 定义采样器的配置,3) 创建并启动 GenerationPipeline,如下所示:

>>> from pythae.models import AutoModel
>>> from pythae.samplers import MAFSamplerConfig
>>> from pythae.pipelines import GenerationPipeline
>>> # 恢复已训练的模型
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> my_sampler_config = MAFSamplerConfig(
...	n_made_blocks=2,
...	n_hidden_in_made=3,
...	hidden_size=128
... )
>>> # 构建流水线
>>> pipe = GenerationPipeline(
...	model=my_trained_vae,
...	sampler_config=my_sampler_config
... )
>>> # 启动数据生成
>>> generated_samples = pipe(
...	num_samples=args.num_samples,
...	return_gen=True, # 如果为假则不返回任何内容
...	train_data=train_data, # 用于拟合采样器
...	eval_data=eval_data, # 用于拟合采样器
...	training_config=BaseTrainerConfig(num_epochs=200) # 用于拟合采样器的训练配置
... )

使用采样器

或者,你也可以直接通过采样器从训练好的模型中启动数据生成过程。例如,要使用你的采样器生成新数据,可以运行以下代码:

>>> from pythae.models import AutoModel
>>> from pythae.samplers import NormalSampler
>>> # 获取训练好的模型
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> # 定义你的采样器
>>> my_samper = NormalSampler(
...	model=my_trained_vae
... )
>>> # 生成样本
>>> gen_data = my_samper.sample(
...	num_samples=50,
...	batch_size=10,
...	output_dir=None,
...	return_gen=True
... )

如果你将 output_dir 设置为一个特定路径,生成的图像将会以 .png 文件的形式保存,文件名分别为 00000000.png, 00000001.png 等等。

只要模型适配,采样器就可以用于任何模型。例如,GaussianMixtureSampler 实例可以用于任何模型,而 VAMPSampler 只能与 VAMP 模型一起使用。请查看 此处 以了解哪些采样器适用于你的模型。请注意,某些采样器,比如 GaussianMixtureSampler,在使用前可能需要调用 fit 方法进行拟合。以下是 GaussianMixtureSampler 的示例:

>>> from pythae.models import AutoModel
>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # 获取训练好的模型
>>> my_trained_vae = AutoModel.load_from_folder(
...	'path/to/your/trained/model'
... )
>>> # 定义你的采样器
... gmm_sampler_config = GaussianMixtureSamplerConfig(
...	n_components=10
... )
>>> my_samper = GaussianMixtureSampler(
...	sampler_config=gmm_sampler_config,
...	model=my_trained_vae
... )
>>> # 拟合采样器
>>> gmm_sampler.fit(train_dataset)
>>> # 生成样本
>>> gen_data = my_samper.sample(
...	num_samples=50,
...	batch_size=10,
...	output_dir=None,
...	return_gen=True
... )

自定义自编码器架构

Pythae 提供了在 VAE 模型中定义自定义神经网络的可能性。例如,假设你想训练一个带有特定编码器和解码器的 Wassertstein AE,你可以这样做:

>>> from pythae.models.nn import BaseEncoder, BaseDecoder
>>> from pythae.models.base.base_utils import ModelOutput
>>> class My_Encoder(BaseEncoder):
...	def __init__(self, args=None): # Args 是一个 ModelConfig 实例
...		BaseEncoder.__init__(self)
...		self.layers = my_nn_layers()
...		
...	def forward(self, x:torch.Tensor) -> ModelOutput:
...		out = self.layers(x)
...		output = ModelOutput(
...			embedding=out # 将编码器的输出放入 ModelOutput 实例中 
...		)
...		return output
...
... class My_Decoder(BaseDecoder):
...	def __init__(self, args=None):
...		BaseDecoder.__init__(self)
...		self.layers = my_nn_layers()
...		
...	def forward(self, x:torch.Tensor) -> ModelOutput:
...		out = self.layers(x)
...		output = ModelOutput(
...			reconstruction=out # 将解码器的输出放入 ModelOutput 实例中
...		)
...		return output
...
>>> my_encoder = My_Encoder()
>>> my_decoder = My_Decoder()

然后构建模型:

>>> from pythae.models import WAE_MMD, WAE_MMD_Config
>>> # 设置模型配置 
>>> my_wae_config = model_config = WAE_MMD_Config(
...	input_dim=(1, 28, 28),
...	latent_dim=10
... )
...
>>> # 构建模型
>>> my_wae_model = WAE_MMD(
...	model_config=my_wae_config,
...	encoder=my_encoder, # 在构建模型时传入你的编码器
...	decoder=my_decoder # 在构建模型时传入你的解码器
... )

重要提示 1:对于所有基于 AE 的模型(AE、WAE、RAE_L2、RAE_GP),编码器和解码器都必须返回一个 ModelOutput 实例。对于编码器,ModelOutput 实例必须在 embedding 键下包含嵌入向量。对于解码器,ModelOutput 实例必须在 reconstruction 键下包含重建结果。

重要提示 2:对于所有基于 VAE 的模型(VAE、BetaVAE、IWAE、HVAE、VAMP、RHVAE),编码器和解码器都必须返回一个 ModelOutput 实例。对于编码器,ModelOutput 实例必须分别在 embeddinglog_covariance 键下包含嵌入向量和对数协方差矩阵(形状为 batch_size × latent_space_dim)。对于解码器,ModelOutput 实例必须在 reconstruction 键下包含重建结果。

使用基准神经网络

你还可以找到针对最常见数据集(如 MNIST、CIFAR、CELEBA 等)的预定义神经网络架构,可以通过以下方式加载:

>>> from pythae.models.nn.benchmark.mnist import (
...	Encoder_Conv_AE_MNIST, # 用于基于 AE 的模型(仅返回嵌入)
...	Encoder_Conv_VAE_MNIST, # 用于基于 VAE 的模型(返回嵌入和对数协方差)
...	Decoder_Conv_AE_MNIST
... )

mnist 替换为 cifarceleba,即可访问其他神经网络。

使用 Pythae 进行分布式训练

v0.1.0 起,Pythae 现已支持使用 PyTorch 的 DDP 进行分布式训练。这使你能够利用多 GPU 和/或多节点训练,更快地在更大的数据集上训练你喜欢的 VAE。

为此,你可以编写一个 Python 脚本,然后由启动程序(如集群上的 srun)来运行该脚本。脚本中唯一需要做的就是在训练配置中直接指定与分布式环境相关的参数,如下所示:

>>> training_config = BaseTrainerConfig(
...     num_epochs=10,
...     learning_rate=1e-3,
...     per_device_train_batch_size=64,
...     per_device_eval_batch_size=64,
...     train_dataloader_num_workers=8,
...     eval_dataloader_num_workers=8,
...     dist_backend="nccl", # 分布式后端
...     world_size=8 # 使用的 GPU 数量(节点数 × 每个节点的 GPU 数量),
...     rank=5 # 全局 GPU ID,
...     local_rank=1 # 节点内的 GPU ID,
...     master_addr="localhost" # 主节点地址,
...     master_port="12345" # 主节点端口,
... )

请参阅此 示例脚本,其中定义了一个在 ImageNet 数据集上进行多 GPU VQVAE 训练的脚本。请注意,分布式环境变量(world_sizerank 等)的获取方式可能因你使用的集群和启动程序而异。

基准测试

以下是使用 Pythae 在 V100 16GB GPU 上对 MNIST 数据集训练 100 个 epoch 的向量量化变分自编码器 (VQ-VAE)、在 FFHQ(1024×1024 图像)上训练 50 个 epoch,以及在 V100 32GB GPU 上对 ImageNet-1k 训练 20 个 epoch 的训练时间。

训练数据 1 张 GPU 4 张 GPU 2×4 张 GPU
MNIST (VQ-VAE) 28×28 图像(5 万张) 235.18 秒 62.00 秒 35.86 秒
FFHQ 1024×1024 (VQVAE) 1024×1024 RGB 图像(6 万张) 19 小时 1 分钟 5 小时 6 分钟 2 小时 37 分钟
ImageNet-1k 128×128 (VQVAE) 128×128 RGB 图像(约 120 万张) 6 小时 25 分钟 1 小时 41 分钟 51 分钟 26 秒

对于每个数据集,我们提供了基准测试脚本,可在这里找到。

使用 HuggingFace Hub 分享你的模型 🤗

Pythae 还允许你将模型分享到 HuggingFace Hub。为此你需要:

  • 一个有效的 HuggingFace 账号
  • 在你的虚拟环境中安装了 huggingface_hub 包。如果没有,可以通过以下命令安装:
$ python -m pip install huggingface_hub
  • 使用以下命令登录你的 HuggingFace 账号:
$ huggingface-cli login

将模型上传到 Hub

任何 Pythae 模型都可以通过 push_to_hf_hub 方法轻松上传:

>>> my_vae_model.push_to_hf_hub(hf_hub_path="your_hf_username/your_hf_hub_repo")

注意: 如果 your_hf_hub_repo 已经存在且不为空,文件将会被覆盖。如果该仓库不存在,系统会创建一个同名的文件夹。

从 Hub 下载模型

同样地,你可以直接从 Hub 下载或重新加载任何 Pythae 模型,只需使用 load_from_hf_hub 方法:

>>> from pythae.models import AutoModel
>>> my_downloaded_vae = AutoModel.load_from_hf_hub(hf_hub_path="path_to_hf_repo")

使用 wandb 监控你的实验 🧪

Pythae 还集成了实验跟踪工具 wandb,允许用户存储配置、监控训练过程,并通过图形界面比较不同运行的结果。要使用此功能,你需要:

  • 一个有效的 wandb 账号
  • 在你的虚拟环境中安装了 wandb 包。如果没有,可以通过以下命令安装:
$ pip install wandb
  • 使用以下命令登录你的 wandb 账号:
$ wandb login

创建 WandbCallback

在 Pythae 中使用 wandb 启动实验监控非常简单。用户只需创建一个 WandbCallback 实例……

>>> # 创建回调
>>> from pythae.trainers.training_callbacks import WandbCallback
>>> callbacks = [] # TrainingPipeline 需要一个回调列表
>>> wandb_cb = WandbCallback() # 构建回调
>>> # 设置回调
>>> wandb_cb.setup(
...	training_config=your_training_config, # 训练配置
...	model_config=your_model_config, # 模型配置
...	project_name="your_wandb_project", # 指定你的 wandb 项目
...	entity_name="your_wandb_entity", # 指定你的 wandb 实体
... )
>>> callbacks.append(wandb_cb) # 添加到回调列表

……然后将其传递给 TrainingPipeline

>>> pipeline = TrainingPipeline(
...	training_config=config,
...	model=model
... )
>>> pipeline(
...	train_data=train_dataset,
...	eval_data=eval_dataset,
...	callbacks=callbacks # 将回调传递给 TrainingPipeline,大功告成!
... )
>>> # 你可以登录 https://wandb.ai/your_wandb_entity/your_wandb_project 来监控你的训练

请参阅详细教程

使用 mlflow 监控你的实验 🧪

Pythae 还集成了实验跟踪工具 mlflow,允许用户存储配置、监控训练过程,并通过图形界面比较不同运行的结果。要使用此功能,你需要:

  • 在你的虚拟环境中安装了 mlflow 包。如果没有,可以通过以下命令安装:
$ pip install mlflow

创建 MLFlowCallback

在 Pythae 中使用 mlflow 启动实验监控非常简单。用户只需创建一个 MLFlowCallback 实例……

>>> # 创建回调
>>> from pythae.trainers.training_callbacks import MLFlowCallback
>>> callbacks = [] # TrainingPipeline 需要一个回调列表
>>> mlflow_cb = MLFlowCallback() # 构建回调
>>> # 设置回调
>>> mlflow_cb.setup(
...	training_config=your_training_config, # 训练配置
...	model_config=your_model_config, # 模型配置
...	run_name="mlflow_cb_example", # 指定你的 mlflow 运行名称
... )
>>> callbacks.append(mlflow_cb) # 添加到回调列表

……然后将其传递给 TrainingPipeline

>>> pipeline = TrainingPipeline(
...	training_config=config,
...	model=model
... )
>>> pipeline(
...	train_data=train_dataset,
...	eval_data=eval_dataset,
...	callbacks=callbacks # 将回调传递给 TrainingPipeline,大功告成!
... )

你可以在包含 ./mlruns 的目录中运行以下命令来可视化指标:

$ mlflow ui

请参阅详细教程

使用 comet_ml 监控你的实验 🧪

Pythae 还集成了实验跟踪工具 comet_ml,允许用户存储配置、监控训练过程,并通过图形界面比较不同运行的结果。要使用此功能,你需要:

  • 在你的虚拟环境中安装了 comet_ml 包。如果没有,可以通过以下命令安装:
$ pip install comet_ml

创建 CometCallback

在 Pythae 中使用 comet_ml 启动实验监控非常简单。用户只需创建一个 CometCallback 实例……

>>> # 创建回调
>>> from pythae.trainers.training_callbacks import CometCallback
>>> callbacks = [] # TrainingPipeline 需要一个回调列表
>>> comet_cb = CometCallback() # 构建回调
>>> # 设置回调
>>> comet_cb.setup(
...	training_config=training_config, # 训练配置
...	model_config=model_config, # 模型配置
...	api_key="your_comet_api_key", # 指定你的 comet API 密钥
...	project_name="your_comet_project", # 指定你的 comet 项目
...	#offline_run=True, # 以离线模式运行
...	#offline_directory='my_offline_runs' # 设置用于存储离线运行的目录
... )
>>> callbacks.append(comet_cb) # 添加到回调列表

……然后将其传递给 TrainingPipeline

>>> pipeline = TrainingPipeline(
...	training_config=config,
...	model=model
... )
>>> pipeline(
...	train_data=train_dataset,
...	eval_data=eval_dataset,
...,callbacks=callbacks # 将回调传递给 TrainingPipeline,大功告成!
... )
>>> # 你可以登录 https://comet.com/your_comet_username/your_comet_project 来监控你的训练

请参阅详细教程

获取代码

为了帮助您理解 pythae 的工作原理以及如何使用该库训练您的模型,我们还提供了以下教程:

处理问题 🛠️

如果您在运行代码时遇到任何问题,或希望添加新的功能/模型,请在 github 上提交一个问题

贡献 🚀

您想通过添加一个模型、采样器,或者只是修复一个 bug 来为这个库做出贡献吗?太棒了!非常感谢!请参阅 CONTRIBUTING.md,以了解主要的贡献指南。

结果

重建效果

首先让我们来看看从评估集抽取的重建样本。

模型 MNIST CELEBA
评估数据 Eval AE
自编码器 AE AE
变分自编码器 VAE VAE
Beta-变分自编码器 Beta Beta Normal
线性流变分自编码器 VAE_LinNF VAE_IAF Normal
IAF变分自编码器 VAE_IAF VAE_IAF Normal
解耦合Beta-变分自编码器 Disentangled Beta Disentangled Beta
FactorVAE FactorVAE FactorVAE
BetaTCVAE BetaTCVAE BetaTCVAE
IWAE IWAE IWAE
MSSSIM_VAE MSSSIM VAE MSSSIM VAE
WAE WAE WAE
INFO VAE INFO INFO
VAMP VAMP VAMP
SVAE SVAE SVAE
对抗自编码器 AAE AAE
VAE_GAN VAEGAN VAEGAN
VQVAE VQVAE VQVAE
HVAE HVAE HVAE
RAE_L2 RAE L2 RAE L2
RAE_GP RAE GMM RAE GMM
黎曼哈密顿变分自编码器 (RHVAE) RHVAE RHVAE RHVAE

生成效果

在这里,我们展示了使用库中实现的各个模型以及不同采样器生成的样本。

模型 MNIST CELEBA
AE + 高斯混合采样器 AE GMM AE GMM
VAE + 正态采样器 VAE 正常 VAE 正常
VAE + 高斯混合采样器 VAE GMM VAE GMM
VAE + 两阶段VAE采样器 VAE 2阶段 VAE 2阶段
VAE + MAF采样器 VAE MAF VAE MAF
Beta-VAE + 正态采样器 Beta 正常 Beta 正常
VAE Lin NF + 正态采样器 VAE_LinNF 正常 VAE_LinNF 正常
VAE IAF + 正态采样器 VAE_IAF 正常 VAE IAF 正常
解耦合Beta-VAE + 正态采样器 解耦合Beta 正常 解耦合Beta 正常
FactorVAE + 正态采样器 FactorVAE 正常 FactorVAE 正常
BetaTCVAE + 正态采样器 BetaTCVAE 正常 BetaTCVAE 正常
IWAE + 正态采样器 IWAE 正常 IWAE 正常
MSSSIM_VAE + 正态采样器 MSSSIM_VAE 正常 MSSSIM_VAE 正常
WAE + 正态采样器 WAE 正常 WAE 正常
INFO VAE + 正态采样器 INFO 正常 INFO 正常
SVAE + 超球面均匀采样器 SVAE 球体 SVAE 球体
VAMP + VAMP采样器 VAMP Vamp VAMP Vamp
对抗自编码器 + 正态采样器 AAE_正态 AAE_正态
VAEGAN + 正态采样器 VAEGAN_正态 VAEGAN_正态
VQVAE + MAF采样器 VQVAE_MAF VQVAE_MAF
HVAE + 正态采样器 HVAE 正常 HVAE GMM
RAE_L2 + 高斯混合采样器 RAE L2 GMM RAE L2 GMM
RAE_GP + 高斯混合采样器 RAE GMM RAE GMM
黎曼哈密顿VAE (RHVAE) + RHVAE采样器 RHVAE RHVAE RHVAE RHVAE

引用

如果您觉得这项工作有用,或在您的研究中使用了它,请考虑引用我们。

@inproceedings{chadebec2022pythae,
 author = {Chadebec, Cl\'{e}ment and Vincent, Louis and Allassonniere, Stephanie},
 booktitle = {神经信息处理系统进展},
 editor = {S. Koyejo and S. Mohamed and A. Agarwal and D. Belgrave and K. Cho and A. Oh},
 pages = {21575--21589},
 publisher = {柯兰联合公司},
 title = {Pythae:在 Python 中统一生成式自编码器——一个基准测试用例},
 volume = {35},
 year = {2022}
}

版本历史

v0.1.22023/09/06
v0.1.12023/02/23
v0.1.02023/02/06
v0.0.92022/10/19
v.0.0.82022/09/07
v.0.0.72022/09/03
v.0.0.62022/07/22
v0.0.52022/07/07
v.0.0.32022/07/05
v.0.0.22022/07/04
v.0.0.12022/06/14

常见问题

相似工具推荐

openclaw

OpenClaw 是一款专为个人打造的本地化 AI 助手,旨在让你在自己的设备上拥有完全可控的智能伙伴。它打破了传统 AI 助手局限于特定网页或应用的束缚,能够直接接入你日常使用的各类通讯渠道,包括微信、WhatsApp、Telegram、Discord、iMessage 等数十种平台。无论你在哪个聊天软件中发送消息,OpenClaw 都能即时响应,甚至支持在 macOS、iOS 和 Android 设备上进行语音交互,并提供实时的画布渲染功能供你操控。 这款工具主要解决了用户对数据隐私、响应速度以及“始终在线”体验的需求。通过将 AI 部署在本地,用户无需依赖云端服务即可享受快速、私密的智能辅助,真正实现了“你的数据,你做主”。其独特的技术亮点在于强大的网关架构,将控制平面与核心助手分离,确保跨平台通信的流畅性与扩展性。 OpenClaw 非常适合希望构建个性化工作流的技术爱好者、开发者,以及注重隐私保护且不愿被单一生态绑定的普通用户。只要具备基础的终端操作能力(支持 macOS、Linux 及 Windows WSL2),即可通过简单的命令行引导完成部署。如果你渴望拥有一个懂你

349.3k|★★★☆☆|1周前
Agent开发框架图像

stable-diffusion-webui

stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。

162.1k|★★★☆☆|2周前
开发框架图像Agent

everything-claude-code

everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上

160.4k|★★☆☆☆|今天
开发框架Agent语言模型

ComfyUI

ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。

109.2k|★★☆☆☆|昨天
开发框架图像Agent

gemini-cli

gemini-cli 是一款由谷歌推出的开源 AI 命令行工具,它将强大的 Gemini 大模型能力直接集成到用户的终端环境中。对于习惯在命令行工作的开发者而言,它提供了一条从输入提示词到获取模型响应的最短路径,无需切换窗口即可享受智能辅助。 这款工具主要解决了开发过程中频繁上下文切换的痛点,让用户能在熟悉的终端界面内直接完成代码理解、生成、调试以及自动化运维任务。无论是查询大型代码库、根据草图生成应用,还是执行复杂的 Git 操作,gemini-cli 都能通过自然语言指令高效处理。 它特别适合广大软件工程师、DevOps 人员及技术研究人员使用。其核心亮点包括支持高达 100 万 token 的超长上下文窗口,具备出色的逻辑推理能力;内置 Google 搜索、文件操作及 Shell 命令执行等实用工具;更独特的是,它支持 MCP(模型上下文协议),允许用户灵活扩展自定义集成,连接如图像生成等外部能力。此外,个人谷歌账号即可享受免费的额度支持,且项目基于 Apache 2.0 协议完全开源,是提升终端工作效率的理想助手。

100.8k|★★☆☆☆|1周前
插件Agent图像

markitdown

MarkItDown 是一款由微软 AutoGen 团队打造的轻量级 Python 工具,专为将各类文件高效转换为 Markdown 格式而设计。它支持 PDF、Word、Excel、PPT、图片(含 OCR)、音频(含语音转录)、HTML 乃至 YouTube 链接等多种格式的解析,能够精准提取文档中的标题、列表、表格和链接等关键结构信息。 在人工智能应用日益普及的今天,大语言模型(LLM)虽擅长处理文本,却难以直接读取复杂的二进制办公文档。MarkItDown 恰好解决了这一痛点,它将非结构化或半结构化的文件转化为模型“原生理解”且 Token 效率极高的 Markdown 格式,成为连接本地文件与 AI 分析 pipeline 的理想桥梁。此外,它还提供了 MCP(模型上下文协议)服务器,可无缝集成到 Claude Desktop 等 LLM 应用中。 这款工具特别适合开发者、数据科学家及 AI 研究人员使用,尤其是那些需要构建文档检索增强生成(RAG)系统、进行批量文本分析或希望让 AI 助手直接“阅读”本地文件的用户。虽然生成的内容也具备一定可读性,但其核心优势在于为机器

93.4k|★★☆☆☆|1周前
插件开发框架