paxml

GitHub
550 70 较难 1 次阅读 2周前Apache-2.0语言模型开发框架Agent
AI 解读 由 AI 自动生成,仅供参考

Paxml 是 Google 开源的一套基于 JAX 的机器学习框架,专为“超大模型 + 大规模分布式训练”而生。它把复杂的并行策略、硬件调度、实验配置都封装成可插拔的模块,开发者只需写一份配置,就能在 TPU Pod 上把数十亿甚至上百亿参数的模型高效跑起来。官方测试显示,Paxml 在同等算力下的模型 FLOPs 利用率处于业界领先水平,显著缩短训练时间和成本。

如果你正在做 NLP、多模态或通用大模型的研究,需要快速验证想法、横向扩展实验规模,Paxml 会是趁手的利器;云 TPU 用户可直接用一条命令拉起环境,本地研究者也能在 GPU 上小步快跑。它支持 SPMD(pjit)与传统 pmap 两种并行模式,内置大量可复用的模型模板与教程 Notebook,真正做到“配置即实验”。

使用场景

一家 20 人规模的初创公司正在训练一个 70 亿参数的中文对话大模型,用于给 B 端客户提供智能客服 API。团队只有 2 名算法工程师和 1 名 MLOps,预算有限,希望在 4 周内完成首轮迭代。

没有 paxml 时

  • 需要手写大量 Jax+pjit 代码才能把模型拆到 8 张 TPU v4 上,光是调试张量并行和流水并行就花了 5 天
  • 训练 10 k step 时 GPU/TPU 利用率只有 38 %,大量时间在等待通信,导致 4 周排期被拉长到 7 周
  • 超参数一改就要重新写配置脚本,实验管理靠 Excel 手动记录,结果经常对不上
  • 想在云上拉起 64 卡做消融实验,需要额外写 200 行 bash 脚本,还踩了配额和镜像版本坑

使用 paxml 后

  • 直接复用 LmCloudSpmd7B 模板,30 行 YAML 就完成 8 卡并行配置,2 小时跑通第一次训练
  • 内置的 flop_utilization 指标稳定在 62 %,同样 10 k step 训练时间从 3.5 天降到 2.1 天
  • 通过 --exp 参数一键切换学习率、序列长度等超参数,实验结果自动写进 TensorBoard 和 GCS,回溯零成本
  • 需要扩容到 64 卡时,把 ACCELERATOR=v4-64 改一行即可,paxml 自动处理 slice 划分和 checkpoint 分片

paxml 让这支小团队在 4 周内如期交付了首个可用模型,并把后续迭代周期缩短一半。

运行环境要求

操作系统
  • Linux
GPU
  • 可选
  • 官方示例基于 Google Cloud TPU v4(v4-8/v4-64/v4-128/v4-384)
  • NVIDIA GPU 支持由 NVIDIA Rosetta 分支提供,未说明具体型号与显存要求
内存

未说明

