paxml
Paxml 是 Google 开源的一套基于 JAX 的机器学习框架,专为“超大模型 + 大规模分布式训练”而生。它把复杂的并行策略、硬件调度、实验配置都封装成可插拔的模块,开发者只需写一份配置,就能在 TPU Pod 上把数十亿甚至上百亿参数的模型高效跑起来。官方测试显示,Paxml 在同等算力下的模型 FLOPs 利用率处于业界领先水平,显著缩短训练时间和成本。
如果你正在做 NLP、多模态或通用大模型的研究,需要快速验证想法、横向扩展实验规模,Paxml 会是趁手的利器;云 TPU 用户可直接用一条命令拉起环境,本地研究者也能在 GPU 上小步快跑。它支持 SPMD(pjit)与传统 pmap 两种并行模式,内置大量可复用的模型模板与教程 Notebook,真正做到“配置即实验”。
使用场景
一家 20 人规模的初创公司正在训练一个 70 亿参数的中文对话大模型,用于给 B 端客户提供智能客服 API。团队只有 2 名算法工程师和 1 名 MLOps,预算有限,希望在 4 周内完成首轮迭代。
没有 paxml 时
- 需要手写大量 Jax+pjit 代码才能把模型拆到 8 张 TPU v4 上,光是调试张量并行和流水并行就花了 5 天
- 训练 10 k step 时 GPU/TPU 利用率只有 38 %,大量时间在等待通信,导致 4 周排期被拉长到 7 周
- 超参数一改就要重新写配置脚本,实验管理靠 Excel 手动记录,结果经常对不上
- 想在云上拉起 64 卡做消融实验,需要额外写 200 行 bash 脚本,还踩了配额和镜像版本坑
使用 paxml 后
- 直接复用
LmCloudSpmd7B模板,30 行 YAML 就完成 8 卡并行配置,2 小时跑通第一次训练 - 内置的
flop_utilization指标稳定在 62 %,同样 10 k step 训练时间从 3.5 天降到 2.1 天 - 通过
--exp参数一键切换学习率、序列长度等超参数,实验结果自动写进 TensorBoard 和 GCS,回溯零成本 - 需要扩容到 64 卡时,把
ACCELERATOR=v4-64改一行即可,paxml 自动处理 slice 划分和 checkpoint 分片
paxml 让这支小团队在 4 周内如期交付了首个可用模型,并把后续迭代周期缩短一半。
运行环境要求
- Linux
- 可选
- 官方示例基于 Google Cloud TPU v4(v4-8/v4-64/v4-128/v4-384)
- NVIDIA GPU 支持由 NVIDIA Rosetta 分支提供,未说明具体型号与显存要求
未说明

