[{"data":1,"prerenderedAt":-1},["ShallowReactive",2],{"similar-google--paxml":3,"tool-google--paxml":64},[4,17,27,35,43,56],{"id":5,"name":6,"github_repo":7,"description_zh":8,"stars":9,"difficulty_score":10,"last_commit_at":11,"category_tags":12,"status":16},3808,"stable-diffusion-webui","AUTOMATIC1111\u002Fstable-diffusion-webui","stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面，旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点，将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。\n\n无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师，还是想要深入探索模型潜力的开发者与研究人员，都能从中获益。其核心亮点在于极高的功能丰富度：不仅支持文生图、图生图、局部重绘（Inpainting）和外绘（Outpainting）等基础模式，还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外，它内置了 GFPGAN 和 CodeFormer 等人脸修复工具，支持多种神经网络放大算法，并允许用户通过插件系统无限扩展能力。即使是显存有限的设备，stable-diffusion-webui 也提供了相应的优化选项，让高质量的 AI 艺术创作变得触手可及。",162132,3,"2026-04-05T11:01:52",[13,14,15],"开发框架","图像","Agent","ready",{"id":18,"name":19,"github_repo":20,"description_zh":21,"stars":22,"difficulty_score":23,"last_commit_at":24,"category_tags":25,"status":16},1381,"everything-claude-code","affaan-m\u002Feverything-claude-code","everything-claude-code 是一套专为 AI 编程助手（如 Claude Code、Codex、Cursor 等）打造的高性能优化系统。它不仅仅是一组配置文件，而是一个经过长期实战打磨的完整框架，旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。\n\n通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能，everything-claude-code 能显著提升 AI 在复杂任务中的表现，帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略，使得模型响应更快、成本更低，同时有效防御潜在的攻击向量。\n\n这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库，还是需要 AI 协助进行安全审计与自动化测试，everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目，它融合了多语言支持与丰富的实战钩子（hooks），让 AI 真正成长为懂上",138956,2,"2026-04-05T11:33:21",[13,15,26],"语言模型",{"id":28,"name":29,"github_repo":30,"description_zh":31,"stars":32,"difficulty_score":23,"last_commit_at":33,"category_tags":34,"status":16},2271,"ComfyUI","Comfy-Org\u002FComfyUI","ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎，专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式，采用直观的节点式流程图界面，让用户通过连接不同的功能模块即可构建个性化的生成管线。\n\n这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景，也能自由组合模型、调整参数并实时预览效果，轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性，不仅支持 Windows、macOS 和 Linux 全平台，还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构，并率先支持 SDXL、Flux、SD3 等前沿模型。\n\n无论是希望深入探索算法潜力的研究人员和开发者，还是追求极致创作自由度的设计师与资深 AI 绘画爱好者，ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能，使其成为当前最灵活、生态最丰富的开源扩散模型工具之一，帮助用户将创意高效转化为现实。",107662,"2026-04-03T11:11:01",[13,14,15],{"id":36,"name":37,"github_repo":38,"description_zh":39,"stars":40,"difficulty_score":23,"last_commit_at":41,"category_tags":42,"status":16},3704,"NextChat","ChatGPTNextWeb\u002FNextChat","NextChat 是一款轻量且极速的 AI 助手，旨在为用户提供流畅、跨平台的大模型交互体验。它完美解决了用户在多设备间切换时难以保持对话连续性，以及面对众多 AI 模型不知如何统一管理的痛点。无论是日常办公、学习辅助还是创意激发，NextChat 都能让用户随时随地通过网页、iOS、Android、Windows、MacOS 或 Linux 端无缝接入智能服务。\n\n这款工具非常适合普通用户、学生、职场人士以及需要私有化部署的企业团队使用。对于开发者而言，它也提供了便捷的自托管方案，支持一键部署到 Vercel 或 Zeabur 等平台。\n\nNextChat 的核心亮点在于其广泛的模型兼容性，原生支持 Claude、DeepSeek、GPT-4 及 Gemini Pro 等主流大模型，让用户在一个界面即可自由切换不同 AI 能力。此外，它还率先支持 MCP（Model Context Protocol）协议，增强了上下文处理能力。针对企业用户，NextChat 提供专业版解决方案，具备品牌定制、细粒度权限控制、内部知识库整合及安全审计等功能，满足公司对数据隐私和个性化管理的高标准要求。",87618,"2026-04-05T07:20:52",[13,26],{"id":44,"name":45,"github_repo":46,"description_zh":47,"stars":48,"difficulty_score":23,"last_commit_at":49,"category_tags":50,"status":16},2268,"ML-For-Beginners","microsoft\u002FML-For-Beginners","ML-For-Beginners 是由微软推出的一套系统化机器学习入门课程，旨在帮助零基础用户轻松掌握经典机器学习知识。这套课程将学习路径规划为 12 周，包含 26 节精炼课程和 52 道配套测验，内容涵盖从基础概念到实际应用的完整流程，有效解决了初学者面对庞大知识体系时无从下手、缺乏结构化指导的痛点。\n\n无论是希望转型的开发者、需要补充算法背景的研究人员，还是对人工智能充满好奇的普通爱好者，都能从中受益。课程不仅提供了清晰的理论讲解，还强调动手实践，让用户在循序渐进中建立扎实的技能基础。其独特的亮点在于强大的多语言支持，通过自动化机制提供了包括简体中文在内的 50 多种语言版本，极大地降低了全球不同背景用户的学习门槛。此外，项目采用开源协作模式，社区活跃且内容持续更新，确保学习者能获取前沿且准确的技术资讯。如果你正寻找一条清晰、友好且专业的机器学习入门之路，ML-For-Beginners 将是理想的起点。",84991,"2026-04-05T10:45:23",[14,51,52,53,15,54,26,13,55],"数据工具","视频","插件","其他","音频",{"id":57,"name":58,"github_repo":59,"description_zh":60,"stars":61,"difficulty_score":10,"last_commit_at":62,"category_tags":63,"status":16},3128,"ragflow","infiniflow\u002Fragflow","RAGFlow 是一款领先的开源检索增强生成（RAG）引擎，旨在为大语言模型构建更精准、可靠的上下文层。它巧妙地将前沿的 RAG 技术与智能体（Agent）能力相结合，不仅支持从各类文档中高效提取知识，还能让模型基于这些知识进行逻辑推理和任务执行。\n\n在大模型应用中，幻觉问题和知识滞后是常见痛点。RAGFlow 通过深度解析复杂文档结构（如表格、图表及混合排版），显著提升了信息检索的准确度，从而有效减少模型“胡编乱造”的现象，确保回答既有据可依又具备时效性。其内置的智能体机制更进一步，使系统不仅能回答问题，还能自主规划步骤解决复杂问题。\n\n这款工具特别适合开发者、企业技术团队以及 AI 研究人员使用。无论是希望快速搭建私有知识库问答系统，还是致力于探索大模型在垂直领域落地的创新者，都能从中受益。RAGFlow 提供了可视化的工作流编排界面和灵活的 API 接口，既降低了非算法背景用户的上手门槛，也满足了专业开发者对系统深度定制的需求。作为基于 Apache 2.0 协议开源的项目，它正成为连接通用大模型与行业专有知识之间的重要桥梁。",77062,"2026-04-04T04:44:48",[15,14,13,26,54],{"id":65,"github_repo":66,"name":67,"description_en":68,"description_zh":69,"ai_summary_zh":69,"readme_en":70,"readme_zh":71,"quickstart_zh":72,"use_case_zh":73,"hero_image_url":74,"owner_login":75,"owner_name":76,"owner_avatar_url":77,"owner_bio":78,"owner_company":79,"owner_location":79,"owner_email":80,"owner_twitter":81,"owner_website":82,"owner_url":83,"languages":84,"stars":105,"forks":106,"last_commit_at":107,"license":108,"difficulty_score":109,"env_os":110,"env_gpu":111,"env_ram":112,"env_deps":113,"category_tags":122,"github_topics":123,"view_count":23,"oss_zip_url":79,"oss_zip_packed_at":79,"status":16,"created_at":131,"updated_at":132,"faqs":133,"releases":154},1357,"google\u002Fpaxml","paxml","Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.","Paxml 是 Google 开源的一套基于 JAX 的机器学习框架，专为“超大模型 + 大规模分布式训练”而生。它把复杂的并行策略、硬件调度、实验配置都封装成可插拔的模块，开发者只需写一份配置，就能在 TPU Pod 上把数十亿甚至上百亿参数的模型高效跑起来。官方测试显示，Paxml 在同等算力下的模型 FLOPs 利用率处于业界领先水平，显著缩短训练时间和成本。\n\n如果你正在做 NLP、多模态或通用大模型的研究，需要快速验证想法、横向扩展实验规模，Paxml 会是趁手的利器；云 TPU 用户可直接用一条命令拉起环境，本地研究者也能在 GPU 上小步快跑。它支持 SPMD（pjit）与传统 pmap 两种并行模式，内置大量可复用的模型模板与教程 Notebook，真正做到“配置即实验”。","# Paxml (aka Pax)\n\nPax is a framework to configure and run machine learning experiments on top of Jax.\n## Quickstart\n### Setting up a Cloud TPU VM\n\nWe refer to\n[this page](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Fusers-guide-tpu-vm#managing_tpus)\nfor more exhaustive documentation about starting a Cloud TPU project. The\nfollowing command is sufficient to create a Cloud TPU VM with 8 cores from a\ncorp machine.\n\n```bash\nexport ZONE=us-central2-b\nexport VERSION=tpu-vm-v4-base\nexport PROJECT=\u003Cyour-project>\nexport ACCELERATOR=v4-8\nexport TPU_NAME=paxml\n\n#create a TPU VM\ngcloud compute tpus tpu-vm create $TPU_NAME \\\n--zone=$ZONE --version=$VERSION \\\n--project=$PROJECT \\\n--accelerator-type=$ACCELERATOR\n```\n\nIf you are using TPU Pod slices, please refer to [this guide](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Fjax-pods). Run all the commands from a local machine using [gcloud](https:\u002F\u002Fcloud.google.com\u002Fsdk\u002Fdocs\u002Finstall) with the `--worker=all` option:\n\n```bash\ngcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE \\\n--worker=all --command=\"\u003Ccommmands>\"\n```\n\nThe following quickstart sections assume you run on a single-host TPU, so you can ssh to the VM and run the commands there.\n\n```bash\ngcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE\n```\n\n### Installing Pax\n\nAfter ssh-ing the VM, you can install the paxml stable release from PyPI, or the dev version from github.\n\nFor installing the stable release from PyPI (https:\u002F\u002Fpypi.org\u002Fproject\u002Fpaxml\u002F):\n\n```bash\npython3 -m pip install -U pip\npython3 -m pip install paxml jax[tpu] \\\n-f https:\u002F\u002Fstorage.googleapis.com\u002Fjax-releases\u002Flibtpu_releases.html\n```\n\nIf you encounter issues with transitive dependencies and you are using the native Cloud TPU VM environment, please navigate to the corresponding release branch rX.Y.Z and download `paxml\u002Fpip_package\u002Frequirements.txt`. This file includes the exact versions of all transitive dependencies needed in the native Cloud TPU VM environment, in which we build\u002Ftest the corresponding release.\n\n```bash\ngit clone -b rX.Y.Z https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\npip install --no-deps -r paxml\u002Fpaxml\u002Fpip_package\u002Frequirements.txt\n```\n\nFor installing the dev version from github, and for the ease of editing code:\n\n```bash\n# install the dev version of praxis first\ngit clone https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpraxis\npip install -e praxis\ngit clone https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\npip install -e paxml\npip install \"jax[tpu]\" -f https:\u002F\u002Fstorage.googleapis.com\u002Fjax-releases\u002Flibtpu_releases.html\n```\n\n### Run a test model\n```bash\n# example model using pjit (SPMD)\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n\n# example model using pmap\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket> \\\n--pmap_use_tensorstore=True\n```\n### Documentations\nPlease visit our [docs folder](https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Ftree\u002Fmain\u002Fpaxml\u002Fdocs) for documentations and Jupyter Notebook tutorials. Please see the following section for instructions of running Jupyter Notebooks on a Cloud TPU VM.\n\n### Run a notebook\nYou can run the [example notebooks](paxml\u002Fdocs\u002Ftutorials) in the TPU VM in which you just installed paxml.\n####Steps to enable a notebook in a `v4-8`\n\n1. ssh in TPU VM with port forwarding\n`gcloud compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_NAME    --zone=$ZONE    --ssh-flag=\"-4 -L 8080:localhost:8080\"`\n\n2. install jupyter notebook on the TPU vm and downgrade markupsafe\n ```\n pip install notebook\n pip install markupsafe==2.0.1\n ```\n3. export `jupyter` path\n`export PATH=\u002Fhome\u002F$USER\u002F.local\u002Fbin:$PATH`\n\n4. scp the [example notebooks](paxml\u002Fdocs\u002Ftutorials) to your TPU VM\n`gcloud compute tpus tpu-vm scp  $TPU_NAME:\u003Cpath inside TPU> \u003Clocal path of the notebooks>   --zone=$ZONE --project=$PROJECT`\n\n5. start jupyter notebook from the TPU VM and note the token generated by jupyter notebook\n `jupyter notebook --no-browser --port=8080`\n\n6. then in your local browser go to: http:\u002F\u002Flocalhost:8080\u002F and enter the token provided\n\nNote: In case you need to start using a second notebook while the first notebook is still occupying the TPUs, you can run\n`pkill -9 python3`\nto free up the TPUs.\n\n### Run on GPU\n\nNote: NVIDIA has released an updated version of Pax with H100 FP8 support and broad GPU performance improvements. Please visit the [NVIDIA Rosetta](https:\u002F\u002Fgithub.com\u002FNVIDIA\u002FJAX-Toolbox\u002Ftree\u002Fmain\u002Frosetta\u002Frosetta\u002Fprojects\u002Fpax) repository for more details and usage instructions.\n\n### FAQs\n\n1. Pax runs on Jax, you can find details on running Jax jobs on Cloud TPU [here](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Frun-calculation-jax), also you can find details on running Jax jobs on a Cloud TPU pod [here](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Fjax-pods)\n\n2. If you run into dependency errors, please refer to the `requirements.txt` file in the branch corresponding to the stable release you are installing.\nFor e.g., for the [stable release 0.4.0](https:\u002F\u002Fpypi.org\u002Fproject\u002Fpaxml\u002F0.4.0\u002F) use branch `r0.4.0` and refer to the [requirements.txt](https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Fblob\u002Fr0.4.0\u002Fpaxml\u002Fpip_package\u002Frequirements.txt) for the exact versions of the dependencies used for the stable release.\n\n## Example Convergence Runs\nHere are some sample convergence runs on [c4 dataset](https:\u002F\u002Fwww.tensorflow.org\u002Fdatasets\u002Fcatalog\u002Fc4).\n\n### 1B model on c4 dataset\n\nYou can run a `1B` params model on c4 dataset on TPU `v4-8`using the config `C4Spmd1BAdam4Replicas`from [c4.py](paxml\u002Ftasks\u002Flm\u002Fparams\u002Fc4.py) as follows:\n\n```bash\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n```\nYou can observe loss curve and `log perplexity` graph as follows:\n\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002F1B-loss.png width=\"400\" height=\"300\">\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002F1B-pplx.png width=\"400\" height=\"300\">\n\n### 16B model on c4 dataset\n\nYou can run a `16B` params model on c4 dataset on TPU `v4-64`using the config `C4Spmd16BAdam32Replicas`from [c4.py](paxml\u002Ftasks\u002Flm\u002Fparams\u002Fc4.py) as follows:\n\n```bash\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4.C4Spmd16BAdam32Replicas \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n```\nYou can observe loss curve and `log perplexity` graph as follows:\n\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002F16B-loss.png width=\"400\" height=\"300\">\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002F16B-pplx.png width=\"400\" height=\"300\">\n\n### GPT3-XL model on c4 dataset\n\nYou can run the GPT3-XL model on c4 dataset on TPU `v4-128`using the config `C4SpmdPipelineGpt3SmallAdam64Replicas`from [c4.py](paxml\u002Ftasks\u002Flm\u002Fparams\u002Fc4.py) as follows:\n\n```bash\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4.C4SpmdPipelineGpt3SmallAdam64Replicas \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n```\nYou can observe loss curve and `log perplexity` graph as follows:\n\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002FGPT3-XL-loss.png width=\"400\" height=\"300\">\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002FGPT3-XL-pplx.png width=\"400\" height=\"300\">\n\n## Benchmark on Cloud TPU v4\nThe [PaLM](https:\u002F\u002Farxiv.org\u002Fabs\u002F2204.02311) paper introduced an efficiency metric called Model FLOPs Utilization (MFU). This is measured as the ratio of the observed throughput (in, for example, tokens per second for a language model) to the theoretical maximum throughput of a system harnessing 100% of peak FLOPs. It differs from other ways of measuring compute utilization because it doesn’t include FLOPs spent on activation rematerialization during the backward pass, meaning that efficiency as measured by MFU translates directly into end-to-end training speed.\n\nTo evaluate the MFU of a key class of workloads on TPU v4 Pods with Pax, we carried out an in-depth benchmark campaign on a series of decoder-only Transformer language model (GPT) configurations that range in size from billions to trillions of parameters on the [c4 dataset](https:\u002F\u002Fwww.tensorflow.org\u002Fdatasets\u002Fcatalog\u002Fc4). The following graph shows the training efficiency using the \"weak scaling\" pattern where we grew the model size in\nproportion to the number of chips used.\n\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002FWeak_scaling_of_large_language_model_training_on_TPU_v4.png width=\"500\" height=\"300\">\n\n\n## Pax on Multislice\nThe multislice configs in this repo refer to\n[1. Singlie slice configs](https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Fblob\u002Fmain\u002Fpaxml\u002Ftasks\u002Flm\u002Fparams\u002Fc4.py)\nfor syntax \u002F model architecture\nand\n[2. MaxText repo](https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fmaxtext)\nfor config values.\n\n\nWe provide example runs under c4_multislice.py` as a starting point for Pax on multislice.\n\n\n### Setting up Cloud TPU VMs using Queued Resources\n\n\nWe refer to\n[this page](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Fqueued-resources)\nfor more exhaustive documentation about using Queued Resources for a multi-slice Cloud TPU project. The\nfollowing shows the steps needed to set up TPUs for running example configs in this repo.\n\n\n```bash\nexport ZONE=us-central2-b\nexport VERSION=tpu-vm-v4-base\nexport PROJECT=\u003Cyour-project>\nexport ACCELERATOR=v4-128 # or v4-384 depending on which config you run\n```\n\n\nSay, for running `C4Spmd22BAdam2xv4_128` on 2 slices of v4-128, you'd need to set up TPUs the following way:\n\n```bash\nexport TPU_PREFIX=\u003Cyour-prefix> # New TPUs will be created based off this prefix\nexport QR_ID=$TPU_PREFIX\nexport NODE_COUNT=\u003Cnumber-of-slices> # 1, 2, or 4 depending on which config you run\n\n\n#create a TPU VM\ngcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=$ACCELERATOR --runtime-version=tpu-vm-v4-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX\n```\n\n\n### Installing Pax\n\n\nThe setup commands described earlier need to be run on ALL workers in ALL slices. You can 1) ssh into each worker and each slice individually; or 2) use for loop with `--worker=all` flag as the following command.\n\n\n```bash\nfor ((i=0; i\u003C$NODE_COUNT; i++))\ndo\ngcloud compute tpus tpu-vm ssh $TPU_PREFIX-$i --zone=us-central2-b --worker=all --command=\"pip install paxml && pip install orbax==0.1.1 && pip install \\\"jax[tpu]\\\" -f https:\u002F\u002Fstorage.googleapis.com\u002Fjax-releases\u002Flibtpu_releases.html\"\ndone\n```\n\n\n### Run a test multislice model\nIn order to run the multislice configs, open the same number of terminals as your $NODE_COUNT. For our experiments on 2 slices(`C4Spmd22BAdam2xv4_128`), open two terminals. Then, run each of these commands individually from each terminal.\n\nFrom Terminal 0, run training command for slice 0 as follows:\n\n```bash\nexport TPU_PREFIX=\u003Cyour-prefix>\nexport EXP_NAME=C4Spmd22BAdam2xv4_128\nexport LIBTPU_INIT_ARGS=\\\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\\\"\ngcloud compute tpus tpu-vm ssh $TPU_PREFIX-0 --zone=us-central2-b --worker=all \\\n--command=\"LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS \\\npython3 \u002Fhome\u002Fyooh\u002F.local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\"\n```\n\n\nFrom Terminal 1, concurrently run training command for slice 1 as follows:\n\n```bash\nexport TPU_PREFIX=\u003Cyour-prefix>\nexport EXP_NAME=C4Spmd22BAdam2xv4_128\nexport LIBTPU_INIT_ARGS=\\\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\\\"\ngcloud compute tpus tpu-vm ssh $TPU_PREFIX-1 --zone=us-central2-b --worker=all \\\n--command=\"LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS \\\npython3 \u002Fhome\u002Fyooh\u002F.local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\"\n```\n\n\n\n### MaxText to Pax\nThis table covers details on how the MaxText variable names have been translated to Pax.\n\n\nNote that MaxText has a \"scale\" which is multiplied to several parameters (base_num_decoder_layers, base_emb_dim, base_mlp_dim, base_num_heads) for final values.\n\n\nAnother thing to mention is while Pax covers DCN and ICN MESH_SHAPE as an array, in MaxText there are separate variables of data_parallelism, fsdp_parallelism and tensor_parallelism for DCN and ICI. Since these values are set as 1 by default, only the variables with value greater than 1 are recorded in this translation table.\n\nThat is, `ICI_MESH_SHAPE = [ici_data_parallelism, ici_fsdp_parallelism, ici_tensor_parallelism]` and `DCN_MESH_SHAPE = [dcn_data_parallelism, dcn_fsdp_parallelism, dcn_tensor_parallelism]`\n\n\n| Pax C4Spmd22BAdam2xv4_128 | | MaxText 2xv4-128.sh | | (after scale is applied) |\n|---------------------------|--------------|---------------------------------|----------|-----------------------:|\n| | | scale (applied to next 4 variables) | 3 | |\n| NUM_LAYERS | 48 | base_num_decoder_layers | 16 | 48 |\n| MODEL_DIMS | 6144 | base_emb_dim | 2048 | 6144 |\n| HIDDEN_DIMS | 24576 | MODEL_DIMS * 4 (= base_mlp_dim) | 8192 | 24576 |\n| NUM_HEADS | 24 | base_num_heads | 8 | 24 |\n| DIMS_PER_HEAD | 256 | head_dim | 256 | |\n| PERCORE_BATCH_SIZE | 16 | per_device_batch_size | 16 | |\n| MAX_SEQ_LEN | 1024 | max_target_length | 1024 | |\n| VOCAB_SIZE | 32768 | vocab_size | 32768 | |\n| FPROP_DTYPE | jnp.bfloat16 | dtype | bfloat16 | |\n| USE_REPEATED_LAYER | TRUE | | | |\n| SUMMARY_INTERVAL_STEPS | 10 | | | |\n| ICI_MESH_SHAPE | [1, 64, 1] | ici_fsdp_parallelism | 64 | |\n| DCN_MESH_SHAPE | [2, 1, 1] | dcn_data_parallelism | 2 | |\n\n\n# Data inputs\n\n## Intro\n\nInput is an instance of the `BaseInput`\nclass for getting data into model for train\u002Feval\u002Fdecode.\n\n```python\nclass BaseInput:\n\n  def get_next(self):\n    pass\n\n  def reset(self):\n    pass\n```\n\nIt acts like an iterator: `get_next()` returns a `NestedMap`, where each field\nis a numerical array with batch size as its leading dimension.\n\nEach input is configured by a subclass of `BaseInput.HParams`.\nIn this page, we use `p` to denote an instance of a `BaseInput.Params`, and it\ninstantiates to `input`.\n\n## Multihost infeed\n\nIn Pax, data is always multihost: Each Jax process will have a separate,\nindependent `input` instantiated. Their params will have different\n`p.infeed_host_index`, set automatically by Pax.\n\nHence, the local batch size seen on each host is `p.batch_size`, and the global\nbatch size is `(p.batch_size * p.num_infeed_hosts)`. One will often see\n`p.batch_size` set to `jax.local_device_count() * PERCORE_BATCH_SIZE`.\n\nDue to this multihost nature, `input` must be sharded properly.\n\nFor training, each `input` must never emit identical batches, and for eval on a\nfinite dataset, each `input` must terminate after the same number of batches.\nThe best solution is to have the input implementation properly shard the data,\nsuch that each `input` on different hosts do not overlap. Failing that, one can\nalso use different random seed to avoid duplicate batches during training.\n\n## Input for eval data\n\n`input.reset()` is never called on training data, but it can for eval (or\ndecode) data.\n\n\nFor each eval (or decode) run, Pax fetches `N` batches from `input` by calling\n`input.get_next()` `N` times. The number of batches used, `N`, can be a fixed\nnumber specified by user, via `p.eval_loop_num_batches`; or `N` can be dynamic\n(`p.eval_loop_num_batches=None`) i.e. we call `input.get_next()` until we\nexhaust all of its data (by raising `StopIteration` or `tf.errors.OutOfRange`).\n\nIf `p.reset_for_eval=True`, `p.eval_loop_num_batches` is ignored and `N` is\ndetermined dynamically as the number of batches to exhaust the data. In this\ncase, `p.repeat` should be set to False, as doing otherwise would lead to\ninfinite decode\u002Feval.\n\nIf `p.reset_for_eval=False`, Pax will fetch `p.eval_loop_num_batches` batches.\nThis should be set with `p.repeat=True` so that data are not prematurely\nexhausted.\n\nNote that LingvoEvalAdaptor inputs require `p.reset_for_eval=True`.\n\n|                          | `N`: static             | `N`: dynamic            |\n| ------------------------ | ----------------------- | ----------------------- |\n| `p.reset_for_eval=True`  | Each eval run uses the  | One epoch per eval run. |\n:                          : first `N` batches. Not  : `eval_loop_num_batches` :\n:                          : supported yet.          : is ignored. Input must  :\n:                          :                         : be finite               :\n:                          :                         : (`p.repeat=False`)      :\n| `p.reset_for_eval=False` | Each eval run uses      | Not supported.          |\n:                          : non-overlapping `N`     :                         :\n:                          : batches on a rolling    :                         :\n:                          : basis, according to     :                         :\n:                          : `eval_loop_num_batches` :                         :\n:                          : . Input must repeat     :                         :\n:                          : indefinitely            :                         :\n:                          : (`p.repeat=True`) or    :                         :\n:                          : otherwise may raise     :                         :\n:                          : exception               :                         :\n\nIf running decode\u002Feval on exactly one epoch (i.e. when `p.reset_for_eval=True`),\nthe input must handle sharding correctly such that each shard raises at the same\nstep after exactly the same number of batches are produced. This usually means\nthat the input must pad the eval data. This is done automatically\nby`SeqIOInput` and `LingvoEvalAdaptor` (see more below).\n\n### Eval metrics\n\nFor the majority of inputs, we only ever call `get_next()` on them to get\nbatches of data. One type of eval data is an exception to this, where \"how to\ncompute metrics\" is also defined on the input object as well.\n\nThis is only supported with `SeqIOInput` that defines some canonical eval\nbenchmark. Specifically, Pax uses `predict_metric_fns` and `score_metric_fns()` defined on the SeqIO task to compute\neval metrics (although Pax does not depend on SeqIO evaluator directly).\n\n## Best practices\n\nWhen a model uses multiple inputs, either between train\u002Feval or different\ntraining data between pretraining\u002Ffinetuning, users must ensure that the\ntokenizers used by the inputs are identical, especially when importing different\ninputs implemented by others.\n\nUsers can sanity check the tokenizers by decoding some ids with\n`input.ids_to_strings()`.\n\nIt's always a good idea to sanity check the data by looking at a few batches.\nUsers can easily reproduce the param in a colab and inspect the data:\n\n```python\np = ... # specify the intended input param\ninp = p.Instantiate()\nb = inp.get_next()\nprint(b)\n```\n\nTraining data typically should not use a fixed random seed. This is because if\nthe training job is preempted, training data will start to repeat itself. In\nparticular, for Lingvo inputs, we recommend setting `p.input.file_random_seed =\n0` for training data.\n\nTo test for whether sharding is handled correctly, users can manually set\ndifferent values for `p.num_infeed_hosts, p.infeed_host_index` and see whether\nthe instantiated inputs emit different batches.\n\n## Input types\n\nPax supports 3 types of inputs: SeqIO, Lingvo, and custom.\n\n### SeqIO\n\n`SeqIOInput` can be used to import datasets.\n\nSeqIO inputs handle correct sharding and padding of eval data automatically.\n\n### Lingvo\n\n`LingvoInputAdaptor` can be used to import datasets.\n\nThe input is fully delegated to the Lingvo implementation, which may or may not\nhandle sharding automatically.\n\nFor GenericInput based Lingvo input implementation using a fixed\n`packing_factor`, we recommend to use\n`LingvoInputAdaptorNewBatchSize` to specify a bigger batch size for the inner Lingvo input and put the desired\n(usually much smaller) batch size on `p.batch_size`.\n\nFor eval data, we recommend using\n`LingvoEvalAdaptor` to handle sharding and padding for running eval over one epoch.\n\n### Custom\n\nCustom subclass of `BaseInput`. Users implement their own subclass, typically\nwith `tf.data` or SeqIO.\n\nUsers can also inherit an existing input class to only customize post processing\nof batches. For example:\n\n```python\nclass MyInput(base_input.LingvoInputAdaptor):\n\n  def get_next(self):\n    batch = super().get_next()\n    # modify batch: batch.new_field = ...\n    return batch\n```\n\n# Key Pax components\n\n## Hyperparameters\n\nHyperparameters are an important part of defining models and configuring\nexperiments.\n\nTo integrate better with Python tooling, Pax\u002FPraxis uses a pythonic\ndataclass based configuration style for hyperparameters.\n\n```python\nclass Linear(base_layer.BaseLayer):\n  \"\"\"Linear layer without bias.\"\"\"\n\n  class HParams(BaseHParams):\n    \"\"\"Associated hyperparams for this layer class.\n\n    Attributes:\n      input_dims: Depth of the input.\n      output_dims: Depth of the output.\n    \"\"\"\n    input_dims: int = 0\n    output_dims: int = 0\n```\n\n### Nesting\n\nIt's also possible to nest HParams dataclasses, in the example below, the\nlinear_tpl attribute is a nested Linear.HParams.\n\n```python\nclass FeedForward(base_layer.BaseLayer):\n  \"\"\"Feedforward layer with activation.\"\"\"\n\n  class HParams(BaseHParams):\n    \"\"\"Associated hyperparams for this layer class.\n\n    Attributes:\n      input_dims: Depth of the input.\n      output_dims: Depth of the output.\n      has_bias: Adds bias weights or not.\n      linear_tpl: Linear layer params.\n      activation_tpl: Activation layer params.\n    \"\"\"\n    input_dims: int = 0\n    output_dims: int = 0\n    has_bias: bool = True\n    linear_tpl: BaseHParams = sub_config_field(Linear.HParams)\n    activation_tpl: activations.BaseActivation.HParams = sub_config_field(\n        ReLU.HParams)\n```\n\n\n## Layers\n\nA Layer represents an arbitrary function possibly with trainable parameters. A\nLayer can contain other Layers as children. Layers are the essential building\nblocks of models. Layers inherit from the Flax nn.Module.\n\nTypically layers define two methods:\n\n### setup\n\nThis method creates trainable weights and child layers.\n\n### fprop\n\nThis method defines the forward propagation function, computing some output\nbased on the inputs. Additionally, fprop might add summaries or track auxiliary\nlosses.\n\n### Fiddle and Shared layers\nFiddle is an open-sourced Python-first configuration library\ndesigned for ML applications. Pax\u002FPraxis supports interoperability with Fiddle\nConfig\u002FPartial(s) and some advanced features like eager error checking and\nshared parameters.\n\n```python\nfdl_config = Linear.HParams.config(input_dims=1, output_dims=1)\n\n# A typo.\nfdl_config.input_dimz = 31337  # Raises an exception immediately to catch typos fast!\n\n\nfdl_partial = Linear.HParams.partial(input_dims=1)\n```\n\nUsing Fiddle, layers can be configured to be shared (eg: instantiated only once\nwith shared trainable weights).\n\n## Model\n\nA model defines solely the network, typically a collection of Layers and defines\ninterfaces for interacting with the model such as decoding, etc.\n\nSome example base models\ninclude:\n\n*   LanguageModel\n*   SequenceModel\n*   ClassificationModel\n\n## Task\n\nA Task contains one more more Models and Learner\u002FOptimizers. The simplest Task\nsubclass is a `SingleTask` which requires the following Hparams:\n\n```python\n  class HParams(base_task.BaseTask.HParams):\n    \"\"\"Task parameters.\n\n    Attributes:\n      name: Name of this task object, must be a valid identifier.\n      model: The underlying JAX model encapsulating all the layers.\n      train: HParams to control how this task should be trained.\n      metrics: A BaseMetrics aggregator class to determine how metrics are\n         computed.\n      loss_aggregator: A LossAggregator aggregator class to derermine how the\n        losses are aggregated (e.g single or MultiLoss)\n      vn: HParams to control variational noise.\n```\n\n## Releases\nPyPI Version | Commit\n------------ | ----------------------------------------\n0.1.0        | 546370f5323ef8b27d38ddc32445d7d3d1e4da9a\n\n\n\n    Copyright 2022 Google LLC\n\n    Licensed under the Apache License, Version 2.0 (the \"License\");\n    you may not use this file except in compliance with the License.\n    You may obtain a copy of the License at\n\n        https:\u002F\u002Fwww.apache.org\u002Flicenses\u002FLICENSE-2.0\n\n    Unless required by applicable law or agreed to in writing, software\n    distributed under the License is distributed on an \"AS IS\" BASIS,\n    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n    See the License for the specific language governing permissions and\n    limitations under the License.\n","# Paxml（又名 Pax）\n\nPax 是一个用于在 Jax 之上配置和运行机器学习实验的框架。\n## 快速入门\n### 设置 Cloud TPU VM\n\n有关如何启动 Cloud TPU 项目，我们可参考\n[此页面](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Fusers-guide-tpu-vm#managing_tpus) 获取更详尽的文档。以下命令足以从一台主机上创建一台拥有 8 个核心的 Cloud TPU VM。\n\n```bash\nexport ZONE=us-central2-b\nexport VERSION=tpu-vm-v4-base\nexport PROJECT=\u003Cyour-project>\nexport ACCELERATOR=v4-8\nexport TPU_NAME=paxml\n\n# 创建 TPU VM\ngcloud compute tpus tpu-vm create $TPU_NAME \\\n--zone=$ZONE --version=$VERSION \\\n--project=$PROJECT \\\n--accelerator-type=$ACCELERATOR\n```\n\n如果您使用的是 TPU Pod 片段，请参阅 [本指南](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Fjax-pods)。请使用 [gcloud](https:\u002F\u002Fcloud.google.com\u002Fsdk\u002Fdocs\u002Finstall) 并通过 `--worker=all` 选项，在本地机器上运行所有命令：\n\n```bash\ngcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE \\\n--worker=all --command=\"\u003Ccommands>\"\n```\n\n以下快速入门章节假设您在单机 TPU 上运行，因此您可以直接 SSH 到 VM 并在其中执行命令。\n\n```bash\ngcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE\n```\n\n### 安装 Pax\n在成功 SSH 到 VM 后，您可以从 PyPI 安装 Pax 的稳定版，或从 GitHub 安装开发版。\n\n要从 PyPI 安装稳定版（https:\u002F\u002Fpypi.org\u002Fproject\u002Fpaxml\u002F）：\n\n```bash\npython3 -m pip install -U pip\npython3 -m pip install paxml jax[tpu] \\\n-f https:\u002F\u002Fstorage.googleapis.com\u002Fjax-releases\u002Flibtpu_releases.html\n```\n\n如果在使用原生 Cloud TPU VM 环境时遇到依赖项传递问题，请前往对应的发布分支 rX.Y.Z，并下载 `paxml\u002Fpip_package\u002Frequirements.txt` 文件。该文件列出了原生 Cloud TPU VM 环境中所需的所有依赖项的确切版本，而我们正是基于这些版本构建并测试相应版本的发布内容。\n\n```bash\ngit clone -b rX.Y.Z https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\npip install --no-deps -r paxml\u002Fpaxml\u002Fpip_package\u002Frequirements.txt\n```\n\n若需从 GitHub 安装开发版，且为了方便编辑代码：\n\n```bash\n# 首先安装 Praxis 的开发版\ngit clone https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpraxis\npip install -e praxis\ngit clone https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\npip install -e paxml\npip install \"jax[tpu]\" -f https:\u002F\u002Fstorage.googleapis.com\u002Fjax-releases\u002Flibtpu_releases.html\n```\n\n### 运行测试模型\n```bash\n# 示例模型：使用 pjit（SPMD）\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n\n# 示例模型：使用 pmap\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.lm_cloud.LmCloudTransformerAdamLimitSteps \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket> \\\n--pmap_use_tensorstore=True\n```\n### 文档说明\n请访问我们的 [文档文件夹](https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Ftree\u002Fmain\u002Fpaxml\u002Fdocs) 以获取相关文档和 Jupyter Notebook 教程。如需了解如何在 Cloud TPU VM 上运行 Jupyter Notebook，请参阅以下部分。\n\n### 运行笔记本\n您可以在刚刚安装 Pax 的 TPU VM 中运行 [示例笔记本](paxml\u002Fdocs\u002Ftutorials)。\n#### 在 `v4-8` 中启用笔记本的步骤\n\n1. 使用端口转发 SSH 登录 TPU VM：\n   ```bash\n   gcloud compute tpus tpu-vm ssh $TPU_NAME --project=$PROJECT_NAME    --zone=$ZONE    --ssh-flag=\"-4 -L 8080:localhost:8080\"\n   ```\n\n2. 在 TPU VM 上安装 Jupyter Notebook，并降级 markupsafe：\n   ```\n   pip install notebook\n   pip install markupsafe==2.0.1\n   ```\n\n3. 导出 `jupyter` 路径：\n   ```bash\n   export PATH=\u002Fhome\u002F$USER\u002F.local\u002Fbin:$PATH\n   ```\n\n4. 将 [示例笔记本](paxml\u002Fdocs\u002Ftutorials) 复制到 TPU VM：\n   ```bash\n   gcloud compute tpus tpu-vm scp $TPU_NAME:\u003C路径在 TPU 中> \u003C本地笔记本路径>   --zone=$ZONE --project=$PROJECT\n   ```\n\n5. 在 TPU VM 上启动 Jupyter Notebook，并记录 Jupyter Notebook 生成的令牌：\n   ```bash\n   jupyter notebook --no-browser --port=8080\n   ```\n\n6. 然后在本地浏览器中访问：http:\u002F\u002Flocalhost:8080\u002F，并输入提供的令牌。\n\n注意：如果您需要在第一个笔记本仍在占用 TPU 时启动第二个笔记本，可以运行 `pkill -9 python3` 来释放 TPU。\n\n### 在 GPU 上运行\n注意：NVIDIA 已发布 Pax 的更新版本，支持 H100 FP8 并大幅提升了 GPU 性能。如需了解更多详情及使用说明，请访问 [NVIDIA Rosetta](https:\u002F\u002Fgithub.com\u002FNVIDIA\u002FJAX-Toolbox\u002Ftree\u002Fmain\u002Frosetta\u002Frosetta\u002Fprojects\u002Fpax) 仓库。\n\n### 常见问题解答\n\n1. Pax 适用于 Jax，您可以在 [此处](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Frun-calculation-jax) 查看在 Cloud TPU 上运行 Jax 作业的详细信息；此外，您也可以在 [此处](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Fjax-pods) 查看在 Cloud TPU Pod 上运行 Jax 作业的相关信息。\n\n2. 如果遇到依赖项错误，请参考您所安装的稳定版对应分支中的 `requirements.txt` 文件。\n例如，对于 [稳定版 0.4.0](https:\u002F\u002Fpypi.org\u002Fproject\u002Fpaxml\u002F0.4.0\u002F)，请使用分支 `r0.4.0`，并参考 [requirements.txt](https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Fblob\u002Fr0.4.0\u002Fpaxml\u002Fpip_package\u002Frequirements.txt) 文件，以获取稳定版所用依赖项的确切版本。\n\n## 示例收敛运行\n以下是针对 [c4 数据集](https:\u002F\u002Fwww.tensorflow.org\u002Fdatasets\u002Fcatalog\u002Fc4) 的一些示例收敛运行。\n\n### c4 数据集上的 1B 模型\n\n您可以在 TPU `v4-8` 上，使用来自 [c4.py](paxml\u002Ftasks\u002Flm\u002Fparams\u002Fc4.py) 的配置 `C4Spmd1BAdam4Replicas`，在 c4 数据集上运行一个参数量为 `1B` 的模型，具体操作如下：\n\n```bash\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n```\n\n您可以通过以下方式观察损失曲线和 `log perplexity` 图表：\n\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002F1B-loss.png width=\"400\" height=\"300\">\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002F1B-pplx.png width=\"400\" height=\"300\">\n\n### c4 数据集上的 16B 模型\n\n您可以在 TPU `v4-64` 上，使用来自 [c4.py](paxml\u002Ftasks\u002Flm\u002Fparams\u002Fc4.py) 的配置 `C4Spmd16BAdam32Replicas`，在 c4 数据集上运行一个参数量为 `16B` 的模型，具体操作如下：\n\n```bash\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4.C4Spmd16BAdam32Replicas \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n```\n\n您可以通过以下方式观察损失曲线和 `log perplexity` 图表：\n\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002F16B-loss.png width=\"400\" height=\"300\">\n\u003Cimg src=paxml\u002Fdocs\u002Fimages\u002F16B-pplx.png width=\"400\" height=\"300\">\n\n### GPT3-XL 模型在 c4 数据集上的应用\n\n您可以通过使用 [c4.py](paxml\u002Ftasks\u002Flm\u002Fparams\u002Fc4.py) 中的配置 `C4SpmdPipelineGpt3SmallAdam64Replicas`，在 TPU `v4-128` 上运行 GPT3-XL 模型，以 c4 数据集为基准。具体操作如下：\n\n```bash\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4.C4SpmdPipelineGpt3SmallAdam64Replicas \\\n--job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n```\n\n您将能够观察到损失曲线以及 `log perplexity` 图表，如下所示：\n\n![GPT3-XL 损失曲线](paxml\u002Fdocs\u002Fimages\u002FGPT3-XL-loss.png width=\"400\" height=\"300\")\n![GPT3-XL 平均困惑度图](paxml\u002Fdocs\u002Fimages\u002FGPT3-XL-pplx.png width=\"400\" height=\"300\")\n\n## 在 Cloud TPU v4 上进行基准测试\nPaLM 论文（https:\u002F\u002Farxiv.org\u002Fabs\u002F2204.02311）提出了一种名为“模型 FLOPs 利用率”（MFU）的效率指标。该指标通过观测到的吞吐量（例如，语言模型每秒处理的 token 数）与系统在充分利用峰值 FLOPs 时的理论最大吞吐量之比来衡量。与其他计算利用率的衡量方式不同，MFU 不会计入在反向传播过程中因激活重置而消耗的 FLOPs，因此，以 MFU 衡量的效率直接反映了端到端训练的速度。\n\n为了评估 Pax 在 TPU v4 Pod 上针对关键工作负载的 MFU 水平，我们对一系列仅包含解码器的 Transformer 语言模型（从数十亿参数到数万亿参数不等）进行了深入的基准测试，这些模型均基于 [c4 数据集](https:\u002F\u002Fwww.tensorflow.org\u002Fdatasets\u002Fcatalog\u002Fc4) 进行训练。下图展示了采用“弱扩展”模式下的训练效率——我们根据所使用的芯片数量按比例扩大模型规模。\n\n![TPU v4 大型语言模型训练的弱扩展图](paxml\u002Fdocs\u002Fimages\u002FWeak_scaling_of_large_language_model_training_on_TPU_v4.png width=\"500\" height=\"300\")\n\n## Pax 在多切片环境中的应用\n本仓库中所提及的多切片配置，分别对应于：\n[1. 单切片配置](https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Fblob\u002Fmain\u002Fpaxml\u002Ftasks\u002Flm\u002Fparams\u002Fc4.py)\n用于语法和模型架构的定义，\n[2. MaxText 仓库](https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fmaxtext)\n则用于配置值的设置。\n\n我们提供了 c4_multislice.py 中的示例运行脚本，作为 Pax 在多切片环境中的起点。\n\n### 使用排队资源配置 Cloud TPU 虚拟机\n\n有关如何为多切片 Cloud TPU 项目使用排队资源的更详尽文档，请参阅 [此页面](https:\u002F\u002Fcloud.google.com\u002Ftpu\u002Fdocs\u002Fqueued-resources)。以下步骤介绍了为运行本仓库中示例配置而设置 TPU 的所需操作。\n\n```bash\nexport ZONE=us-central2-b\nexport VERSION=tpu-vm-v4-base\nexport PROJECT=\u003Cyour-project>\nexport ACCELERATOR=v4-128 # 或 v4-384，具体取决于您运行的配置\n```\n\n例如，若要针对 v4-128 的 2 个切片运行 `C4Spmd22BAdam2xv4_128`，您需要按照以下方式配置 TPU：\n\n```bash\nexport TPU_PREFIX=\u003Cyour-prefix> # 新创建的 TPU 将基于此前缀命名\nexport QR_ID=$TPU_PREFIX\nexport NODE_COUNT=\u003Cnumber-of-slices> # 1、2 或 4，具体取决于您运行的配置\n```\n\n创建 TPU 虚拟机：\n```bash\ngcloud alpha compute tpus queued-resources create $QR_ID --accelerator-type=$ACCELERATOR --runtime-version=tpu-vm-v4-base --node-count=$NODE_COUNT --node-prefix=$TPU_PREFIX\n```\n\n### 安装 Pax\n\n前面介绍的安装命令需在所有切片的所有 worker 上执行。您可以选择：1）分别登录每个 worker 和每个切片；或者 2）使用带有 `--worker=all` 标志的 for 循环，如以下命令所示。\n\n```bash\nfor ((i=0; i\u003C$NODE_COUNT; i++))\ndo\ngcloud compute tpus tpu-vm ssh $TPU_PREFIX-$i --zone=us-central2-b --worker=all --command=\"pip install paxml && pip install orbax==0.1.1 && pip install \\\"jax[tpu]\\\" -f https:\u002F\u002Fstorage.googleapis.com\u002Fjax-releases\u002Flibtpu_releases.html\"\ndone\n```\n\n### 运行多切片模型测试\n要运行多切片配置，需打开与您的 `NODE_COUNT` 相同数量的终端。对于我们的 2 个切片实验（`C4Spmd22BAdam2xv4_128`），请打开两个终端。然后，分别从每个终端单独运行上述命令。\n\n从终端 0 开始，运行切片 0 的训练命令，如下所示：\n\n```bash\nexport TPU_PREFIX=\u003Cyour-prefix>\nexport EXP_NAME=C4Spmd22BAdam2xv4_128\nexport LIBTPU_INIT_ARGS=\\\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\\\"\ngcloud compute tpus tpu-vm ssh $TPU_PREFIX-0 --zone=us-central2-b --worker=all \\\n--command=\"LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS \\\npython3 \u002Fhome\u002Fyooh\u002F.local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\"\n``` \n\n从终端 1 同时运行切片 1 的训练命令，如下所示：\n\n```bash\nexport TPU_PREFIX=\u003Cyour-prefix>\nexport EXP_NAME=C4Spmd22BAdam2xv4_128\nexport LIBTPU_INIT_ARGS=\\\"--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_enable_async_all_gather=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE\\\"\ngcloud compute tpus tpu-vm ssh $TPU_PREFIX-1 --zone=us-central2-b --worker=all \\\n--command=\"LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS \\\npython3 \u002Fhome\u002Fyooh\u002F.local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n--exp=tasks.lm.params.c4_multislice.${EXP_NAME} --job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\"\n```\n\n### MaxText 转换为 Pax\n本表格详细介绍了 MaxText 变量名称如何被转换为 Pax 的具体细节。\n\n\n请注意，MaxText 为多个参数（如 base_num_decoder_layers、base_emb_dim、base_mlp_dim、base_num_heads）设置了“缩放因子”，用于计算最终的值。\n\n\n另外需要说明的是，虽然 Pax 将 DCN 和 ICN 的 MESH_SHAPE 作为数组来处理，但在 MaxText 中，DCN 和 ICI 分别对应着独立的变量：data_parallelism、fsdp_parallelism 和 tensor_parallelism。由于这些变量默认设置为 1，因此在本次翻译表中，仅记录了那些值大于 1 的变量。\n\n也就是说，`ICI_MESH_SHAPE = [ici_data_parallelism, ici_fsdp_parallelism, ici_tensor_parallelism]`，而 `DCN_MESH_SHAPE = [dcn_data_parallelism, dcn_fsdp_parallelism, dcn_tensor_parallelism]`。\n\n\n| Pax C4Spmd22BAdam2xv4_128 | | MaxText 2xv4-128.sh | | （在应用缩放因子后） |\n|---------------------------|--------------|---------------------------------|----------|-----------------------:|\n| | | 缩放因子（应用于接下来的四个变量） | 3 | |\n| NUM_LAYERS | 48 | base_num_decoder_layers | 16 | 48 |\n| MODEL_DIMS | 6144 | base_emb_dim | 2048 | 6144 |\n| HIDDEN_DIMS | 24576 | MODEL_DIMS * 4 (= base_mlp_dim) | 8192 | 24576 |\n| NUM_HEADS | 24 | base_num_heads | 8 | 24 |\n| DIMS_PER_HEAD | 256 | head_dim | 256 | |\n| PERCORE_BATCH_SIZE | 16 | per_device_batch_size | 16 | |\n| MAX_SEQ_LEN | 1024 | max_target_length | 1024 | |\n| VOCAB_SIZE | 32768 | vocab_size | 32768 | |\n| FPROP_DTYPE | jnp.bfloat16 | dtype | bfloat16 | |\n| USE_REPEATED_LAYER | TRUE | | | |\n| SUMMARY_INTERVAL_STEPS | 10 | | | |\n| ICI_MESH_SHAPE | [1, 64, 1] | ici_fsdp_parallelism | 64 | |\n| DCN_MESH_SHAPE | [2, 1, 1] | dcn_data_parallelism | 2 | |\n\n\n# 数据输入\n\n## 简介\n\n输入是一个 `BaseInput` 类的实例，用于将数据输入模型，以进行训练、评估或解码。\n\n```python\nclass BaseInput:\n\n  def get_next(self):\n    pass\n\n  def reset(self):\n    pass\n```\n\n它类似于一个迭代器：`get_next()` 会返回一个 `NestedMap`，其中每个字段\n都是一维数值数组，其首维表示批量大小。\n\n每个输入由 `BaseInput.HParams` 的子类进行配置。在本页面中，我们用 `p` 来表示 `BaseInput.Params` 的实例，并将其初始化为 `input`。\n\n## 多主机输入\n\n在 Pax 中，数据始终是多主机的：每个 Jax 进程都会拥有独立且独立的 `input` 实例。它们的参数会拥有不同的 `p.infeed_host_index`，由 Pax 自动设置。\n\n因此，在每台主机上看到的本地批量大小为 `p.batch_size`，而全局批量大小则为 `(p.batch_size * p.num_infeed_hosts)`。通常情况下，我们会将 `p.batch_size` 设置为 `jax.local_device_count() * PERCORE_BATCH_SIZE`。\n\n由于具备多主机特性，`input` 必须进行适当的分片处理。\n\n在训练时，每个 `input` 从不输出相同的批次；而在对有限数据集进行评估时，每个 `input` 应在相同数量的批次后终止。最佳方案是让输入实现能够正确地对数据进行分片，使得不同主机上的每个 `input` 不会相互重叠。如果无法做到这一点，也可以使用不同的随机种子，以避免在训练过程中出现重复的批次。\n\n## 评估数据的输入\n\n在训练数据中，`input.reset()` 从不被调用，但在评估（或解码）数据中，可以调用 `input.reset()`。\n\n\n对于每次评估（或解码）运行，Pax 会通过多次调用 `input.get_next()`，从 `input` 中获取 `N` 批次。`N` 的数量可以由用户通过 `p.eval_loop_num_batches` 固定指定；或者，`N` 也可以是动态的（`p.eval_loop_num_batches=None`），即我们不断调用 `input.get_next()`，直到耗尽所有数据（通过抛出 `StopIteration` 或 `tf.errors.OutOfRange`）。\n\n如果 `p.reset_for_eval=True`，`p.eval_loop_num_batches` 将被忽略，`N` 会根据实际需求动态确定，以满足数据的消耗需求。在这种情况下，应将 `p.repeat` 设置为 False，否则会导致无限的解码\u002F评估过程。\n\n如果 `p.reset_for_eval=False`，Pax 将获取 `p.eval_loop_num_batches` 批次。此时，应将 `p.repeat=True` 设置为 True，以确保数据不会过早耗尽。\n\n需要注意的是，LingvoEvalAdaptor 输入要求 `p.reset_for_eval=True`。\n\n\n|                          | `N`: 静态             | `N`: 动态            |\n| ------------------------ | ----------------------- | ----------------------- |\n| `p.reset_for_eval=True`  | 每次评估运行会使用前 `N` 批次。尚未支持 `eval_loop_num_batches`。 | 每个评估运行一次一个 epoch。 |\n:                          : 第一批次开始使用 `N` 批次。未支持 `eval_loop_num_batches`。 | 输入必须是有限的。 |\n| `p.reset_for_eval=False` | 每个评估运行会使用非重叠的 `N` 批次，按滚动方式逐步获取数据，依据 `eval_loop_num_batches` 的设定。 | 不支持此功能。输入必须无限循环。 |\n|                         |                         | 请将 `p.repeat` 设置为 False，否则可能导致解码\u002F评估过程无限进行。 |\n\n如果在单个 epoch 内执行解码\u002F评估（即当 `p.reset_for_eval=True` 时），输入必须正确处理分片问题，确保每个分片在完成相同数量的批次后，以完全相同的步长进行下一轮操作。这通常意味着输入需要对评估数据进行填充。`SeqIOInput` 和 `LingvoEvalAdaptor` 会自动完成这一工作（详见下文）。\n\n### 评估指标\n\n对于大多数输入，我们只会调用 `get_next()` 来获取数据批次。然而，有一种类型的评估数据例外——在这种情况下，“如何计算指标”也会在输入对象上进行定义。\n\n此功能仅适用于 `SeqIOInput`，该输入定义了一些标准的评估基准。具体而言，Pax 使用 SeqIO 任务中定义的 `predict_metric_fns` 和 `score_metric_fns` 来计算评估指标（尽管 Pax 并不直接依赖于 SeqIO 评估器）。\n\n## 最佳实践\n\n当模型使用多种输入时，无论是训练\u002F评估阶段之间，还是在预训练\u002F微调不同训练数据之间的切换过程中，用户必须确保各输入所使用的分词器完全一致，尤其是在导入由他人实现的不同输入时。\n\n用户可以通过使用 `input.ids_to_strings()` 对分词器进行校验，以验证其正确性。\n\n建议始终通过查看若干批次的数据来对数据进行初步检查。用户可以轻松地在 Colab 中重现该参数，并对数据进行检查：\n\n```python\np = ... # 指定预期的输入参数\ninp = p.Instantiate()\nb = inp.get_next()\nprint(b)\n```\n\n训练数据通常不应使用固定的随机种子。这是因为如果训练任务被提前终止，训练数据将会开始重复出现。尤其对于 Lingvo 输入，我们建议在训练数据中将 `p.input.file_random_seed = 0` 设置为默认值。\n\n为了测试分片是否得到正确处理，用户可以手动为 `p.num_infeed_hosts` 和 `p.infeed_host_index` 设置不同的值，观察实例化后的输入是否生成不同的批次。\n\n## 输入类型\n\nPax 支持三种类型的输入：SeqIO、Lingvo 以及自定义输入。\n\n### SeqIO\n\n`SeqIOInput` 可用于导入数据集。\n\nSeqIO 输入会自动处理评估数据的正确分片与填充。\n\n### Lingvo\n\n`LingvoInputAdaptor` 可用于导入数据集。\n\n该输入完全委托给 Lingvo 实现，而 Lingvo 实现可能自动处理分片，也可能不自动处理分片。\n\n对于基于 GenericInput 的 Lingvo 输入实现，若采用固定 `packing_factor`，我们建议使用 `LingvoInputAdaptorNewBatchSize` 来为内部 Lingvo 输入指定更大的批量大小，并将所需的（通常要小得多）批量大小设置为 `p.batch_size`。\n\n对于评估数据，我们建议使用 `LingvoEvalAdaptor`，以在单个 epoch 内运行评估时，自动处理分片与填充。\n\n### 自定义\n\n自定义类继承自 `BaseInput`。用户可以自行实现子类，通常使用 `tf.data` 或 SeqIO。\n\n用户也可以直接继承现有的输入类，仅对批次的后处理进行个性化定制。例如：\n\n```python\nclass MyInput(base_input.LingvoInputAdaptor):\n\n  def get_next(self):\n    batch = super().get_next()\n    # 修改批次：batch.new_field = ...\n    return batch\n```\n\n# Pax 关键组件\n\n## 超参数\n\n超参数是定义模型和配置实验的重要组成部分。\n\n为了更好地与 Python 工具集成，Pax\u002FPraxis 采用了一种 Python 风格的 dataclass 配置方式来管理超参数。\n\n```python\nclass Linear(base_layer.BaseLayer):\n  \"\"\"线性层，无偏置。”\n\n  class HParams(BaseHParams):\n    \"\"\"此层类的关联超参数。\n\n    属性：\n      input_dims：输入的深度。\n      output_dims：输出的深度。\n    \"\"\"\n    input_dims: int = 0\n    output_dims: int = 0\n```\n\n### 嵌套\n\n超参数数据类也可以进行嵌套。如下面的例子所示，`linear_tpl` 属性是一个嵌套的 `Linear.HParams`。\n\n```python\nclass FeedForward(base_layer.BaseLayer):\n  \"\"\"前馈层，带激活函数。”\n\n  class HParams(BaseHParams):\n    \"\"\"此层类的关联超参数。\n\n    属性：\n      input_dims：输入的深度。\n      output_dims：输出的深度。\n      has_bias：是否添加偏置权重。\n      linear_tpl：线性层的参数。\n      activation_tpl：激活函数层的参数。\n    \"\"\"\n    input_dims: int = 0\n    output_dims: int = 0\n    has_bias: bool = True\n    linear_tpl: BaseHParams = sub_config_field(Linear.HParams)\n    activation_tpl: activations.BaseActivation.HParams = sub_config_field(\n        ReLU.HParams)\n```\n\n\n## 层\n\n层代表任意函数，该函数可能包含可训练的参数。层可以作为子级包含其他层。层是模型的核心构建模块。层继承自 Flax nn.Module。\n\n通常，层会定义两个方法：\n\n### setup\n\n该方法用于创建可训练的权重和子级层。\n\n### fprop\n\n该方法定义前向传播函数，根据输入计算出某些输出。此外，fprop 还可以添加摘要信息或跟踪辅助损失。\n\n### Fiddle 与共享层\nFiddle 是一个开源的 Python 首先配置库，专为机器学习应用设计。Pax\u002FPraxis 支持与 Fiddle Config\u002FPartial 的互操作性，以及一些高级功能，如即时错误检查和共享参数。\n\n```python\nfdl_config = Linear.HParams.config(input_dims=1, output_dims=1)\n\n# 一个拼写错误。\nfdl_config.input_dimz = 31337  # 一旦发现拼写错误，就会立即抛出异常，快速捕捉错误！\n\n\nfdl_partial = Linear.HParams.partial(input_dims=1)\n```\n\n借助 Fiddle，层可以被配置为共享（例如，仅实例化一次，并使用共享的可训练权重）。\n\n## 模型\n\n模型仅定义网络结构，通常是多个层的集合，并定义了与模型交互的接口，例如解码等。\n\n一些基础模型示例包括：\n\n*   语言模型\n*   序列模型\n*   分类模型\n\n## 任务\n\n任务包含多个模型以及学习器\u002F优化器。最简单的任务子类是 `SingleTask`，它需要以下超参数：\n\n```python\n  class HParams(base_task.BaseTask.HParams):\n    \"\"\"任务参数。\n\n    属性：\n      name：此任务对象的名称，必须是有效的标识符。\n      model：底层 JAX 模型，封装了所有层。\n      train：用于控制任务训练方式的超参数。\n      metrics：一个 BaseMetrics 汇总类，用于确定如何计算指标。\n      loss_aggregator：一个 LossAggregator 汇总类，用于确定如何聚合损失（例如，单个损失或多损失）。\n      vn：用于控制变分噪声的超参数。\n```\n\n## 发布版本\nPyPI 版本 | 提交\n------------ | ----------------------------------------\n0.1.0        | 546370f5323ef8b27d38ddc32445d7d3d1e4da9a\n\n\n\n    版权所有 2022 Google LLC\n\n    根据 Apache 许可证第 2.0 版（“许可证”）发布；您不得使用本文件，除非遵守许可证条款。您可以从以下网址获取许可证副本：\n\n        https:\u002F\u002Fwww.apache.org\u002Flicenses\u002FLICENSE-2.0\n\n    除非法律另有要求或书面协议约定，否则软件将以“原样”提供，且不附带任何明示或暗示的保证。请参阅许可证，了解有关特定语言的许可条款及限制，以及相关法律责任。","# Paxml 快速上手指南（中文）\n\n> Paxml（简称 Pax）是 Google 开源的基于 JAX 的大规模机器学习实验框架，专为 TPU\u002FGPU 训练优化。\n\n---\n\n## 环境准备\n\n| 项目 | 要求 |\n|---|---|\n| 操作系统 | Linux（推荐 Ubuntu 20.04+） |\n| Python | 3.8 或 3.9 |\n| 硬件 | 至少 1 张 TPU v4-8 或 1 张 NVIDIA A100\u002FH100 |\n| 网络 | 可访问 Google Cloud（国内用户需科学上网） |\n\n---\n\n## 安装步骤\n\n### 1. 创建 Cloud TPU VM（单节点示例）\n\n```bash\nexport ZONE=us-central2-b\nexport VERSION=tpu-vm-v4-base\nexport PROJECT=\u003Cyour-gcp-project>\nexport ACCELERATOR=v4-8          # 8 核 TPU\nexport TPU_NAME=paxml\n\ngcloud compute tpus tpu-vm create $TPU_NAME \\\n  --zone=$ZONE --version=$VERSION \\\n  --project=$PROJECT \\\n  --accelerator-type=$ACCELERATOR\n```\n\n### 2. SSH 登录 TPU VM\n\n```bash\ngcloud compute tpus tpu-vm ssh $TPU_NAME --zone=$ZONE\n```\n\n### 3. 安装 Paxml（二选一）\n\n#### 方案 A：稳定版（推荐）\n\n```bash\npython3 -m pip install -U pip\npython3 -m pip install paxml jax[tpu] \\\n  -f https:\u002F\u002Fstorage.googleapis.com\u002Fjax-releases\u002Flibtpu_releases.html\n```\n\n#### 方案 B：开发版（可修改源码）\n\n```bash\n# 先装 praxis\ngit clone https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpraxis\npip install -e praxis\n\n# 再装 paxml\ngit clone https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\npip install -e paxml\npip install \"jax[tpu]\" -f https:\u002F\u002Fstorage.googleapis.com\u002Fjax-releases\u002Flibtpu_releases.html\n```\n\n---\n\n## 基本使用\n\n### 1. 运行 2B 参数测试模型（SPMD）\n\n```bash\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n  --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2BLimitSteps \\\n  --job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n```\n\n### 2. 运行 1B 参数 C4 数据集实验\n\n```bash\npython3 .local\u002Flib\u002Fpython3.8\u002Fsite-packages\u002Fpaxml\u002Fmain.py \\\n  --exp=tasks.lm.params.c4.C4Spmd1BAdam4Replicas \\\n  --job_log_dir=gs:\u002F\u002F\u003Cyour-bucket>\n```\n\n### 3. 运行 Jupyter Notebook（可选）\n\n```bash\n# 本地端口转发\ngcloud compute tpus tpu-vm ssh $TPU_NAME \\\n  --project=$PROJECT --zone=$ZONE \\\n  --ssh-flag=\"-4 -L 8080:localhost:8080\"\n\n# TPU VM 内安装\npip install notebook markupsafe==2.0.1\nexport PATH=\u002Fhome\u002F$USER\u002F.local\u002Fbin:$PATH\n\n# 启动\njupyter notebook --no-browser --port=8080\n```\n\n浏览器访问 `http:\u002F\u002Flocalhost:8080` 并输入 token 即可。\n\n---\n\n> 如需在 GPU 上运行，请参考 [NVIDIA Rosetta](https:\u002F\u002Fgithub.com\u002FNVIDIA\u002FJAX-Toolbox\u002Ftree\u002Fmain\u002Frosetta\u002Frosetta\u002Fprojects\u002Fpax) 以获得 H100 FP8 优化版本。","一家 20 人规模的初创公司正在训练一个 70 亿参数的中文对话大模型，用于给 B 端客户提供智能客服 API。团队只有 2 名算法工程师和 1 名 MLOps，预算有限，希望在 4 周内完成首轮迭代。\n\n### 没有 paxml 时\n- 需要手写大量 Jax+pjit 代码才能把模型拆到 8 张 TPU v4 上，光是调试张量并行和流水并行就花了 5 天  \n- 训练 10 k step 时 GPU\u002FTPU 利用率只有 38 %，大量时间在等待通信，导致 4 周排期被拉长到 7 周  \n- 超参数一改就要重新写配置脚本，实验管理靠 Excel 手动记录，结果经常对不上  \n- 想在云上拉起 64 卡做消融实验，需要额外写 200 行 bash 脚本，还踩了配额和镜像版本坑  \n\n### 使用 paxml 后\n- 直接复用 `LmCloudSpmd7B` 模板，30 行 YAML 就完成 8 卡并行配置，2 小时跑通第一次训练  \n- 内置的 `flop_utilization` 指标稳定在 62 %，同样 10 k step 训练时间从 3.5 天降到 2.1 天  \n- 通过 `--exp` 参数一键切换学习率、序列长度等超参数，实验结果自动写进 TensorBoard 和 GCS，回溯零成本  \n- 需要扩容到 64 卡时，把 `ACCELERATOR=v4-64` 改一行即可，paxml 自动处理 slice 划分和 checkpoint 分片  \n\npaxml 让这支小团队在 4 周内如期交付了首个可用模型，并把后续迭代周期缩短一半。","https:\u002F\u002Foss.gittoolsai.com\u002Fimages\u002Fgoogle_paxml_6b5009ec.png","google","Google","https:\u002F\u002Foss.gittoolsai.com\u002Favatars\u002Fgoogle_c4bedcda.png","Google ❤️ Open Source",null,"opensource@google.com","GoogleOSS","https:\u002F\u002Fopensource.google\u002F","https:\u002F\u002Fgithub.com\u002Fgoogle",[85,89,93,97,101],{"name":86,"color":87,"percentage":88},"Python","#3572A5",87.2,{"name":90,"color":91,"percentage":92},"Jupyter Notebook","#DA5B0B",7,{"name":94,"color":95,"percentage":96},"Starlark","#76d275",4.2,{"name":98,"color":99,"percentage":100},"Shell","#89e051",1.2,{"name":102,"color":103,"percentage":104},"Dockerfile","#384d54",0.4,550,70,"2026-03-20T23:10:17","Apache-2.0",4,"Linux","可选；官方示例基于 Google Cloud TPU v4（v4-8\u002Fv4-64\u002Fv4-128\u002Fv4-384）。NVIDIA GPU 支持由 NVIDIA Rosetta 分支提供，未说明具体型号与显存要求。","未说明",{"notes":114,"python":115,"dependencies":116},"1) 官方示例默认在 Google Cloud TPU VM 上运行，需提前创建 TPU 实例；2) 若使用 GPU，请转向 NVIDIA Rosetta 分支；3) 多 slice 训练需为每个 slice 单独启动终端并设置 LIBTPU_INIT_ARGS；4) 日志与 checkpoint 需写入 Google Cloud Storage（gs:\u002F\u002F）；5) 首次安装建议从 release 分支获取 requirements.txt 以锁定依赖版本。","3.8",[117,67,118,119,120,121],"jax[tpu]","praxis","orbax==0.1.1","notebook","markupsafe==2.0.1",[26,13,15],[124,125,126,127,128,129,130],"c4","jax","large-language-models","llm","model-flops","parallelism","gpt","2026-03-27T02:49:30.150509","2026-04-06T05:16:47.136778",[134,139,144,149],{"id":135,"question_zh":136,"answer_zh":137,"source_url":138},6205,"运行 bazel 时出现 “error loading package 'paxml'：Every .bzl file must have a corresponding package” 怎么办？","该错误是因为缺少 BUILD 文件导致 bazel 无法定位 praxis 的 build-visibility.bzl。\n1. 在 praxis 目录下新建一个空文件 BUILD（内容可以为空）。\n2. 如果仍报错，可临时修改 paxml 目录下所有 BUILD 文件：\n   将\n   ```\n   load(\"\u002F\u002Fpraxis:build-visibility.bzl\", \"JAX_VISIBILITY\")\n   package(default_visibility = JAX_VISIBILITY)\n   ```\n   替换为\n   ```\n   #load(\"\u002F\u002Fpraxis:build-visibility.bzl\", \"JAX_VISIBILITY\")\n   package(default_visibility = [\"\u002F\u002Fvisibility:public\"])\n   ```\n3. 确保 praxis 目录与 paxml 同级，且都已 git clone 完整。","https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Fissues\u002F1",{"id":140,"question_zh":141,"answer_zh":142,"source_url":143},6206,"pip install paxml[gpu] 时出现 “Cannot install paxml and paxml[gpu]==1.0.0 because these package versions have conflicting dependencies” 如何解决？","该问题已在最新代码中修复，直接重新安装即可：\n```bash\npip install -e '.[gpu]'   # 从源码安装最新 HEAD\n```\n如果仍想使用旧版本，可手动降级 seqio：\n```bash\npip install seqio==0.0.15\n```\n","https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Fissues\u002F25",{"id":145,"question_zh":146,"answer_zh":147,"source_url":148},6207,"在 GPU 上使用 Pipeline Parallelism 时，设置 USE_REPEATED_LAYERS=True 会报错，该如何处理？","Pipeline + repeat 两层 jax.scan 会导致意外的 remat 行为，官方目前不推荐这种组合。\n替代方案：\n1. 将 USE_REPEATED_LAYERS=False。\n2. 通过调整 checkpoint_policy 降低显存占用：\n   在 praxis\u002Flayers\u002Fpipeline.py 第 112 行附近修改\n   ```python\n   checkpoint_policy: AutodiffCheckpointType = AutodiffCheckpointType.SAVE_NOTHING\n   ```\n   可选策略：SAVE_NOTHING（边界缓存最小，阶段内缓存最大）、SAVE_ITERATION_INPUT 等，按显存\u002F速度权衡选择。","https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Fissues\u002F5",{"id":150,"question_zh":151,"answer_zh":152,"source_url":153},6208,"Pipeline Parallelism 在 NUM_MICROBATCHES > 1 时直接崩溃，提示 “Fatal Python error: Aborted” 怎么办？","该问题已在内部修复，请更新到最新代码即可。若仍出现，可：\n1. 确认已拉取最新 main 分支。\n2. 如仍复现，可导出 HLO 文件（见 issue 中 Google Drive 链接）并提交新的 issue 给维护者。","https:\u002F\u002Fgithub.com\u002Fgoogle\u002Fpaxml\u002Fissues\u002F4",[155,160,165,170,174,178,183,187,192,196,200],{"id":156,"version":157,"summary_zh":158,"released_at":159},105792,"paxml-v1.4.0","paxml-v1.4.0\r\n\r\nUpdate paxml 1.4.0 requirements","2024-04-09T22:33:36",{"id":161,"version":162,"summary_zh":163,"released_at":164},105793,"paxml-v1.3.1","paxml-v1.3.1\r\n\r\nUpdate paxml 1.3.1 requirements","2024-02-21T09:09:47",{"id":166,"version":167,"summary_zh":168,"released_at":169},105794,"paxml-v1.3.0","paxml-v1.3.0\r\n\r\nUpdate Pax 1.3.0 requirements","2024-02-17T13:41:33",{"id":171,"version":172,"summary_zh":79,"released_at":173},105795,"paxml-v1.2.0","2023-10-19T22:11:59",{"id":175,"version":176,"summary_zh":79,"released_at":177},105796,"paxml-v1.1.0","2023-08-28T21:56:58",{"id":179,"version":180,"summary_zh":181,"released_at":182},105797,"paxml-v1.0.0","see RELEASE.md for release notes","2023-04-12T21:34:58",{"id":184,"version":185,"summary_zh":79,"released_at":186},105798,"paxml-v0.4.0","2023-04-12T21:33:44",{"id":188,"version":189,"summary_zh":190,"released_at":191},105799,"paxml-v0.3.0","See RELEASE.md for release notes","2023-02-03T08:36:43",{"id":193,"version":194,"summary_zh":79,"released_at":195},105800,"paxml-v0.2.1","2022-11-22T02:21:09",{"id":197,"version":198,"summary_zh":79,"released_at":199},105801,"paxml-v0.2.0","2022-11-15T08:06:31",{"id":201,"version":202,"summary_zh":79,"released_at":203},105802,"paxml-v0.0.1","2022-06-15T23:52:50"]