依赖
notes1) 官方示例默认在 Google Cloud TPU VM 上运行,需提前创建 TPU 实例;2) 若使用 GPU,请转向 NVIDIA Rosetta 分支;3) 多 slice 训练需为每个 slice 单独启动终端并设置 LIBTPU_INIT_ARGS;4) 日志与 checkpoint 需写入 Google Cloud Storage(gs://);5) 首次安装建议从 release 分支获取 requirements.txt 以锁定依赖版本。
python3.8
jax[tpu]
paxml
praxis
orbax==0.1.1
notebook
markupsafe==2.0.1
paxml hero image

快速开始

Paxml(又名 Pax)

Pax 是一个用于在 Jax 之上配置和运行机器学习实验的框架。

快速入门

设置 Cloud TPU VM

有关如何启动 Cloud TPU 项目,我们可参考 此页面 获取更详尽的文档。以下命令足以从一台主机上创建一台拥有 8 个核心的 Cloud TPU VM。

export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-8
export TPU_NAME=paxml

# 创建 TPU VM
gcloud compute tpus tpu-vm create $TPU_NAME \
--zone=$ZONE --version=$VERSION \
--project=$PROJECT \
--accelerator-type=$ACCELERATOR

如果您使用的是 TPU Pod 片段,请参阅 本指南。请使用 gcloud 并通过 --worker=all 选项,在本地机器上运行所有命令:

gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE \
--worker=all --command="<commands>"

以下快速入门章节假设您在单机 TPU 上运行,因此您可以直接 SSH 到 VM 并在其中执行命令。

gcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE

安装 Pax

在成功 SSH 到 VM 后,您可以从 PyPI 安装 Pax 的稳定版,或从 GitHub 安装开发版。

要从 PyPI 安装稳定版(https://pypi.org/project/paxml/):

python3 -m pip install -U pip
python3 -m pip install paxml jax[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html

如果在使用原生 Cloud TPU VM 环境时遇到依赖项传递问题,请前往对应的发布分支 rX.Y.Z,并下载 paxml/pip_package/requirements.txt 文件。该文件列出了原生 Cloud TPU VM 环境中所需的所有依赖项的确切版本,而我们正是基于这些版本构建并测试相应版本的发布内容。

git clone -b rX.Y.Z https://github.com/google/paxml
pip install --no-deps -r paxml/paxml/pip_package/requirements.txt

若需从 GitHub 安装开发版,且为了方便编辑代码:

# 首先安装 Praxis 的开发版
git clone https://github.com/google/praxis
pip install -e praxis
git clone https://github.com/google/paxml
pip install -e paxml
pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

运行测试模型

# 示例模型:使用 pjit(SPMD)
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \
--job_log_dir=gs://<your-bucket>

# 示例模型:使用 pmap
python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps \
--job_log_dir=gs://<your-bucket> \
--pmap_use_tensorstore=True

文档说明

请访问我们的 文档文件夹 以获取相关文档和 Jupyter Notebook 教程。如需了解如何在 Cloud TPU VM 上运行 Jupyter Notebook,请参阅以下部分。

运行笔记本

您可以在刚刚安装 Pax 的 TPU VM 中运行 示例笔记本

v4-8 中启用笔记本的步骤

  1. 使用端口转发 SSH 登录 TPU VM:

    gcloud compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_NAME    --zone=$ZONE    --ssh-flag="-4 -L 8080:localhost:8080"
    
  2. 在 TPU VM 上安装 Jupyter Notebook,并降级 markupsafe:

    pip install notebook
    pip install markupsafe==2.0.1
    
  3. 导出 jupyter 路径:

    export PATH=/home/$USER/.local/bin:$PATH
    
  4. 示例笔记本 复制到 TPU VM:

    gcloud compute tpus tpu-vm scp $TPU_NAME:<路径在 TPU 中> <本地笔记本路径>   --zone=$ZONE --project=$PROJECT
    
  5. 在 TPU VM 上启动 Jupyter Notebook,并记录 Jupyter Notebook 生成的令牌:

    jupyter notebook --no-browser --port=8080
    
  6. 然后在本地浏览器中访问:http://localhost:8080/,并输入提供的令牌。

注意:如果您需要在第一个笔记本仍在占用 TPU 时启动第二个笔记本,可以运行 pkill -9 python3 来释放 TPU。

在 GPU 上运行

注意:NVIDIA 已发布 Pax 的更新版本,支持 H100 FP8 并大幅提升了 GPU 性能。如需了解更多详情及使用说明,请访问 NVIDIA Rosetta 仓库。

常见问题解答

  1. Pax 适用于 Jax,您可以在 此处 查看在 Cloud TPU 上运行 Jax 作业的详细信息;此外,您也可以在 此处 查看在 Cloud TPU Pod 上运行 Jax 作业的相关信息。

  2. 如果遇到依赖项错误,请参考您所安装的稳定版对应分支中的 requirements.txt 文件。 例如,对于 稳定版 0.4.0,请使用分支 r0.4.0,并参考 requirements.txt 文件,以获取稳定版所用依赖项的确切版本。

示例收敛运行

以下是针对 c4 数据集 的一些示例收敛运行。

c4 数据集上的 1B 模型

您可以在 TPU v4-8 上,使用来自 c4.py 的配置 C4Spmd1BAdam4Replicas,在 c4 数据集上运行一个参数量为 1B 的模型,具体操作如下:

python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \
--job_log_dir=gs://<your-bucket>

您可以通过以下方式观察损失曲线和 log perplexity 图表:

c4 数据集上的 16B 模型

您可以在 TPU v4-64 上,使用来自 c4.py 的配置 C4Spmd16BAdam32Replicas,在 c4 数据集上运行一个参数量为 16B 的模型,具体操作如下:

python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4Spmd16BAdam32Replicas \
--job_log_dir=gs://<your-bucket>

您可以通过以下方式观察损失曲线和 log perplexity 图表:

GPT3-XL 模型在 c4 数据集上的应用

您可以通过使用 c4.py 中的配置 C4SpmdPipelineGpt3SmallAdam64Replicas,在 TPU v4-128 上运行 GPT3-XL 模型,以 c4 数据集为基准。具体操作如下:

python3 .local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4.C4SpmdPipelineGpt3SmallAdam64Replicas \
--job_log_dir=gs://<your-bucket>

您将能够观察到损失曲线以及 log perplexity 图表,如下所示:

![GPT3-XL 损失曲线](paxml/docs/images/GPT3-XL-loss.png width="400" height="300") ![GPT3-XL 平均困惑度图](paxml/docs/images/GPT3-XL-pplx.png width="400" height="300")

在 Cloud TPU v4 上进行基准测试

PaLM 论文(https://arxiv.org/abs/2204.02311)提出了一种名为“模型 FLOPs 利用率”(MFU)的效率指标。该指标通过观测到的吞吐量(例如,语言模型每秒处理的 token 数)与系统在充分利用峰值 FLOPs 时的理论最大吞吐量之比来衡量。与其他计算利用率的衡量方式不同,MFU 不会计入在反向传播过程中因激活重置而消耗的 FLOPs,因此,以 MFU 衡量的效率直接反映了端到端训练的速度。

为了评估 Pax 在 TPU v4 Pod 上针对关键工作负载的 MFU 水平,我们对一系列仅包含解码器的 Transformer 语言模型(从数十亿参数到数万亿参数不等)进行了深入的基准测试,这些模型均基于 c4 数据集 进行训练。下图展示了采用“弱扩展”模式下的训练效率——我们根据所使用的芯片数量按比例扩大模型规模。

![TPU v4 大型语言模型训练的弱扩展图](paxml/docs/images/Weak_scaling_of_large_language_model_training_on_TPU_v4.png width="500" height="300")

Pax 在多切片环境中的应用

本仓库中所提及的多切片配置,分别对应于: 1. 单切片配置 用于语法和模型架构的定义, 2. MaxText 仓库 则用于配置值的设置。

我们提供了 c4_multislice.py 中的示例运行脚本,作为 Pax 在多切片环境中的起点。

使用排队资源配置 Cloud TPU 虚拟机

有关如何为多切片 Cloud TPU 项目使用排队资源的更详尽文档,请参阅 此页面。以下步骤介绍了为运行本仓库中示例配置而设置 TPU 的所需操作。

export ZONE=us-central2-b
export VERSION=tpu-vm-v4-base
export PROJECT=<your-project>
export ACCELERATOR=v4-128 # 或 v4-384,具体取决于您运行的配置

例如,若要针对 v4-128 的 2 个切片运行 C4Spmd22BAdam2xv4_128,您需要按照以下方式配置 TPU:

export TPU_PREFIX=<your-prefix> # 新创建的 TPU 将基于此前缀命名
export QR_ID=$TPU_PREFIX
export NODE_COUNT=<number-of-slices> # 1、2 或 4,具体取决于您运行的配置

创建 TPU 虚拟机:

gcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=$ACCELERATOR --runtime-version=tpu-vm-v4-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX

安装 Pax

前面介绍的安装命令需在所有切片的所有 worker 上执行。您可以选择:1)分别登录每个 worker 和每个切片;或者 2)使用带有 --worker=all 标志的 for 循环,如以下命令所示。

for ((i=0; i<$NODE_COUNT; i++))
do
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-$i --zone=us-central2-b --worker=all --command="pip install paxml && pip install orbax==0.1.1 && pip install \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
done

运行多切片模型测试

要运行多切片配置,需打开与您的 NODE_COUNT 相同数量的终端。对于我们的 2 个切片实验(C4Spmd22BAdam2xv4_128),请打开两个终端。然后,分别从每个终端单独运行上述命令。

从终端 0 开始,运行切片 0 的训练命令,如下所示:

export TPU_PREFIX=<your-prefix>
export EXP_NAME=C4Spmd22BAdam2xv4_128
export LIBTPU_INIT_ARGS=\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\"
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-0 --zone=us-central2-b --worker=all \
--command="LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS \
python3 /home/yooh/.local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs://<your-bucket>"

从终端 1 同时运行切片 1 的训练命令,如下所示:

export TPU_PREFIX=<your-prefix>
export EXP_NAME=C4Spmd22BAdam2xv4_128
export LIBTPU_INIT_ARGS=\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\"
gcloud compute tpus tpu-vm ssh $TPU_PREFIX-1 --zone=us-central2-b --worker=all \
--command="LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS \
python3 /home/yooh/.local/lib/python3.8/site-packages/paxml/main.py \
--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs://<your-bucket>"

MaxText 转换为 Pax

本表格详细介绍了 MaxText 变量名称如何被转换为 Pax 的具体细节。

请注意,MaxText 为多个参数(如 base_num_decoder_layers、base_emb_dim、base_mlp_dim、base_num_heads)设置了“缩放因子”,用于计算最终的值。

另外需要说明的是,虽然 Pax 将 DCN 和 ICN 的 MESH_SHAPE 作为数组来处理,但在 MaxText 中,DCN 和 ICI 分别对应着独立的变量:data_parallelism、fsdp_parallelism 和 tensor_parallelism。由于这些变量默认设置为 1,因此在本次翻译表中,仅记录了那些值大于 1 的变量。

也就是说,ICI_MESH_SHAPE = [ici_data_parallelism, ici_fsdp_parallelism, ici_tensor_parallelism],而 DCN_MESH_SHAPE = [dcn_data_parallelism, dcn_fsdp_parallelism, dcn_tensor_parallelism]

Pax C4Spmd22BAdam2xv4_128 MaxText 2xv4-128.sh (在应用缩放因子后)
缩放因子(应用于接下来的四个变量) 3
NUM_LAYERS 48 base_num_decoder_layers 16 48
MODEL_DIMS 6144 base_emb_dim 2048 6144
HIDDEN_DIMS 24576 MODEL_DIMS * 4 (= base_mlp_dim) 8192 24576
NUM_HEADS 24 base_num_heads 8 24
DIMS_PER_HEAD 256 head_dim 256
PERCORE_BATCH_SIZE 16 per_device_batch_size 16
MAX_SEQ_LEN 1024 max_target_length 1024
VOCAB_SIZE 32768 vocab_size 32768
FPROP_DTYPE jnp.bfloat16 dtype bfloat16
USE_REPEATED_LAYER TRUE
SUMMARY_INTERVAL_STEPS 10
ICI_MESH_SHAPE [1, 64, 1] ici_fsdp_parallelism 64
DCN_MESH_SHAPE [2, 1, 1] dcn_data_parallelism 2

数据输入

简介

输入是一个 BaseInput 类的实例,用于将数据输入模型,以进行训练、评估或解码。

class BaseInput:

  def get_next(self):
    pass

  def reset(self):
    pass

它类似于一个迭代器:get_next() 会返回一个 NestedMap,其中每个字段 都是一维数值数组,其首维表示批量大小。

每个输入由 BaseInput.HParams 的子类进行配置。在本页面中,我们用 p 来表示 BaseInput.Params 的实例,并将其初始化为 input

多主机输入

在 Pax 中,数据始终是多主机的:每个 Jax 进程都会拥有独立且独立的 input 实例。它们的参数会拥有不同的 p.infeed_host_index,由 Pax 自动设置。

因此,在每台主机上看到的本地批量大小为 p.batch_size,而全局批量大小则为 (p.batch_size * p.num_infeed_hosts)。通常情况下,我们会将 p.batch_size 设置为 jax.local_device_count() * PERCORE_BATCH_SIZE

由于具备多主机特性,input 必须进行适当的分片处理。

在训练时,每个 input 从不输出相同的批次;而在对有限数据集进行评估时,每个 input 应在相同数量的批次后终止。最佳方案是让输入实现能够正确地对数据进行分片,使得不同主机上的每个 input 不会相互重叠。如果无法做到这一点,也可以使用不同的随机种子,以避免在训练过程中出现重复的批次。

评估数据的输入

在训练数据中,input.reset() 从不被调用,但在评估(或解码)数据中,可以调用 input.reset()

对于每次评估(或解码)运行,Pax 会通过多次调用 input.get_next(),从 input 中获取 N 批次。N 的数量可以由用户通过 p.eval_loop_num_batches 固定指定;或者,N 也可以是动态的(p.eval_loop_num_batches=None),即我们不断调用 input.get_next(),直到耗尽所有数据(通过抛出 StopIterationtf.errors.OutOfRange)。

如果 p.reset_for_eval=Truep.eval_loop_num_batches 将被忽略,N 会根据实际需求动态确定,以满足数据的消耗需求。在这种情况下,应将 p.repeat 设置为 False,否则会导致无限的解码/评估过程。

如果 p.reset_for_eval=False,Pax 将获取 p.eval_loop_num_batches 批次。此时,应将 p.repeat=True 设置为 True,以确保数据不会过早耗尽。

需要注意的是,LingvoEvalAdaptor 输入要求 p.reset_for_eval=True

N: 静态 N: 动态
p.reset_for_eval=True 每次评估运行会使用前 N 批次。尚未支持 eval_loop_num_batches 每个评估运行一次一个 epoch。
: : 第一批次开始使用 N 批次。未支持 eval_loop_num_batches 输入必须是有限的。
p.reset_for_eval=False 每个评估运行会使用非重叠的 N 批次,按滚动方式逐步获取数据,依据 eval_loop_num_batches 的设定。 不支持此功能。输入必须无限循环。
请将 p.repeat 设置为 False,否则可能导致解码/评估过程无限进行。

如果在单个 epoch 内执行解码/评估(即当 p.reset_for_eval=True 时),输入必须正确处理分片问题,确保每个分片在完成相同数量的批次后,以完全相同的步长进行下一轮操作。这通常意味着输入需要对评估数据进行填充。SeqIOInputLingvoEvalAdaptor 会自动完成这一工作(详见下文)。

评估指标

对于大多数输入,我们只会调用 get_next() 来获取数据批次。然而,有一种类型的评估数据例外——在这种情况下,“如何计算指标”也会在输入对象上进行定义。

此功能仅适用于 SeqIOInput,该输入定义了一些标准的评估基准。具体而言,Pax 使用 SeqIO 任务中定义的 predict_metric_fnsscore_metric_fns 来计算评估指标(尽管 Pax 并不直接依赖于 SeqIO 评估器)。

最佳实践

当模型使用多种输入时,无论是训练/评估阶段之间,还是在预训练/微调不同训练数据之间的切换过程中,用户必须确保各输入所使用的分词器完全一致,尤其是在导入由他人实现的不同输入时。

用户可以通过使用 input.ids_to_strings() 对分词器进行校验,以验证其正确性。

建议始终通过查看若干批次的数据来对数据进行初步检查。用户可以轻松地在 Colab 中重现该参数,并对数据进行检查:

p = ... # 指定预期的输入参数
inp = p.Instantiate()
b = inp.get_next()
print(b)

训练数据通常不应使用固定的随机种子。这是因为如果训练任务被提前终止,训练数据将会开始重复出现。尤其对于 Lingvo 输入,我们建议在训练数据中将 p.input.file_random_seed = 0 设置为默认值。

为了测试分片是否得到正确处理,用户可以手动为 p.num_infeed_hostsp.infeed_host_index 设置不同的值,观察实例化后的输入是否生成不同的批次。

输入类型

Pax 支持三种类型的输入:SeqIO、Lingvo 以及自定义输入。

SeqIO

SeqIOInput 可用于导入数据集。

SeqIO 输入会自动处理评估数据的正确分片与填充。

Lingvo

LingvoInputAdaptor 可用于导入数据集。

该输入完全委托给 Lingvo 实现,而 Lingvo 实现可能自动处理分片,也可能不自动处理分片。

对于基于 GenericInput 的 Lingvo 输入实现,若采用固定 packing_factor,我们建议使用 LingvoInputAdaptorNewBatchSize 来为内部 Lingvo 输入指定更大的批量大小,并将所需的(通常要小得多)批量大小设置为 p.batch_size

对于评估数据,我们建议使用 LingvoEvalAdaptor,以在单个 epoch 内运行评估时,自动处理分片与填充。

自定义

自定义类继承自 BaseInput。用户可以自行实现子类,通常使用 tf.data 或 SeqIO。

用户也可以直接继承现有的输入类,仅对批次的后处理进行个性化定制。例如:

class MyInput(base_input.LingvoInputAdaptor):

  def get_next(self):
    batch = super().get_next()
    # 修改批次:batch.new_field = ...
    return batch

Pax 关键组件

超参数

超参数是定义模型和配置实验的重要组成部分。

为了更好地与 Python 工具集成,Pax/Praxis 采用了一种 Python 风格的 dataclass 配置方式来管理超参数。

class Linear(base_layer.BaseLayer):
  """线性层,无偏置。”

  class HParams(BaseHParams):
    """此层类的关联超参数。

    属性:
      input_dims:输入的深度。
      output_dims:输出的深度。
    """
    input_dims: int = 0
    output_dims: int = 0

嵌套

超参数数据类也可以进行嵌套。如下面的例子所示,linear_tpl 属性是一个嵌套的 Linear.HParams

class FeedForward(base_layer.BaseLayer):
  """前馈层,带激活函数。”

  class HParams(BaseHParams):
    """此层类的关联超参数。

    属性:
      input_dims:输入的深度。
      output_dims:输出的深度。
      has_bias:是否添加偏置权重。
      linear_tpl:线性层的参数。
      activation_tpl:激活函数层的参数。
    """
    input_dims: int = 0
    output_dims: int = 0
    has_bias: bool = True
    linear_tpl: BaseHParams = sub_config_field(Linear.HParams)
    activation_tpl: activations.BaseActivation.HParams = sub_config_field(
        ReLU.HParams)

层代表任意函数,该函数可能包含可训练的参数。层可以作为子级包含其他层。层是模型的核心构建模块。层继承自 Flax nn.Module。

通常,层会定义两个方法:

setup

该方法用于创建可训练的权重和子级层。

fprop

该方法定义前向传播函数,根据输入计算出某些输出。此外,fprop 还可以添加摘要信息或跟踪辅助损失。

Fiddle 与共享层

Fiddle 是一个开源的 Python 首先配置库,专为机器学习应用设计。Pax/Praxis 支持与 Fiddle Config/Partial 的互操作性,以及一些高级功能,如即时错误检查和共享参数。

fdl_config = Linear.HParams.config(input_dims=1, output_dims=1)

# 一个拼写错误。
fdl_config.input_dimz = 31337  # 一旦发现拼写错误,就会立即抛出异常,快速捕捉错误!


fdl_partial = Linear.HParams.partial(input_dims=1)

借助 Fiddle,层可以被配置为共享(例如,仅实例化一次,并使用共享的可训练权重)。

模型

模型仅定义网络结构,通常是多个层的集合,并定义了与模型交互的接口,例如解码等。

一些基础模型示例包括:

  • 语言模型
  • 序列模型
  • 分类模型

任务

任务包含多个模型以及学习器/优化器。最简单的任务子类是 SingleTask,它需要以下超参数:

  class HParams(base_task.BaseTask.HParams):
    """任务参数。

    属性:
      name:此任务对象的名称,必须是有效的标识符。
      model:底层 JAX 模型,封装了所有层。
      train:用于控制任务训练方式的超参数。
      metrics:一个 BaseMetrics 汇总类,用于确定如何计算指标。
      loss_aggregator:一个 LossAggregator 汇总类,用于确定如何聚合损失(例如,单个损失或多损失)。
      vn:用于控制变分噪声的超参数。

发布版本

PyPI 版本 提交
0.1.0 546370f5323ef8b27d38ddc32445d7d3d1e4da9a
版权所有 2022 Google LLC

根据 Apache 许可证第 2.0 版(“许可证”)发布;您不得使用本文件,除非遵守许可证条款。您可以从以下网址获取许可证副本:

    https://www.apache.org/licenses/LICENSE-2.0

除非法律另有要求或书面协议约定,否则软件将以“原样”提供,且不附带任何明示或暗示的保证。请参阅许可证,了解有关特定语言的许可条款及限制,以及相关法律责任。

版本历史

paxml-v1.4.02024/04/09
paxml-v1.3.12024/02/21
paxml-v1.3.02024/02/17
paxml-v1.2.02023/10/19
paxml-v1.1.02023/08/28
paxml-v1.0.02023/04/12
paxml-v0.4.02023/04/12
paxml-v0.3.02023/02/03
paxml-v0.2.12022/11/22
paxml-v0.2.02022/11/15
paxml-v0.0.12022/06/15

常见问题

相似工具推荐

stable-diffusion-webui

stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。

162.1k|★★★☆☆|今天
开发框架图像Agent

everything-claude-code

everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上

139k|★★☆☆☆|今天
开发框架Agent语言模型

ComfyUI

ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。

107.7k|★★☆☆☆|2天前
开发框架图像Agent

NextChat

NextChat 是一款轻量且极速的 AI 助手,旨在为用户提供流畅、跨平台的大模型交互体验。它完美解决了用户在多设备间切换时难以保持对话连续性,以及面对众多 AI 模型不知如何统一管理的痛点。无论是日常办公、学习辅助还是创意激发,NextChat 都能让用户随时随地通过网页、iOS、Android、Windows、MacOS 或 Linux 端无缝接入智能服务。 这款工具非常适合普通用户、学生、职场人士以及需要私有化部署的企业团队使用。对于开发者而言,它也提供了便捷的自托管方案,支持一键部署到 Vercel 或 Zeabur 等平台。 NextChat 的核心亮点在于其广泛的模型兼容性,原生支持 Claude、DeepSeek、GPT-4 及 Gemini Pro 等主流大模型,让用户在一个界面即可自由切换不同 AI 能力。此外,它还率先支持 MCP(Model Context Protocol)协议,增强了上下文处理能力。针对企业用户,NextChat 提供专业版解决方案,具备品牌定制、细粒度权限控制、内部知识库整合及安全审计等功能,满足公司对数据隐私和个性化管理的高标准要求。

87.6k|★★☆☆☆|今天
开发框架语言模型

ML-For-Beginners

ML-For-Beginners 是由微软推出的一套系统化机器学习入门课程,旨在帮助零基础用户轻松掌握经典机器学习知识。这套课程将学习路径规划为 12 周,包含 26 节精炼课程和 52 道配套测验,内容涵盖从基础概念到实际应用的完整流程,有效解决了初学者面对庞大知识体系时无从下手、缺乏结构化指导的痛点。 无论是希望转型的开发者、需要补充算法背景的研究人员,还是对人工智能充满好奇的普通爱好者,都能从中受益。课程不仅提供了清晰的理论讲解,还强调动手实践,让用户在循序渐进中建立扎实的技能基础。其独特的亮点在于强大的多语言支持,通过自动化机制提供了包括简体中文在内的 50 多种语言版本,极大地降低了全球不同背景用户的学习门槛。此外,项目采用开源协作模式,社区活跃且内容持续更新,确保学习者能获取前沿且准确的技术资讯。如果你正寻找一条清晰、友好且专业的机器学习入门之路,ML-For-Beginners 将是理想的起点。

85k|★★☆☆☆|今天
图像数据工具视频

ragflow

RAGFlow 是一款领先的开源检索增强生成(RAG)引擎,旨在为大语言模型构建更精准、可靠的上下文层。它巧妙地将前沿的 RAG 技术与智能体(Agent)能力相结合,不仅支持从各类文档中高效提取知识,还能让模型基于这些知识进行逻辑推理和任务执行。 在大模型应用中,幻觉问题和知识滞后是常见痛点。RAGFlow 通过深度解析复杂文档结构(如表格、图表及混合排版),显著提升了信息检索的准确度,从而有效减少模型“胡编乱造”的现象,确保回答既有据可依又具备时效性。其内置的智能体机制更进一步,使系统不仅能回答问题,还能自主规划步骤解决复杂问题。 这款工具特别适合开发者、企业技术团队以及 AI 研究人员使用。无论是希望快速搭建私有知识库问答系统,还是致力于探索大模型在垂直领域落地的创新者,都能从中受益。RAGFlow 提供了可视化的工作流编排界面和灵活的 API 接口,既降低了非算法背景用户的上手门槛,也满足了专业开发者对系统深度定制的需求。作为基于 Apache 2.0 协议开源的项目,它正成为连接通用大模型与行业专有知识之间的重要桥梁。

77.1k|★★★☆☆|昨天
Agent图像开发框架