快速开始
Paxml(又名 Pax)
Pax 是一个用于在 Jax 之上配置和运行机器学习实验的框架。
快速入门
设置 Cloud TPU VM
有关如何启动 Cloud TPU 项目,我们可参考 此页面 获取更详尽的文档。以下命令足以从一台主机上创建一台拥有 8 个核心的 Cloud TPU VM。
export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-8
export TPU_NAME=paxml
# 创建 TPU VM
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE --version=$VERSION \
--project=$PROJECT \
--accelerator-type=$ACCELERATOR
如果您使用的是 TPU Pod 片段,请参阅 本指南。请使用 gcloud 并通过 --worker=all 选项,在本地机器上运行所有命令:
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE \
--worker=all --command="<commands>"
以下快速入门章节假设您在单机 TPU 上运行,因此您可以直接 SSH 到 VM 并在其中执行命令。
gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE
安装 Pax
在成功 SSH 到 VM 后,您可以从 PyPI 安装 Pax 的稳定版,或从 GitHub 安装开发版。
要从 PyPI 安装稳定版(https://pypi.org/project/paxml/):
python3 -m pip install -U pip
python3 -m pip install paxml jax[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
如果在使用原生 Cloud TPU VM 环境时遇到依赖项传递问题,请前往对应的发布分支 rX.Y.Z,并下载 paxml/pip_package/requirements.txt 文件。该文件列出了原生 Cloud TPU VM 环境中所需的所有依赖项的确切版本,而我们正是基于这些版本构建并测试相应版本的发布内容。
git clone -b rX.Y.Z https://github.com/google/paxml
pip install --no-deps -r paxml/paxml/pip_package/requirements.txt
若需从 GitHub 安装开发版,且为了方便编辑代码:
# 首先安装 Praxis 的开发版
git clone https://github.com/google/praxis
pip install -e praxis
git clone https://github.com/google/paxml
pip install -e paxml
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
运行测试模型
# 示例模型:使用 pjit(SPMD)
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \
--job_log_dir=gs://<your-bucket>
# 示例模型:使用 pmap
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps \
--job_log_dir=gs://<your-bucket> \
--pmap_use_tensorstore=True
文档说明
请访问我们的 文档文件夹 以获取相关文档和 Jupyter Notebook 教程。如需了解如何在 Cloud TPU VM 上运行 Jupyter Notebook,请参阅以下部分。
运行笔记本
您可以在刚刚安装 Pax 的 TPU VM 中运行 示例笔记本。
在 v4-8 中启用笔记本的步骤
使用端口转发 SSH 登录 TPU VM:
gcloud compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_NAME --zone=$ZONE --ssh-flag="-4 -L 8080:localhost:8080"在 TPU VM 上安装 Jupyter Notebook,并降级 markupsafe:
pip install notebook pip install markupsafe==2.0.1导出
jupyter路径:export PATH=/home/$USER/.local/bin:$PATH将 示例笔记本 复制到 TPU VM:
gcloud compute tpus tpu-vm scp $TPU_NAME:<路径在 TPU 中> <本地笔记本路径> --zone=$ZONE --project=$PROJECT在 TPU VM 上启动 Jupyter Notebook,并记录 Jupyter Notebook 生成的令牌:
jupyter notebook --no-browser --port=8080然后在本地浏览器中访问:http://localhost:8080/,并输入提供的令牌。
注意:如果您需要在第一个笔记本仍在占用 TPU 时启动第二个笔记本,可以运行 pkill -9 python3 来释放 TPU。
在 GPU 上运行
注意:NVIDIA 已发布 Pax 的更新版本,支持 H100 FP8 并大幅提升了 GPU 性能。如需了解更多详情及使用说明,请访问 NVIDIA Rosetta 仓库。
常见问题解答
Pax 适用于 Jax,您可以在 此处 查看在 Cloud TPU 上运行 Jax 作业的详细信息;此外,您也可以在 此处 查看在 Cloud TPU Pod 上运行 Jax 作业的相关信息。
如果遇到依赖项错误,请参考您所安装的稳定版对应分支中的
requirements.txt文件。 例如,对于 稳定版 0.4.0,请使用分支r0.4.0,并参考 requirements.txt 文件,以获取稳定版所用依赖项的确切版本。
示例收敛运行
以下是针对 c4 数据集 的一些示例收敛运行。
c4 数据集上的 1B 模型
您可以在 TPU v4-8 上,使用来自 c4.py 的配置 C4Spmd1BAdam4Replicas,在 c4 数据集上运行一个参数量为 1B 的模型,具体操作如下:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \
--job_log_dir=gs://<your-bucket>
您可以通过以下方式观察损失曲线和 log perplexity 图表:
c4 数据集上的 16B 模型
您可以在 TPU v4-64 上,使用来自 c4.py 的配置 C4Spmd16BAdam32Replicas,在 c4 数据集上运行一个参数量为 16B 的模型,具体操作如下:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd16BAdam32Replicas \
--job_log_dir=gs://<your-bucket>
您可以通过以下方式观察损失曲线和 log perplexity 图表:
GPT3-XL 模型在 c4 数据集上的应用
您可以通过使用 c4.py 中的配置 C4SpmdPipelineGpt3SmallAdam64Replicas,在 TPU v4-128 上运行 GPT3-XL 模型,以 c4 数据集为基准。具体操作如下:
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4SpmdPipelineGpt3SmallAdam64Replicas \
--job_log_dir=gs://<your-bucket>
您将能够观察到损失曲线以及 log perplexity 图表,如下所示:
 
在 Cloud TPU v4 上进行基准测试
PaLM 论文(https://arxiv.org/abs/2204.02311)提出了一种名为“模型 FLOPs 利用率”(MFU)的效率指标。该指标通过观测到的吞吐量(例如,语言模型每秒处理的 token 数)与系统在充分利用峰值 FLOPs 时的理论最大吞吐量之比来衡量。与其他计算利用率的衡量方式不同,MFU 不会计入在反向传播过程中因激活重置而消耗的 FLOPs,因此,以 MFU 衡量的效率直接反映了端到端训练的速度。
为了评估 Pax 在 TPU v4 Pod 上针对关键工作负载的 MFU 水平,我们对一系列仅包含解码器的 Transformer 语言模型(从数十亿参数到数万亿参数不等)进行了深入的基准测试,这些模型均基于 c4 数据集 进行训练。下图展示了采用“弱扩展”模式下的训练效率——我们根据所使用的芯片数量按比例扩大模型规模。

Pax 在多切片环境中的应用
本仓库中所提及的多切片配置,分别对应于: 1. 单切片配置 用于语法和模型架构的定义, 2. MaxText 仓库 则用于配置值的设置。
我们提供了 c4_multislice.py 中的示例运行脚本,作为 Pax 在多切片环境中的起点。
使用排队资源配置 Cloud TPU 虚拟机
有关如何为多切片 Cloud TPU 项目使用排队资源的更详尽文档,请参阅 此页面。以下步骤介绍了为运行本仓库中示例配置而设置 TPU 的所需操作。
export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-128 # 或 v4-384,具体取决于您运行的配置
例如,若要针对 v4-128 的 2 个切片运行 C4Spmd22BAdam2xv4_128,您需要按照以下方式配置 TPU:
export TPU_PREFIX=<your-prefix> # 新创建的 TPU 将基于此前缀命名
export QR_ID=$TPU_PREFIX
export NODE_COUNT=<number-of-slices> # 1、2 或 4,具体取决于您运行的配置
创建 TPU 虚拟机:
gcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=$ACCELERATOR --runtime-version=tpu-vm-v4-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX
安装 Pax
前面介绍的安装命令需在所有切片的所有 worker 上执行。您可以选择:1)分别登录每个 worker 和每个切片;或者 2)使用带有 --worker=all 标志的 for 循环,如以下命令所示。
for ((i=0; i<$NODE_COUNT; i++))
do
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-$i --zone=us-central2-b --worker=all --command="pip install paxml && pip install orbax==0.1.1 && pip install \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
done
运行多切片模型测试
要运行多切片配置,需打开与您的 NODE_COUNT 相同数量的终端。对于我们的 2 个切片实验(C4Spmd22BAdam2xv4_128),请打开两个终端。然后,分别从每个终端单独运行上述命令。
从终端 0 开始,运行切片 0 的训练命令,如下所示:
export TPU_PREFIX=<your-prefix>
export EXP_NAME=C4Spmd22BAdam2xv4_128
export LIBTPU_INIT_ARGS=\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\"
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-0 --zone=us-central2-b --worker=all \
--command="LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS \
python3 /home/yooh/.local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs://<your-bucket>"
从终端 1 同时运行切片 1 的训练命令,如下所示:
export TPU_PREFIX=<your-prefix>
export EXP_NAME=C4Spmd22BAdam2xv4_128
export LIBTPU_INIT_ARGS=\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\"
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-1 --zone=us-central2-b --worker=all \
--command="LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS \
python3 /home/yooh/.local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs://<your-bucket>"
MaxText 转换为 Pax
本表格详细介绍了 MaxText 变量名称如何被转换为 Pax 的具体细节。
请注意,MaxText 为多个参数(如 base_num_decoder_layers、base_emb_dim、base_mlp_dim、base_num_heads)设置了“缩放因子”,用于计算最终的值。
另外需要说明的是,虽然 Pax 将 DCN 和 ICN 的 MESH_SHAPE 作为数组来处理,但在 MaxText 中,DCN 和 ICI 分别对应着独立的变量:data_parallelism、fsdp_parallelism 和 tensor_parallelism。由于这些变量默认设置为 1,因此在本次翻译表中,仅记录了那些值大于 1 的变量。
也就是说,ICI_MESH_SHAPE = [ici_data_parallelism, ici_fsdp_parallelism, ici_tensor_parallelism],而 DCN_MESH_SHAPE = [dcn_data_parallelism, dcn_fsdp_parallelism, dcn_tensor_parallelism]。
| Pax C4Spmd22BAdam2xv4_128 | MaxText 2xv4-128.sh | (在应用缩放因子后) | ||
|---|---|---|---|---|
| 缩放因子(应用于接下来的四个变量) | 3 | |||
| NUM_LAYERS | 48 | base_num_decoder_layers | 16 | 48 |
| MODEL_DIMS | 6144 | base_emb_dim | 2048 | 6144 |
| HIDDEN_DIMS | 24576 | MODEL_DIMS * 4 (= base_mlp_dim) | 8192 | 24576 |
| NUM_HEADS | 24 | base_num_heads | 8 | 24 |
| DIMS_PER_HEAD | 256 | head_dim | 256 | |
| PERCORE_BATCH_SIZE | 16 | per_device_batch_size | 16 | |
| MAX_SEQ_LEN | 1024 | max_target_length | 1024 | |
| VOCAB_SIZE | 32768 | vocab_size | 32768 | |
| FPROP_DTYPE | jnp.bfloat16 | dtype | bfloat16 | |
| USE_REPEATED_LAYER | TRUE | |||
| SUMMARY_INTERVAL_STEPS | 10 | |||
| ICI_MESH_SHAPE | [1, 64, 1] | ici_fsdp_parallelism | 64 | |
| DCN_MESH_SHAPE | [2, 1, 1] | dcn_data_parallelism | 2 |
数据输入
简介
输入是一个 BaseInput 类的实例,用于将数据输入模型,以进行训练、评估或解码。
class BaseInput:
def get_next(self):
pass
def reset(self):
pass
它类似于一个迭代器:get_next() 会返回一个 NestedMap,其中每个字段
都是一维数值数组,其首维表示批量大小。
每个输入由 BaseInput.HParams 的子类进行配置。在本页面中,我们用 p 来表示 BaseInput.Params 的实例,并将其初始化为 input。
多主机输入
在 Pax 中,数据始终是多主机的:每个 Jax 进程都会拥有独立且独立的 input 实例。它们的参数会拥有不同的 p.infeed_host_index,由 Pax 自动设置。
因此,在每台主机上看到的本地批量大小为 p.batch_size,而全局批量大小则为 (p.batch_size * p.num_infeed_hosts)。通常情况下,我们会将 p.batch_size 设置为 jax.local_device_count() * PERCORE_BATCH_SIZE。
由于具备多主机特性,input 必须进行适当的分片处理。
在训练时,每个 input 从不输出相同的批次;而在对有限数据集进行评估时,每个 input 应在相同数量的批次后终止。最佳方案是让输入实现能够正确地对数据进行分片,使得不同主机上的每个 input 不会相互重叠。如果无法做到这一点,也可以使用不同的随机种子,以避免在训练过程中出现重复的批次。
评估数据的输入
在训练数据中,input.reset() 从不被调用,但在评估(或解码)数据中,可以调用 input.reset()。
对于每次评估(或解码)运行,Pax 会通过多次调用 input.get_next(),从 input 中获取 N 批次。N 的数量可以由用户通过 p.eval_loop_num_batches 固定指定;或者,N 也可以是动态的(p.eval_loop_num_batches=None),即我们不断调用 input.get_next(),直到耗尽所有数据(通过抛出 StopIteration 或 tf.errors.OutOfRange)。
如果 p.reset_for_eval=True,p.eval_loop_num_batches 将被忽略,N 会根据实际需求动态确定,以满足数据的消耗需求。在这种情况下,应将 p.repeat 设置为 False,否则会导致无限的解码/评估过程。
如果 p.reset_for_eval=False,Pax 将获取 p.eval_loop_num_batches 批次。此时,应将 p.repeat=True 设置为 True,以确保数据不会过早耗尽。
需要注意的是,LingvoEvalAdaptor 输入要求 p.reset_for_eval=True。
N: 静态 |
N: 动态 |
|
|---|---|---|
p.reset_for_eval=True |
每次评估运行会使用前 N 批次。尚未支持 eval_loop_num_batches。 |
每个评估运行一次一个 epoch。 |
: : 第一批次开始使用 N 批次。未支持 eval_loop_num_batches。 |
输入必须是有限的。 | |
p.reset_for_eval=False |
每个评估运行会使用非重叠的 N 批次,按滚动方式逐步获取数据,依据 eval_loop_num_batches 的设定。 |
不支持此功能。输入必须无限循环。 |
请将 p.repeat 设置为 False,否则可能导致解码/评估过程无限进行。 |
如果在单个 epoch 内执行解码/评估(即当 p.reset_for_eval=True 时),输入必须正确处理分片问题,确保每个分片在完成相同数量的批次后,以完全相同的步长进行下一轮操作。这通常意味着输入需要对评估数据进行填充。SeqIOInput 和 LingvoEvalAdaptor 会自动完成这一工作(详见下文)。
评估指标
对于大多数输入,我们只会调用 get_next() 来获取数据批次。然而,有一种类型的评估数据例外——在这种情况下,“如何计算指标”也会在输入对象上进行定义。
此功能仅适用于 SeqIOInput,该输入定义了一些标准的评估基准。具体而言,Pax 使用 SeqIO 任务中定义的 predict_metric_fns 和 score_metric_fns 来计算评估指标(尽管 Pax 并不直接依赖于 SeqIO 评估器)。
最佳实践
当模型使用多种输入时,无论是训练/评估阶段之间,还是在预训练/微调不同训练数据之间的切换过程中,用户必须确保各输入所使用的分词器完全一致,尤其是在导入由他人实现的不同输入时。
用户可以通过使用 input.ids_to_strings() 对分词器进行校验,以验证其正确性。
建议始终通过查看若干批次的数据来对数据进行初步检查。用户可以轻松地在 Colab 中重现该参数,并对数据进行检查:
p = ... # 指定预期的输入参数
inp = p.Instantiate()
b = inp.get_next()
print(b)
训练数据通常不应使用固定的随机种子。这是因为如果训练任务被提前终止,训练数据将会开始重复出现。尤其对于 Lingvo 输入,我们建议在训练数据中将 p.input.file_random_seed = 0 设置为默认值。
为了测试分片是否得到正确处理,用户可以手动为 p.num_infeed_hosts 和 p.infeed_host_index 设置不同的值,观察实例化后的输入是否生成不同的批次。
输入类型
Pax 支持三种类型的输入:SeqIO、Lingvo 以及自定义输入。
SeqIO
SeqIOInput 可用于导入数据集。
SeqIO 输入会自动处理评估数据的正确分片与填充。
Lingvo
LingvoInputAdaptor 可用于导入数据集。
该输入完全委托给 Lingvo 实现,而 Lingvo 实现可能自动处理分片,也可能不自动处理分片。
对于基于 GenericInput 的 Lingvo 输入实现,若采用固定 packing_factor,我们建议使用 LingvoInputAdaptorNewBatchSize 来为内部 Lingvo 输入指定更大的批量大小,并将所需的(通常要小得多)批量大小设置为 p.batch_size。
对于评估数据,我们建议使用 LingvoEvalAdaptor,以在单个 epoch 内运行评估时,自动处理分片与填充。
自定义
自定义类继承自 BaseInput。用户可以自行实现子类,通常使用 tf.data 或 SeqIO。
用户也可以直接继承现有的输入类,仅对批次的后处理进行个性化定制。例如:
class MyInput(base_input.LingvoInputAdaptor):
def get_next(self):
batch = super().get_next()
# 修改批次:batch.new_field = ...
return batch
Pax 关键组件
超参数
超参数是定义模型和配置实验的重要组成部分。
为了更好地与 Python 工具集成,Pax/Praxis 采用了一种 Python 风格的 dataclass 配置方式来管理超参数。
class Linear(base_layer.BaseLayer):
"""线性层,无偏置。”
class HParams(BaseHParams):
"""此层类的关联超参数。
属性:
input_dims:输入的深度。
output_dims:输出的深度。
"""
input_dims: int = 0
output_dims: int = 0
嵌套
超参数数据类也可以进行嵌套。如下面的例子所示,linear_tpl 属性是一个嵌套的 Linear.HParams。
class FeedForward(base_layer.BaseLayer):
"""前馈层,带激活函数。”
class HParams(BaseHParams):
"""此层类的关联超参数。
属性:
input_dims:输入的深度。
output_dims:输出的深度。
has_bias:是否添加偏置权重。
linear_tpl:线性层的参数。
activation_tpl:激活函数层的参数。
"""
input_dims: int = 0
output_dims: int = 0
has_bias: bool = True
linear_tpl: BaseHParams = sub_config_field(Linear.HParams)
activation_tpl: activations.BaseActivation.HParams = sub_config_field(
ReLU.HParams)
层
层代表任意函数,该函数可能包含可训练的参数。层可以作为子级包含其他层。层是模型的核心构建模块。层继承自 Flax nn.Module。
通常,层会定义两个方法:
setup
该方法用于创建可训练的权重和子级层。
fprop
该方法定义前向传播函数,根据输入计算出某些输出。此外,fprop 还可以添加摘要信息或跟踪辅助损失。
Fiddle 与共享层
Fiddle 是一个开源的 Python 首先配置库,专为机器学习应用设计。Pax/Praxis 支持与 Fiddle Config/Partial 的互操作性,以及一些高级功能,如即时错误检查和共享参数。
fdl_config = Linear.HParams.config(input_dims=1, output_dims=1)
# 一个拼写错误。
fdl_config.input_dimz = 31337 # 一旦发现拼写错误,就会立即抛出异常,快速捕捉错误!
fdl_partial = Linear.HParams.partial(input_dims=1)
借助 Fiddle,层可以被配置为共享(例如,仅实例化一次,并使用共享的可训练权重)。
模型
模型仅定义网络结构,通常是多个层的集合,并定义了与模型交互的接口,例如解码等。
一些基础模型示例包括:
- 语言模型
- 序列模型
- 分类模型
任务
任务包含多个模型以及学习器/优化器。最简单的任务子类是 SingleTask,它需要以下超参数:
class HParams(base_task.BaseTask.HParams):
"""任务参数。
属性:
name:此任务对象的名称,必须是有效的标识符。
model:底层 JAX 模型,封装了所有层。
train:用于控制任务训练方式的超参数。
metrics:一个 BaseMetrics 汇总类,用于确定如何计算指标。
loss_aggregator:一个 LossAggregator 汇总类,用于确定如何聚合损失(例如,单个损失或多损失)。
vn:用于控制变分噪声的超参数。
发布版本
| PyPI 版本 | 提交 |
|---|---|
| 0.1.0 | 546370f5323ef8b27d38ddc32445d7d3d1e4da9a |
版权所有 2022 Google LLC
根据 Apache 许可证第 2.0 版(“许可证”)发布;您不得使用本文件,除非遵守许可证条款。您可以从以下网址获取许可证副本:
https://www.apache.org/licenses/LICENSE-2.0
除非法律另有要求或书面协议约定,否则软件将以“原样”提供,且不附带任何明示或暗示的保证。请参阅许可证,了解有关特定语言的许可条款及限制,以及相关法律责任。
版本历史
paxml-v1.4.02024/04/09paxml-v1.3.12024/02/21paxml-v1.3.02024/02/17paxml-v1.2.02023/10/19paxml-v1.1.02023/08/28paxml-v1.0.02023/04/12paxml-v0.4.02023/04/12paxml-v0.3.02023/02/03paxml-v0.2.12022/11/22paxml-v0.2.02022/11/15paxml-v0.0.12022/06/15常见问题
相似工具推荐
stable-diffusion-webui
stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。
everything-claude-code
everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上
ComfyUI
ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。
NextChat
NextChat 是一款轻量且极速的 AI 助手,旨在为用户提供流畅、跨平台的大模型交互体验。它完美解决了用户在多设备间切换时难以保持对话连续性,以及面对众多 AI 模型不知如何统一管理的痛点。无论是日常办公、学习辅助还是创意激发,NextChat 都能让用户随时随地通过网页、iOS、Android、Windows、MacOS 或 Linux 端无缝接入智能服务。 这款工具非常适合普通用户、学生、职场人士以及需要私有化部署的企业团队使用。对于开发者而言,它也提供了便捷的自托管方案,支持一键部署到 Vercel 或 Zeabur 等平台。 NextChat 的核心亮点在于其广泛的模型兼容性,原生支持 Claude、DeepSeek、GPT-4 及 Gemini Pro 等主流大模型,让用户在一个界面即可自由切换不同 AI 能力。此外,它还率先支持 MCP(Model Context Protocol)协议,增强了上下文处理能力。针对企业用户,NextChat 提供专业版解决方案,具备品牌定制、细粒度权限控制、内部知识库整合及安全审计等功能,满足公司对数据隐私和个性化管理的高标准要求。
ML-For-Beginners
ML-For-Beginners 是由微软推出的一套系统化机器学习入门课程,旨在帮助零基础用户轻松掌握经典机器学习知识。这套课程将学习路径规划为 12 周,包含 26 节精炼课程和 52 道配套测验,内容涵盖从基础概念到实际应用的完整流程,有效解决了初学者面对庞大知识体系时无从下手、缺乏结构化指导的痛点。 无论是希望转型的开发者、需要补充算法背景的研究人员,还是对人工智能充满好奇的普通爱好者,都能从中受益。课程不仅提供了清晰的理论讲解,还强调动手实践,让用户在循序渐进中建立扎实的技能基础。其独特的亮点在于强大的多语言支持,通过自动化机制提供了包括简体中文在内的 50 多种语言版本,极大地降低了全球不同背景用户的学习门槛。此外,项目采用开源协作模式,社区活跃且内容持续更新,确保学习者能获取前沿且准确的技术资讯。如果你正寻找一条清晰、友好且专业的机器学习入门之路,ML-For-Beginners 将是理想的起点。
ragflow
RAGFlow 是一款领先的开源检索增强生成(RAG)引擎,旨在为大语言模型构建更精准、可靠的上下文层。它巧妙地将前沿的 RAG 技术与智能体(Agent)能力相结合,不仅支持从各类文档中高效提取知识,还能让模型基于这些知识进行逻辑推理和任务执行。 在大模型应用中,幻觉问题和知识滞后是常见痛点。RAGFlow 通过深度解析复杂文档结构(如表格、图表及混合排版),显著提升了信息检索的准确度,从而有效减少模型“胡编乱造”的现象,确保回答既有据可依又具备时效性。其内置的智能体机制更进一步,使系统不仅能回答问题,还能自主规划步骤解决复杂问题。 这款工具特别适合开发者、企业技术团队以及 AI 研究人员使用。无论是希望快速搭建私有知识库问答系统,还是致力于探索大模型在垂直领域落地的创新者,都能从中受益。RAGFlow 提供了可视化的工作流编排界面和灵活的 API 接口,既降低了非算法背景用户的上手门槛,也满足了专业开发者对系统深度定制的需求。作为基于 Apache 2.0 协议开源的项目,它正成为连接通用大模型与行业专有知识之间的重要桥梁。