atlas

GitHub
554 72 困难 1 次阅读 1周前NOASSERTION语言模型开发框架
AI 解读 由 AI 自动生成,仅供参考

Atlas 是一个用于支持“检索增强语言模型”研究的开源代码库,核心目标是实现高效的少样本学习。它通过联合预训练一个基于段落的密集检索器和一个编码器-解码器语言模型,让 AI 能够在回答问题时主动“查阅”外部知识库,从而显著提升准确性。

Atlas 主要解决了传统大语言模型依赖海量参数且知识更新滞后的痛点。研究表明,即便仅使用 64 个示例进行训练,Atlas 在自然问答任务上的准确率也能超越参数量大其 50 倍的模型;而在全量数据微调后,它更是刷新了当时的行业最佳纪录。此外,Atlas 证明了文档索引内容可以轻松更新,这意味着模型无需重新训练即可获取最新知识,有效缓解了模型“幻觉”和知识过时的问题。

这款工具特别适合人工智能研究人员、NLP 工程师以及对检索增强生成(RAG)技术感兴趣的开发者使用。它支持从预训练、微调到检索评估的全流程操作,能够处理包括开放域问答、多项选择、事实核查及 KILT 基准测试在内的多种任务。

在技术亮点方面,Atlas 支持高达 110 亿参数的 Fusion-in-Decoder 模型训练,并具备在训练循环中进行端到端检索的能力。它内置了快速的并行分布式 GPU 检索功能,支持对多达 4 亿个段落进行精确或近似搜索,同时提供了内存优化方案和索引原地快速刷新机制,确保在处理大规模语料库时依然保持高效与灵活。需要注意的是,该仓库目前不再维护,代码按原样提供,适合用于学术研究和技术参考。

使用场景

某金融科技公司的算法团队正在构建一个智能投研助手,需要让模型基于最新的上市公司财报和实时新闻,准确回答分析师提出的复杂事实性问题。

没有 atlas 时

  • 知识滞后严重:传统大模型依赖训练时的静态权重,无法获取最新发布的财报数据,导致回答过时甚至错误,必须频繁且昂贵地重新全量微调模型。
  • 小样本适应差:面对特定垂直领域的专业问答,缺乏有效的少样本学习能力,需要收集成千上万条标注数据才能达到可用精度,冷启动成本极高。
  • 推理幻觉频发:模型在缺乏确切依据时倾向于“编造”财务数据,且由于是黑盒生成,难以追溯信息来源,无法满足金融场景对可解释性和准确性的严苛要求。
  • 资源消耗巨大:为了提升精度盲目扩大模型参数量(如使用千亿级参数模型),导致推理延迟高、算力成本高昂,难以在生产环境大规模部署。

使用 atlas 后

  • 知识实时更新:利用 Atlas 的检索增强机制,直接外挂最新的文档索引(如维基百科或内部财报库),无需重新训练即可让模型掌握最新信息,索引支持快速原地刷新。
  • 高效少样本学习:凭借强大的检索增强能力,Atlas 仅用 64 个示例即可在自然问答任务上达到超过 45% 的准确率,大幅降低了对大规模标注数据的依赖。
  • 答案有据可依:模型先检索相关段落再生成答案,显著减少幻觉。每个回答都能关联到具体的参考文档片段,便于人工核查与审计,提升了可信度。
  • 小模型高性能:Atlas 以远少于超大模型的参数量(少 50 倍),在多项基准测试中超越了 540B 参数的模型,实现了更高的性价比和更快的推理速度。

核心价值在于 Atlas 通过检索增强与少样本学习的结合,以极低的算力和数据成本,实现了具备实时知识更新能力且高可信度的专业问答系统。

运行环境要求

操作系统
  • Linux
GPU
  • 必需 NVIDIA GPU
  • 官方测试环境使用 CUDA 11.3
  • 支持多卡分布式训练(示例为4节点每节点8卡)
  • 显存需求取决于模型大小和批次大小,支持高达 110 亿参数模型的训练,建议使用梯度检查点和分片优化以节省显存
内存

未说明(但需足够加载大规模语料索引,支持高达 4 亿条向量索引,建议大内存)

依赖
notes1. 代码库已不再维护,按原样提供研究代码。2. 强烈建议使用 conda 进行环境管理和依赖安装。3. 主要支持 T5 架构的编码器-解码器语言模型和 Contriever 架构的检索器。4. 数据文件需为 jsonlines 格式。5. 支持在训练循环中进行端到端的检索增强训练,并支持索引的就地刷新。
python3.8
pytorch==1.11.0
fairscale==0.4.6
transformers==4.18.0
numpy==1.22.4
faiss-gpu==1.7.2
atlas hero image

快速开始

Atlas: 基于检索增强语言模型的少样本学习

该仓库已不再维护,研究代码按原样提供。

本仓库包含用于论文《Atlas: 基于检索增强语言模型的少样本学习》(https://arxiv.org/pdf/2208.03299.pdf)的预训练模型、语料库、索引以及预训练、微调、检索和评估的相关代码。

请阅读我们的 Atlas 博客文章,以快速了解该项目,并学习如何使用 torchrun 运行代码(无需 Slurm)。

我们联合预训练了一个基于检索增强的序列到序列语言模型,该模型由一个基于段落的密集检索器和一个编码器-解码器语言模型组成。我们在 MMLU、KILT 和 NaturalQuestions 等广泛的任务上进行了评估,并研究了文档索引内容的影响,结果表明索引可以轻松更新。值得注意的是,当使用 2018 年的维基百科索引时,Atlas 在仅使用 64 个示例的情况下,在 Natural Questions 数据集上达到了超过 45% 的准确率,尽管参数量仅为 540B 参数模型的 1/50,仍比后者高出 6 个百分点。此外,Atlas 在更大的数据集上进行微调后表现也非常出色——在完整的 Natural Questions 数据上微调后,Atlas 创下了 64% 的新 SOTA 记录,比当前最佳水平高出 8 个百分点。

本仓库支持大规模和小规模数据集的预训练与微调。它具备以下功能:

  • 训练大型融合解码器序列到序列模型,最高已测试至 11B 参数;
  • 使用多种蒸馏方法将融合解码器模型中的相关性信号提炼到密集检索模型中;
  • 对用户提供的段落语料库(已测试最大达 4 亿段落,约 400 亿词)进行端到端的检索增强训练,并在训练过程中集成检索机制;
  • 支持掩码语言建模、前缀语言建模、维基百科章节生成、开放域问答、多选题问答、事实核查以及 KILT 等任务的训练(也可支持任意序列到序列任务);
  • 提供基于 GPU 的快速并行分布式精确和近似最大内积搜索,用于密集向量检索;
  • 支持快速就地索引刷新;
  • 多种内存优化技术及在循环中训练检索器时保持快速准确检索的方法;
  • 更多功能,请参阅命令行参数或自述文件。

目录

安装

Atlas 代码库依赖以下软件包:

  • Python 3(已测试 3.8)
  • fairscale(已测试 0.4.6)
  • transformers(已测试 4.18.0)
  • numpy(已测试 1.22.4)
  • faiss(已测试 1.7.2)

我们建议使用 Conda 进行安装。以下命令将安装所有依赖项:

git clone https://github.com/facebookresearch/atlas.git
cd atlas
conda create --name atlas-env python=3.8
conda activate atlas-env
conda install pytorch==1.11.0 cudatoolkit=11.3 -c pytorch
conda install -c pytorch faiss-gpu=1.7.2 cudatoolkit=11.3
pip install -r requirements.txt

入门与代码库概览

Atlas 仓库提供了训练和评估检索增强生成模型的功能,该模型由编码器-解码器语言模型和密集向量检索器组成。

我们当前支持 T5 架构作为编码器-解码器语言模型,以及 Contriever 架构作为检索器(目前暂不计划支持其他架构,但欢迎提交 PR)。 Atlas 模型由 Contriever 检索器和融合解码器(FID)架构(使用 T5)组成。如果您有兴趣,可以分别在 这里这里 了解更多关于 Contriever 和 FID 的信息,不过所有必要的功能都已经在此代码库中重新实现。

与大多数标准 NLP 训练代码库相比,最大的不同在于 Atlas 会实时进行检索,并且可以在原地刷新其检索嵌入索引。 这是通过一个自定义设计的分布式 GPU 索引实现的,它可以自动处理快速且可扩展的检索。

关于检索实现方式的说明: 在启动训练或评估运行时,代码库首先会加载预训练模型,然后每个 GPU 工作节点会加载一部分要检索的文档——如果有 N 个 GPU,则每个节点会加载 1/N 的文档分片。 随后,每个工作节点会使用检索器的嵌入模块为其分片中的文档生成嵌入,并将这些文档嵌入保留在 GPU 内存中(也可以选择构建 FAISS 索引)。 此时,文档和嵌入分片(称为“索引”)可以选择保存到磁盘,以避免每次运行时都重新计算索引。 检索是并行进行的,每个 GPU 工作节点会对其分片中的所有查询执行精确的最大内积搜索。 更多关于检索的详细信息请参见 检索与索引详情 部分。 请注意,以上所有步骤均由代码库自动完成,因此用户无需过多了解或担心嵌入、索引刷新或检索的具体实现方式,只需注意以下几点:

  1. 用户可以通过传递磁盘上格式正确的文档路径(或已保存的索引),轻松从任何他们喜欢的文档集中进行检索;
  2. 嵌入、索引刷新和检索的速度会随着 GPU 工作节点数量的增加而加快;
  3. 根据可用的 GPU 数量和 CPU 内存大小,Atlas 可以支持训练参数量超过 110 亿的模型,以及包含 4 亿+ 向量的索引,约相当于 400 亿个词(假设每篇文档约 100 个词)。

训练和评估采用数据并行模式:对于 N 个 GPU 工作节点,每个节点处理总 mini-batch 数据的 1/N。为了节省训练时的内存,优化器状态和梯度可以使用 fairscale 的 ShardedDataParallel 进行分片。

所有数据文件(检索器文档以及训练/验证/测试数据)都应以 jsonlines(“jsonl”)格式提供。 用于检索的文档应为 JSON 序列化的对象,每行包含 texttitle 文本字段,每行对应一篇文档。 示例文档文件可用于维基百科(详见 语料库)。 训练/验证/测试数据文件也应为 JSON 序列化的对象,每行代表一个实例。字段名称取决于具体任务(详见 任务),例如,在 NaturalQuestions 数据集中,所需的字段是 question(问题字符串)和 answers(参考答案字符串列表)。

代码库有两个入口脚本:train.py 用于训练,evaluate.py 用于测试时的评估(以及 独立检索模式,如果需要)。您可以通过运行 python train.py -h 打印命令行参数来查看 Atlas 的完整功能(完整输出 此处)。

展示代码库最简单的方式就是通过一个示例:

以下示例展示了如何使用 Atlas-large 在 NaturalQuestions 数据集上进行少样本微调和评估(这些操作也可通过 example_scripts/nq/ 中的可运行 sbatch 脚本完成),并从 2018 年的维基百科转储(约 3000 万篇文档)中进行检索。

# 假设使用 4 个节点,每个节点配备 8 个 GPU
DATA_DIR=./atlas_data
SIZE=large # 使用 large 版本(比 base 版本慢,但仍然相当快速且易于使用,只是准确率不如 xl 或 xxl)
 
# 下载 NaturalQuestions 数据
python preprocessing/prepare_qa.py --output_directory ${DATA_DIR}/data/
# 下载 2018 年的维基百科语料库
python preprocessing/download_corpus.py --corpus corpora/wiki/enwiki-dec2018 --output_directory ${DATA_DIR}

# 下载预训练的 Atlas-large 模型
python preprocessing/download_model.py --model models/atlas/${SIZE} --output_directory ${DATA_DIR}  

port=$(shuf -i 15000-16000 -n 1)
TRAIN_FILE="${DATA_DIR}/data/nq_data/train.64-shot.jsonl"
EVAL_FILES="${DATA_DIR}/data/nq_data/dev.jsonl"
SAVE_DIR=${DATA_DIR}/experiments/
EXPERIMENT_NAME=my-nq-64-shot-example
TRAIN_STEPS=30

srun python train.py \
    --shuffle \
    --train_retriever \
    --gold_score_mode pdist \ # 检索器的损失函数(参见论文)
    --use_gradient_checkpoint_reader --use_gradient_checkpoint_retriever\ # 使用梯度检查点节省显存,但会降低速度
    --precision fp32 \ # 如果你的 GPU 支持,可以使用 "bf16";fp16 通常不稳定
    --shard_optim --shard_grads \ # 通过这些优化节省显存
    --temperature_gold 0.01 --temperature_score 0.01 \ 
    --refresh_index -1 \ # 对于少样本微调,刷新索引(即重新计算嵌入)成本很高,且收益不大
    --query_side_retriever_training\ # 相反,在少样本场景下,仅微调 Contriever 的查询编码器效果较好。去掉该标志则会微调整个检索器
    --target_maxlength 16 \ # 生成的最大长度
    --reader_model_type google/t5-${SIZE}-lm-adapt \ # Atlas 的架构
    --dropout 0.1 --weight_decay 0.01 --lr 4e-5 --lr_retriever 4e-5 --scheduler linear \ # 优化参数
    --text_maxlength 512 \ # 问题与段落拼接后的最大长度
    --model_path "${DATA_DIR}/models/atlas/${SIZE}" \ # 刚才下载的预训练 Atlas 模型路径(传入 'none' 可从纯 T5 和 Contriever 开始训练)
    --train_data "${DATA_DIR}/data/nq_data/train.64-shot.jsonl" \ # 刚才下载的 64 抽样训练数据集路径
    --eval_data "${DATA_DIR}/data/nq_data/dev.jsonl" \ # 刚才下载的 NQ 验证集路径,用于在训练完成后评估
    --per_gpu_batch_size 1 \
    --n_context 40 \ # 将检索器返回的前 40 篇文档传递给语言模型
    --retriever_n_context 40 \ # 使用前 40 篇文档对检索器进行微调
    --name ${EXPERIMENT_NAME} \ # 实验名称(日志和模型也将保存到该目录)
    --checkpoint_dir ${SAVE_DIR} \ # 日志和模型检查点将保存到 ${SAVE_DIR}/${EXPERIMENT_NAME}
    --eval_freq ${TRAIN_STEPS} \ # 训练结束后进行评估
    --log_freq 4 \ # 每 4 个训练步骤记录一次统计信息。日志会写入 ${SAVE_DIR}/${EXPERIMENT_NAME}/run.log,同时如果已安装 TensorBoard,也会生成 TensorBoard 日志
    --total_steps ${TRAIN_STEPS} \ # 训练指定步数
    --warmup_steps 5 \
    --save_freq ${TRAIN_STEPS} \ # 在本示例中,训练完成后只保存一个检查点
    --main_port $port \ # 用于分布式训练
    --write_results \ # 写出预测结果——它们将保存在检查点文件夹中,${SAVE_DIR}/${EXPERIMENT_NAME}
    --task qa \ # 我们执行的是问答任务
    --index_mode flat \ # 不使用 Faiss,保持索引为扁平结构(建议如此操作,除非使用超大索引或显存非常有限)
    --passages "${DATA_DIR}/corpora/wiki/enwiki-dec2018/text-list-100-sec.jsonl" "${DATA_DIR}/corpora/wiki/enwiki-dec2018/infobox.jsonl"\ # 输入维基百科段落以构建索引并进行检索(我们同时使用文本和信息框)
    --save_index_path ${SAVE_DIR}/${EXPERIMENT_NAME}/saved_index # 将构建的索引保存到此路径

训练脚本首先会对 2018 年版维基百科构建索引并将其保存到检查点文件夹内(${SAVE_DIR}/${EXPERIMENT_NAME})。随后,脚本将以 64 抽样的方式对 Atlas-large NQ 模型进行少样本微调,共 30 步,并从整个 2018 年版维基百科中检索内容。该脚本仅微调检索器的查询编码器及 FID 参数,而保持段落编码器冻结不动(详情请参阅论文或下方的相关说明)。最后,脚本将对验证集进行评估并保存检查点。你可以在 ${SAVE_DIR}/${EXPERIMENT_NAME}/run.log 中查看实验日志,其中记录了约 38% 的 NQ 验证集精确匹配分数(我们的运行结果为 38.4%),以及可进一步检查的预测结果。

要评估模型性能(例如在保留的测试集上),我们可以使用 evaluate.py 入口脚本:

srun python evaluate.py \
    --name 'my-nq-64-shot-example-evaluation' \
    --generation_max_length 16 \
    --gold_score_mode "pdist" \
    --precision fp32 \
    --reader_model_type google/t5-${size}-lm-adapt \
    --text_maxlength 512 \
    --model_path ${SAVE_DIR}/${EXPERIMENT_NAME}/checkpoint/step-30 \ # 现在指向我们刚刚训练好的模型
    --eval_data "${DATA_DIR}/data/nq_data/dev.jsonl ${DATA_DIR}/data/nq_data/test.jsonl" \ # 这次我们将同时评估验证集和测试集
    --per_gpu_batch_size 1 \
    --n_context 40 --retriever_n_context 40 \
    --checkpoint_dir ${SAVE_DIR} \
    --main_port $port \
    --index_mode "flat"  \
    --task "qa" \
    --load_index_path ${SAVE_DIR}/${EXPERIMENT_NAME}/saved_index\ # 不再重新嵌入所有维基百科段落,而是直接加载我们之前保存的索引
    --write_results # 写出推理结果

该脚本将加载模型,并由于指定了通过 --load_index_path 加载已保存的索引,因此不会像之前那样从段落中重新嵌入,而是直接使用索引进行检索。随后,它将对开发集和测试集进行评估。检查 ${SAVE_DIR}/my-nq-64-shot-example-evaluation/run.log 中保存的日志,你会看到与先前相同的验证集精确匹配分数,以及约 38% 的测试集分数(我们的情况是 38.8% EM)。

本自述文件的其余部分将详细描述数据、代码和功能。

可下载的数据和模型

目前,Atlas 的维基百科语料库、预训练模型以及预构建的维基百科索引均可下载。

点击展开:

语料库

我们用于 Atlas 检索和预训练的预处理维基百科转储文件可按如下方式下载:

python preprocessing/download_corpus.py --corpus {语料下载键} --output_directory ${DATA_DIR} 

上述命令将下载一个语料并将其解压到 ${DATA_DIR}/{语料下载键}

可用的语料如下表所示:

语料名称 语料下载键 描述 大小
enwiki-dec2017 corpora/wiki/enwiki-dec2017 2017年12月的维基百科转储,已预处理为段落 30.4M (26.9M 文本, 2.7M 信息框)
enwiki-dec2018 corpora/wiki/enwiki-dec2018 2018年12月的维基百科转储,已预处理为段落(推荐用于 NQ、TriviaQA) 32.1M (28.4M 文本, 3.7M 信息框)
enwiki-aug2019 corpora/wiki/enwiki-aug2019 2019年8月的维基百科转储,已预处理为段落 33.1M (29.4M 文本, 3.8M 信息框)
enwiki-dec2020 corpora/wiki/enwiki-dec2020 2020年12月的维基百科转储,已预处理为段落 35.6M (31.5M 文本, 4.1M 信息框)
enwiki-dec2021 corpora/wiki/enwiki-dec2021 2021年12月的维基百科转储,已预处理为段落 37.5M (33.1M 文本, 4.3M 信息框)

段落文件采用 jsonl 格式,每行序列化为一个 JSON 对象。默认情况下,每个段落应按以下格式组织:

{
    "id": "0", # 段落应具有唯一 ID
    "title": "兰花", # 应指定该段落来自的页面标题(如果没有合适的标题,可以为空字符串)
    "text": "兰花与其他植物很容易区分,因为它们具有一些非常明显的衍生特征或共源性状。其中包括:花的两侧对称性(合轴对称)、许多倒置的花朵、几乎总是高度特化的唇瓣、雄蕊和雌蕊融合在一起,以及极其微小的种子。", # 段落的主要文本
    "section": "描述" # 可选字段,表示段落所属的小节标题;如果非空,则默认会将此字段附加到标题后,形成 {title}: {section}
    ... # 您还可以添加其他字段以方便分析,但这些字段实际上不会被使用
}

如果您按照上述格式创建自己的段落文件,那么与 Atlas 配合使用应该会非常简单。

目前,我们无法开源论文中使用的 Common Crawl 索引。

模型

我们正在开源基础、大、XL 和 XXL 尺寸的预训练 Atlas 模型。这些模型同时包含预训练的检索器和阅读器权重。

此外,我们还开源了性能最强的全微调 Natural Questions Atlas 模型,供希望进行最先进问答推理(或在其他问答任务上进行微调)的用户使用。

模型可按如下方式下载:

python preprocessing/download_model.py --model {模型下载键} --output_directory ${DATA_DIR} 

这将会把请求的模型下载到 ${DATA_DIR}/{模型下载键},随后您可以通过将 ${DATA_DIR}/{模型下载键} 传递给 --model_path 参数,在脚本中使用这些模型。

下表详细列出了可用的模型:

模型 模型下载键 描述 参数量(阅读器 / 检索器)
Atlas-xxl models/atlas/xxl 预训练的 Atlas XXL 模型 11B / 110M
Atlas-xl models/atlas/xl 预训练的 Atlas XL 模型 3B / 110M
Atlas-large models/atlas/large 预训练的 Atlas 大模型 770M / 110M
Atlas-base models/atlas/base 预训练的 Atlas 基础模型 220M / 110M
NQ 微调后的 Atlas-xxl models/atlas_nq/xxl 经 Natural Questions 数据微调的 Atlas XXL 模型 11B / 110M
NQ 微调后的 Atlas-xl models/atlas_nq/xl 经 Natural Questions 数据微调的 Atlas XL 模型 3B / 110M
NQ 微调后的 Atlas-large models/atlas_nq/large 经 Natural Questions 数据微调的 Atlas 大模型 770M / 110M
NQ 微调后的 Atlas-base models/atlas_nq/base 经 Natural Questions 数据微调的 Atlas 基础模型 220M / 110M

预构建索引

如果未提供索引,Atlas 会自动构建一个索引。这种方式虽然方便,但耗时较长,尤其是在 GPU 工作节点较少或索引规模较大的情况下。

因此,我们提供了针对预训练 Atlas 检查点以及经 Natural Questions 数据微调的 Atlas 检查点的 wiki-dec2018 语料的预计算索引供下载。

这些索引可按如下方式下载:

python preprocessing/download_index.py --index {索引下载键} --output_directory ${DATA_DIR} 

上述脚本将下载请求的预训练索引,并将其保存到 ${DATA_DIR}/{索引下载键}。随后,您可以通过将这些索引传递给 --load_index_path 参数,在训练或评估中使用它们。关于索引的保存和加载的更多细节,请参阅 检索与索引详情

可供下载的索引如下表所示:

索引 索引下载键 对应模型 描述
Atlas XXL wiki-dec2018 索引 indices/atlas/wiki/xxl models/atlas/xxl 针对预训练 Atlas-xxl 模型的 wiki-dec2018 语料预计算索引
Atlas XL wiki-dec2018 索引 indices/atlas/wiki/xl models/atlas/xl 针对预训练 Atlas-xl 模型的 wiki-dec2018 语料预计算索引
Atlas large wiki-dec2018 索引 indices/atlas/wiki/large models/atlas/large 针对预训练 Atlas-large 模型的 wiki-dec2018 语料预计算索引
Atlas base wiki-dec2018 索引 indices/atlas/wiki/base models/atlas/base 针对预训练 Atlas-base 模型的 wiki-dec2018 语料预计算索引
Atlas-nq XXL wiki-dec2018 索引 indices/atlas_nq/wiki/xxl models/atlas_nq/xxl 针对经 Natural Questions 数据微调的 Atlas xxl 模型的 wiki-dec2018 语料预计算索引
Atlas-nq XL wiki-dec2018 索引 indices/atlas_nq/wiki/xl models/atlas/xl 针对经 Natural Questions 数据微调的 Atlas xl 模型的 wiki-dec2018 语料预计算索引
Atlas-nq large wiki-dec2018 索引 indices/atlas_nq/wiki/large models/atlas/large 针对经 Natural Questions 数据微调的 Atlas 大模型的 wiki-dec2018 语料预计算索引
Atlas-nq base wiki-dec2018 索引 indices/atlas_nq/wiki/base models/atlas/base 针对经 Natural Questions 数据微调的 Atlas 基础模型的 wiki-dec2018 语料预计算索引

任务

Atlas 可以在任何可以表示为“seq2seq”格式的监督学习任务上进行训练(或评估),其中输入是一个由一个或多个标记组成的序列,称为query,而输出则是由一个或多个标记组成的序列,称为target。 例如,一个 query 可能是一个问题,如“百慕大三角在哪里?”;而对应的 target 则是该问题的答案,“北大西洋西部海域”。 这种建模方式对于使用 T5 或 BART 等模型的用户来说会很熟悉。凡是可以使用这些模型的地方,Atlas 也同样适用,并且可以使用完全相同的数据:Atlas 将自行学会从其检索索引中检索段落——无需用于将段落与 (query, target) 对关联的标注。

Atlas 的代码库通过命令行参数 --task 来配置当前执行的任务以及要调用的评估指标。 我们实现了一个 base 任务,仅提供最基本的 seq2seq 训练支持,但同时也为掩码语言建模 (mlm)、语言建模 (lm)、维基百科章节生成 (section)、开放域问答 (QA)、选择题问答 (multiple_choice)、事实核查 (fever) 以及 KILT 套件 (kilt) 提供了更全面的功能。 所有任务都期望输入数据采用 jsonl 格式,但具体的字段名称因任务而异。部分任务还具有额外的命令行参数和专门的评估方法。 添加新任务非常简单,具体说明请参见 这里

以下将更详细地介绍各个任务,大多数任务在 examples/{task}/ 目录下都有示例命令(点击展开)。

基础任务

这是最基础的任务,可能并不是您的最佳选择,尤其是当您的任务与已实现的其他任务非常相似时。 通过向 train.pyevaluate.py 传递 --task base 参数即可指定此任务。 该任务的训练/验证/测试数据应由 jsonl 文件组成,需以空格分隔的形式传入 train.pyevaluate.py,例如 --train_data train_file_1.jsonl train_file_2.jsonl--eval_data eval_file_1.jsonl eval_file_2.jsonl 等。 输入文件应包含 query 字段,用于存储输入查询字符串,以及 target 字段,用于存储输出目标字符串,例如:

{"query": "输入到 Atlas 的内容", "target": "希望 Atlas 生成的内容"}

评估循环会计算评估损失,以及 Atlas 生成的输出与目标完全匹配的验证数据样本所占的比例。 如果您向脚本传递 --write_results 参数,Atlas 在验证数据上的预测结果将以如下格式写入保存检查点的目录:

{"query": "输入到 Atlas 的内容", "answers": ["希望 Atlas 生成的内容"], "generation": "Atlas 对该查询的预测结果", "passages": ["检索到的段落列表"]}

掩码语言建模

掩码语言建模任务实现了由 T5 提出的掩码语言建模预训练任务。这也是我们在论文中用来预训练主模型 Atlas 的任务。 通过向 train.py 传递 --task mlm 参数即可指定此任务。 该任务的训练/验证/测试数据应由 jsonl 文件组成,需以 --train_data train_file_1.jsonl train_file_2.jsonl--eval_data eval_file_1.jsonl eval_file_2.jsonl 等形式传入 train.py。 这些文件应由具有以下格式的 JSON 对象组成:

{
  "text": "需要施加噪声并训练去噪的文本片段",
  "id": "该文本片段的唯一标识符"
  ... # 您还可以保留其他字段以便于分析,但这些字段实际上并不会被使用
}

其设计意图是,您可以将用于检索语料库的文件(通过 --passages 传递)直接用作训练数据。 该任务会应用 T5 的噪声函数处理 text 字段,从而自动创建输入和目标生成内容。 MLM 任务还会阻止 Atlas 检索正在尝试去噪的那篇文档。它通过过滤掉检索结果中与正在去噪的实例具有相同 id 字段的文档来实现这一点。如果去噪训练数据与 Atlas 正在检索的文档来自同一语料库,这一功能就显得尤为重要。 该任务具有以下特定于任务的参数:

  --mlm_noise_density MLM_NOISE_DENSITY
      输入文本中应被掩码跨度覆盖的比例(默认:0.15)
  --mlm_mean_noise_span_length MLM_MEAN_NOISE_SPAN_LENGTH
      MLM 掩码跨度的平均长度(默认:3)
  --min_words_per_lm_instance MIN_WORDS_PER_LM_INSTANCE
      如果实例中的词数少于此值,则会跳过该实例,不参与 MLM/LM/章节生成任务(默认:无)

如果您传递 --write_results 参数,Atlas 会将其填空预测结果写入文件。 在评估过程中,Atlas 会记录以下 MLM 任务的评估指标:

  • eval_loss:生成的 MLM 填空跨度的评估损失
  • accuracy:完美去噪的填空跨度所占比例
  • f1:正确去噪的填空跨度的 token F1 分数
  • rouge_1:生成的填空跨度相对于黄金参考掩码跨度的 ROUGE-1 分数
  • rouge_2:生成的填空跨度相对于黄金参考掩码跨度的 ROUGE-2 分数
  • rouge_L:生成的填空跨度相对于黄金参考掩码跨度的 ROUGE-L 分数

语言建模

通过向 train.py 传递 --task lm 参数,Atlas 可以被训练为执行从左到右的语言建模任务。 该任务的训练/验证/测试数据应由 jsonl 文件组成,需以 --train_data train_file_1.jsonl train_file_2.jsonl--eval_data eval_file_1.jsonl eval_file_2.jsonl 等形式传入 train.py。 这些文件应由具有以下格式的 JSON 对象组成:

{
  "text": "需要训练 Atlas 生成的文本片段",
  "id": "该文本片段的唯一标识符"
  ... # 您还可以保留其他字段以便于分析,但这些字段实际上并不会被使用
}

其设计意图是,您可以将用于检索语料库的文件(通过 --passages 传递)直接用作训练数据。 该任务会自动对 text 字段进行预处理,将其随机分为两部分:左侧作为条件上下文,右侧则作为 Atlas 模型将被训练生成的后续内容。 LM 任务还会阻止 Atlas 检索正在尝试生成的同一文档。它通过过滤掉检索结果中与正在生成的实例具有相同 id 字段的文档来实现这一点。如果去噪训练数据与 Atlas 正在检索的文档来自同一语料库,这一功能就显得尤为重要。

该任务具有以下特定于任务的参数:

  --min_words_per_lm_instance MIN_WORDS_PER_LM_INSTANCE
      具有少于 min_words_per_lm_instance 个词的实例将被跳过,不参与 MLM/LM/Section 生成(默认:无)
  --min_lm_context_ratio MIN_LM_CONTEXT_RATIO
      将文本分割为两个部分用于语言建模。左侧部分作为条件上下文,右侧部分用于生成。左侧部分必须大于右侧部分的 min_lm_context_ratio(默认:0.5)
  --max_lm_context_ratio MAX_LM_CONTEXT_RATIO
      将文本分割为两个部分用于语言建模。左侧部分作为条件上下文,右侧部分用于生成。左侧部分必须小于右侧部分的 max_lm_context_ratio(默认:0.5)

如果您传递 --write_results,Atlas 会将其语言模型预测结果写入文件。

在评估过程中,Atlas 会记录以下语言模型评估指标:

  • eval_loss:参考数据中续写的评估读取损失
  • accuracy:完全正确预测的续写比例
  • f1:生成续写的 token F1 分数,表示正确生成的比例
  • rouge_1:生成续写相对于黄金参考续写的 ROUGE-1 分数
  • rouge_2:生成续写相对于黄金参考续写的 ROUGE-2 分数
  • rouge_L:生成续写相对于黄金参考续写的 ROUGE-L 分数

维基百科章节生成

通过向 train.py 传递 --task section,可以训练 Atlas 根据维基百科条目的标题和章节标题生成相应段落的文本。

此任务的训练/验证/测试数据应由 JSONL 文件组成,其格式应与维基百科转储中的 text-list-100-sec.jsonl 文件一致。这些文件可以通过遵循可下载的数据和模型中的说明获取,例如训练文件:enwiki-dec2018/text-list-100-sec.jsonl。这些文件应由每行一个 JSON 对象组成,格式如下:

{
  "id": "3793043", 
  "title": "百慕大三角",
  "section": "指南针偏差",
  "text": " 指南针问题是许多百慕大三角事件中经常提到的现象之一。尽管有人推测该地区可能存在异常的局部磁异常,但至今尚未发现此类异常。实际上,指南针会因与地磁极的关系而产生自然的偏差,这是航海家们几个世纪以来就已知的事实。"
}

该任务会自动将输入查询格式化为“{标题}, {章节}”——例如,在此示例中,输入到 Atlas 的内容将是 百慕大三角, 指南针偏差。输出将是示例中的 text 字段。 section 任务会防止 Atlas 生成与其检索到的同一段落相同的内容。它通过过滤掉检索结果中与正在生成的实例具有相同 id 字段的段落来实现这一点。

该任务具有以下特定于任务的参数:

  --min_words_per_lm_instance MIN_WORDS_PER_LM_INSTANCE
      具有少于 min_words_per_lm_instance 个词的实例将被跳过,不参与 MLM/LM/Section 生成(默认:无)

如果您传递 --write_results,Atlas 会将其生成的维基百科章节文本预测结果写入文件。

在评估过程中,Atlas 会记录以下 section 任务的评估指标:

  • eval_loss:参考数据中续写的评估读取损失
  • accuracy:完全正确预测的续写比例
  • f1:生成续写的 token F1 分数,表示正确生成的比例
  • rouge_1:生成续写相对于黄金参考续写的 ROUGE-1 分数
  • rouge_2:生成续写相对于黄金参考续写的 ROUGE-2 分数
  • rouge_L:生成续写相对于黄金参考续写的 ROUGE-L 分数

开放域问答(如 NaturalQuestions、TriviaQA、TempLama)

通过向 train.pyevaluate.py 传递 --task qa,可以训练 Atlas 回答开放域问答问题。在快速入门与代码库概览部分有一个 QA 的示例。

我们在论文中使用此任务处理 NaturalQuestions、TriviaQA 和 TempLama 数据集。

此任务的训练/验证/测试数据应由 JSONL 文件组成,这些文件应以 --train_data train_file_1.jsonl train_file_2.jsonl--eval_data eval_file_1.jsonl eval_file_2.jsonl 等形式传递给 train.py。 每个文件应包含一行 JSON 实例,格式如下:

{
  "question": "百慕大三角在哪里",
  "answers": ["北大西洋西部"],
   ... # 您可以保留其他字段以便于分析,但这些字段不会实际用于训练
}

问题将根据任务特定参数 --qa_prompt_format 进行格式化,默认值为 question: {question} answer: <extra_id_0>。 例如,上述问题将自动格式化为输入到 Atlas 的查询:question: 百慕大三角在哪里 answer: <extra_id_0>。 监督目标来自 target 字段。如果该字段不存在,则监督目标将从 answers 字段中的可用答案中随机选择,并格式化为 <extra_id_0> {answer}

如果您传递 --write_results,Atlas 会将其预测的答案写入文件。

在评估过程中,Atlas 会记录以下开放域问答的评估指标:

  • eval_loss:评估答案的评估读取损失
  • exact_match:生成答案的开放域问答精确匹配分数
  • f1:生成答案的开放域问答 F1 分数

Natural Questions 和 TriviaQA

您可以通过运行以下命令下载 NaturalQuestions 和 TriviaQA 数据:

python preprocessing/prepare_qa.py --output_directory ${DATA_DIR} 

这将下载 train.jsonltrain.64-shot.jsonl(我们使用的少量样本训练数据集)、dev.jsonltest.jsonl,并将其保存到 ${DATA_DIR}/data/nq_data${DATA_DIR}/data/triviaqa_data 中。

有关使用 NQ 维基百科索引进行少量样本和标准微调及评估的示例脚本可在 examples/nq 中找到。只需替换训练/验证/测试文件,即可将该脚本用于 TriviaQA。

TempLama

我们基于 TempLAMA 数据集定义了一个完形填空式问答任务,用于评估索引的忠实性和时间迁移能力。

您可以通过运行以下脚本下载 TempLAMA 数据并创建和格式化我们衍生的数据集:

python preprocessing/prepare_templama.py --output_directory ${DATA_DIR} 

这将创建文件 temp_lama.train.2017.jsonltemp_lama.valid.2017.jsonltemp_lama.test.2017.jsonltemp_lama.train.2020.jsonltemp_lama.valid.2020.jsonltemp_lama.test.2020.jsonl,位于 ${DATA_DIR}/data/templama_data/ 目录下。 这些文件将包含完形填空题,答案会根据年份有所不同。

运行 TempLama 训练和评估的示例脚本可以在 examples/templama 中找到。(注意使用了 qa_prompt_format {question},它会关闭 TriviaQA 和 NQ 所使用的自动 QA 提示格式化功能)

多项选择题回答(例如 MMLU)

Atlas 可以通过在 train.pyevaluate.py 中添加 --task multiple_choice 参数来训练回答多项选择题。我们在 MMLU 的实验中使用了这一任务。 该任务的训练/验证/测试数据应由 jsonl 文件组成,这些文件应作为 --train_data train_file_1.jsonl train_file_2.jsonl--eval_data eval_file_1.jsonl eval_file_2.jsonl 等参数传递给 train.py。 每个文件每行应包含一个 JSON 实例,格式如下:

{
  "question": "以下哪个是包含垂体的体腔?", 
  "options": {
    "A": "腹腔",
    "B": "颅腔",
    "C": "胸膜腔", 
    "D": "脊髓腔"
    ... # 你可以有更多的(或更少的)选项,只要它们的键是按字母顺序连续的大写字母,从 A 开始即可
  }, 
  "answer": "B",
  ... # 你还可以保留其他字段以便于分析,但这些字段并不会被实际使用
}

这些数据会被自动格式化为 Atlas 的输入查询,形式为 question: {question} answers: (A) {options['A']} (B) {options['B']} (C) {options['C']} (D) {options['D']} Answer: <extra_id_0>,目标生成格式为 <extra_id_0> {answer letter}。 上述示例会被格式化为:question: {以下哪个是包含垂体的体腔? answers: (A) 腹腔 (B) 颅腔 (C) 胸膜腔 (D) 脊髓腔 Answer: <extra_id_0>,目标生成为 {extra_id_0} B

多项选择问答任务有以下特定参数:

  --multiple_choice_num_options
      多项选择问答中选项的数量(MMLU 是 4 个)(默认值:4)
  --multiple_choice_train_permutations {single,cyclic,all}
      在进行多项选择训练时(例如 MMLU),是否启用答案顺序的排列组合。这有助于消除模型对任意答案顺序的偏好,从而提升效果。建议使用 'all' 模式。single:不进行排列;cyclic:循环排列;all:所有可能的答案排列组合。(默认值:single)
  --multiple_choice_eval_permutations {single,cyclic,all}
      在进行多项选择评估时(例如 MMLU),是否启用答案顺序的排列组合。这同样可以减少模型对答案顺序的偏好,从而提升效果。使用 'all' 模式效果最佳,但速度较慢。'cyclic' 是一个不错的折中方案。single:不进行排列;cyclic:循环排列;all:所有可能的答案排列组合。(默认值:single)

排列选项会自动复制输入数据,并对答案顺序进行排列(例如,“A”变为“颅腔”,“B”变为“胸膜腔”等)。 当监督数据量非常少(或零样本情况下),这种方法可以显著提升效果。 如果你使用了 --multiple_choice_eval_permutations 参数中的 cyclicall 选项,代码会自动对不同排列组合的结果进行汇总,为你提供最终的评估结果。 关于排列去偏化的更多细节,请参阅论文 Atlas: Few-shot Learning with Retrieval Augmented Language Models 的附录。

如果你指定了 --write_results,Atlas 会将其预测的答案写入文件,格式如下:

{
  "question": "应用提示模板后的输入内容",
  "generation": "在对排列组合结果进行汇总后,概率最高的答案字母",
  "choice_probs": "每个答案选项的概率(归一化到总选项数)",
  "all_probs": "未汇总前的所有答案排列组合的概率",
  "permutations": ["针对每种答案排序组合的预测对象列表"]
}

MMLU

专门用于运行 MMLU 实验的 ReadMe 文档可在 这里 查看。 我们提供了一个用于下载和预处理 MMLU 数据的工具,并且在 examples/mmlu 中提供了针对我们所探索的各种实验设置的示例脚本。 这些内容在 MMLU 的专用 ReadMe 文档中都有详细说明。

FEVER 事实核查

Atlas 可以通过使用 --task fever 参数在 train.pyevaluate.py 中训练,使其能够根据语料库将文本陈述分类为“支持”、“反驳”或“信息不足”,例如用于 FEVER 任务。 你可以通过运行以下脚本来下载 FEVER 数据:

python preprocessing/prepare_fever.py --output_directory ${DATA_DIR} 

该任务的训练/验证/测试数据应由 jsonl 文件组成,这些文件应作为 --train_data train_file_1.jsonl train_file_2.jsonl--eval_data eval_file_1.jsonl eval_file_2.jsonl 等参数传递给 train.py。 每个文件每行应包含一个 JSON 实例,格式如下:

{
  "claim": "需要评估的陈述", 
  "label": "要么是 'SUPPORTS'、'REFUTES',要么是 'NOT ENOUGH INFO'",
   ... # 你可以保留其他字段以便于分析,但这些字段并不会被实际使用
}

Atlas 会自动处理这些实例,并将其格式化为输入 question: {claim} answer: <extra_id_0> 和输出 <extra_id_0> {true, false 或 maybe}。 如果你指定了 --write_results,Atlas 会将其预测的标签写入文件。 在评估过程中,Atlas 还会记录以下开放域问答的评估指标:

  • accuracy: 模型正确分类的陈述数量。

KILT

Atlas 可以通过在 train.pyevaluate.py 中使用 --task kilt 参数来训练执行 KILT 任务。 KILT 数据可以从 这里 获取。

该任务的训练/验证/测试数据应由 JSONL 文件组成,这些文件应作为 --train_data train_file_1.jsonl train_file_2.jsonl--eval_data eval_file_1.jsonl eval_file_2.jsonl 等参数传递给 train.py。 每个文件应每行包含一个 JSON 实例,格式如下(即代码库可以直接接受 KILT 格式):

{'id': # 原始数据点的 ID,如果有的话;否则为唯一 ID
 'input': # 问题 / 主张 / 句子 / 等等
 'output': [ # 每个元素可能包含答案、证据来源或两者
    {
    'answer': # 文本形式的答案
    'provenance': [
        # 针对答案的 KILT 数据集中的证据集合
        {
            'wikipedia_id':  # *必须* 
            'title': 
            'section': 
            'start_paragraph_id': 
            'start_character': 
            'end_paragraph_id':
            'end_character': 
            'bleu_score': # 相对于原始证据的 BLEU 分数
            'meta': # 数据集/任务特定的元数据
        }
        ] 
      }
    ]
 'meta': # 数据集/任务特定的元数据
 }

Atlas 会自动根据 input 字段将这些实例处理为 Atlas 查询输入,并根据 answer 字段生成目标输出。

如果您传递 --write_results 参数,Atlas 会将其预测标签写入文件。

在评估过程中,Atlas 将记录以下开放域问答任务的评估指标:

  • accuracy: 生成内容与参考答案完全匹配的频率
  • exact_match: 在应用开放域问答标准化后,生成内容与参考答案完全匹配的频率
  • f1: 生成内容与参考答案之间的 token 级别 F1 分数重叠

检索与索引详情

以下部分提供了关于检索和索引的更多细节。

如前言中简要提及,Atlas 代码利用现代大型神经网络训练的并行特性来处理检索。具体来说,所有现代 GPU 上的训练(以及推理)都需要多个 GPU 工作节点来进行并行计算。

Atlas 利用了这一已有的分布式架构。它会将检索索引(文档片段 + 嵌入后的文档片段)划分为 N 个大小相等的分片,每个 GPU 工作节点负责一个分片。默认情况下,检索完全在 GPU 上使用 PyTorch 进行精确搜索(不使用 FAISS 的近似搜索),尽管如此,由于搜索是在所有 GPU 工作节点之间并行进行的,因此速度仍然很快(假设 GPU 数量足够)。

了解检索的具体实现机制并非运行该代码库的必要条件,但可能有助于针对特定需求调整代码库。因此,我们在此提供简要说明。为简单起见,我们假设正在进行开放域问答任务,每 GPU 的批处理大小为 1,共有 W 个 GPU 工作节点,检索索引中总共有 N 条文档。

检索器需要完成两个高层次的功能:

1. 构建/刷新嵌入

构建或刷新索引涉及为检索索引中的每条文档计算嵌入,然后将其保存在内存中,以便在后续检索时能够快速计算最大内积或最近邻。(假设我们有 W 个 GPU 工作节点,检索索引中总共有 N 条文档)

请注意,每个 GPU 工作节点负责 N/W 条文档的分片。文档嵌入的计算非常简单,过程如下:

  • 暂停正在进行的模型训练
  • 每个工作节点并行计算其分片中文档的嵌入,按批次迭代处理,并将结果保存在一个大型 PyTorch 张量中(FAISS 支持也可用,但稍后讨论)。
  • 当所有工作节点都完成其分片的嵌入计算后,我们可以继续进行模型训练。

N 可能非常大(1000万到1亿),因此嵌入所有文档的速度会很慢。然而,由于可以在各个工作节点之间并行化,如果 GPU 数量足够,整个过程可以相对较快。尽管如此,对于大型索引,这个过程仍然可能耗时较长。我们提供了索引保存功能,可以将索引保存到磁盘,以避免频繁重建索引。此外,随着检索器的不断训练,缓存的嵌入可能会过时或“陈旧”,需要重新计算,这会带来额外的成本。有关减少或避免频繁刷新索引的方法,请参阅处理陈旧索引的策略

2. 执行分布式检索

Atlas 在训练循环中执行检索:即在前向传播过程中调用检索功能。下面我们将简要描述 Atlas 前向传播中的步骤,包括如何完成检索: (假设我们有 W 个 GPU 工作节点,检索索引中总共有 N 条文档,为简化起见,假设每 GPU 的训练批大小为 1,且任务为问答任务)。

  • 每个工作节点会收到一个问题,并将其嵌入为查询向量。
  • 然后执行全归约操作,使得每个工作节点都拥有全部 W 个查询向量的副本(即当前小批量中所有问题的查询向量)。
  • 接着,每个工作节点在其负责的文档分片上对小批量中的所有查询向量执行 GPU 上的最大内积搜索。
  • 每个工作节点会将其分片中每个查询的前 K 个结果通过归约操作发送回嵌入该查询的 GPU。
  • 最终,每个工作节点会获得其查询对应的前 W × K 个结果,从中选出真正的前 K 个结果。
  • 检索过程至此完成,随后可以继续进行标准的分布式数据并行前向传播(即运行模型前向传播、计算梯度、在各工作节点之间聚合梯度,并更新模型参数)。

Flat 与 Faiss

Atlas 实现了两种索引模式。默认情况下,我们使用精确搜索(“Flat”)索引进行检索,检索过程在 GPU 上通过纯 PyTorch 完成。我们还支持 FAISS 模式,该模式有助于为超大规模索引节省 GPU 内存,或在 GPU 内存非常有限的情况下使用。FAISS 是一个用于快速近似最近邻搜索的库。由于我们的检索是在 GPU 上进行的,通常不需要进一步的搜索加速,但 FAISS 可以用来压缩内存中索引的大小,这对于非常大的索引可能会有所帮助。

要使用的模式由 --index_mode {"flat"|"faiss"} 指定。对于大多数用例,flat 索引就足够了,而且通常更为推荐。

如果使用 faiss 索引,用户应指定要使用的 faiss 索引类型,可选如下:

  --faiss_index_type {ivfflat,flat,ivfsq,ivfpq,pq}
      IVFFlat、IndexFlatIP、IVFScalarQuantizer、IndexPQ 或 IndexIVFPQ(需配合 faiss-gpu 使用;默认值:flat)
  --faiss_code_size FAISS_CODE_SIZE
      PQ/SQ 量化参数(默认值:无)

使用 faiss 索引时的一个良好默认设置是 --faiss_index_type ivfpq --faiss_code_size 16。这将使用 IVF-PQ 索引,其中 IVF 聚类的数量设置为每个分片嵌入数量的平方根,PQ 代码长度为 16。有关此索引结构的更多详细信息,请参阅 faiss 文档 FAISS

索引的保存与加载

索引(段落和嵌入分片)可以保存到磁盘并在需要时加载,以避免重新计算它们。有关一些可下载的索引,请参阅上文

可以通过 --save_index_path {path/to/directory/save/index/in} 开启索引保存功能,该命令会创建一个目录,并将每个工作进程的嵌入分片(作为磁盘上的 PyTorch 张量)和段落分片(作为 pickle 文件)保存到该目录中。

要加载索引,只需传递 --load_index_path {path},即可从指定路径加载索引。

保存和加载功能同时适用于 flatfaiss 模式。

为了便于在使用的工作进程数与创建索引时不同的情况下加载索引,我们可以配置 --save_index_n_shards N,这会将索引保存为 N 个分片(例如,如果有 32 个工作进程,可以传递 --save_index_n_shards 128,将索引保存为 128 个分片)。当再次尝试加载索引时,比如使用 64 个工作进程,代码会自动判断每个工作进程应加载 2 个保存的文件。(注意:此功能仅适用于 flat 索引——对于 faiss 索引,只能加载与保存时工作进程数相同的索引。)

处理过时索引的策略

随着检索器的训练,存储在内存中的段落嵌入会逐渐过时。这会影响检索的准确性,并且在长时间内可能导致训练效果不佳或不稳定。Atlas 提供三种方法来应对这一问题:

  1. 索引刷新:最简单但也最昂贵的方法是使用最新的检索器嵌入器重新计算嵌入。索引刷新频率由 --refresh_index 参数控制。格式为:startstep-endstep:refreshrate,例如 --refresh_index 0-1000:500,1000-10000:1000 表示在前 1000 步中每 500 步刷新一次索引,随后从第 1000 步到第 10000 步每 1000 步刷新一次。也可以只传递一个数字,如 --refresh_index 100 表示每 100 步刷新一次索引。传递 --refresh_index -1 则表示永不刷新。我们通常在大型数据集和预训练中使用此设置。
  2. 带重排序的超额检索:在这种方法中,我们不刷新索引,而是检索前 L 个段落(其中 L > K),然后使用最新的嵌入器对这 L 个段落进行实时重排序,并从中选出前 K 个。如果真实的前 K 个确实包含在过时的前 L 个中,这种方法效果很好。要使用此方法,需传递 --retrieve_with_rerank 并指定 --n_to_rerank_with_retrieve_with_rerank L。此方法可以与索引刷新结合使用,以减少两次刷新之间的过时程度。
  3. 查询端微调:为了避免过时问题,我们可以固定检索器的段落嵌入器,仅训练查询嵌入器。如果训练数据量很大,这种方法会牺牲检索性能,但在少样本场景下效果较好。要启用此模式,需传递 --query_side_retriever_training。注意:通常检索器的段落编码器和查询编码器会共享参数——而此模式则例外,我们会解除参数绑定,以保持段落编码器不变。

仅检索模式

Atlas 在评估时可以完全以检索模式运行。这对于希望使用快速、可扩展、易于部署且支持 GPU 的密集型检索器的用户来说非常有用。

在此模式下(仅适用于 evaluate.py),不会加载阅读语言模型,脚本将执行检索操作,并在传递了 --write_results 标志的情况下将检索结果写入文件。

要使用此模式,需在 evaluate.py 中传递 --retrieve_onlyexamples/nq/retrieve_only.sh 中提供了一个使用此模式进行 NaturalQuestions 数据集检索的示例。

使用预先检索或缓存的段落

在某些情况下,用户可能已经完成了检索,并希望为其数据集缓存检索结果,或者事先知道最相关的段落,因此无需再进行检索。

在这种情况下,可以通过以下两种方式让 Atlas 忽略检索步骤,直接使用用户指定的段落:1) 传递 --use_file_passages 标志;2) 在传入的训练/评估文件中包含一个名为 passages 的 JSON 字段,其格式如下(以 qa 任务为例):

(点击展开查看示例)
{
  "question": "百慕大三角在哪里",
  "answers": ["北大西洋西部海域"],
  "passages": [
    {
      "text": "第一段落的内容",
      "title": "第一段落的标题",
      "id": "第一段落的唯一标识符"
      ... # 其他字段也可存在,但不会被使用
    },
    {
      "text": "第二段落的内容",
      "title": "第二段落的标题",
      "id": "第二段落的唯一标识符"
    },
    ... # 如有需要,可添加更多段落
  ]
}

其他功能

以下是 Atlas 为高级用户提供的其他功能:

封闭书本模式

Atlas 可以作为标准的非检索增强 T5 模型运行,在文献中常被称为“封闭书本”模式。这对于进行基线实验以及验证您的模型是否确实从针对特定任务的检索增强中受益很有帮助。传递 --closed_book 参数即可进行封闭书本训练,并忽略检索到的段落。

指定格式

可以通过注入格式字符串来更精细地控制输入如何呈现给 Atlas 模型:

  --encoder_format ENCODER_FORMAT
    阅读器编码器预处理的格式字符串(默认: "{query} title: {title} context: {text}")
  --retriever_format RETRIEVER_FORMAT
    检索器编码器预处理的格式字符串(默认: "{title} {text}")

例如,传递 --encoder_format "{query} text: {text}" 将不会把检索到的段落标题传递给阅读器模型。

实现您自己的任务

要为 Atlas 实现新任务,有两种选择:最简单的方法是使用已实现的任务之一对您的任务进行预处理或格式化,使其兼容(base 任务应支持几乎所有潜在用例)。

另一种方法是在 src/tasks/your_task_name.py 中实现您自己的任务,并在 src/tasks/__init__.py 中将其导入。

请参阅 src/tasks/qa.py 以获取示例。

process 函数接受传递给 --train_data--eval_data 的原始解析后的 jsonl 对象,并应返回一个字典,包含 {query: "传递给 Atlas 的查询", "target": "目标字符串", "passages": [黄金检索段落列表,可以为空]}

evaluate 函数接受任务的预测生成和参考答案,并返回一个特定于任务的评估分数字典,代码库会针对所有评估实例计算这些分数的平均值。

命令行参数完整列表:

点击展开
用法: train.py/evaluate.py [-h] [--name NAME] [--checkpoint_dir CHECKPOINT_DIR] [--model_path MODEL_PATH] [--per_gpu_batch_size PER_GPU_BATCH_SIZE] [--per_gpu_embedder_batch_size PER_GPU_EMBEDDER_BATCH_SIZE] [--local_rank LOCAL_RANK]
                [--main_port MAIN_PORT] [--seed SEED] [--log_freq LOG_FREQ] [--eval_freq EVAL_FREQ] [--save_freq SAVE_FREQ] [--train_data TRAIN_DATA [TRAIN_DATA ...]] [--eval_data EVAL_DATA [EVAL_DATA ...]] [--write_results]
                [--dont_write_passages] [--load_index_path LOAD_INDEX_PATH] [--save_index_path SAVE_INDEX_PATH] [--save_index_n_shards SAVE_INDEX_N_SHARDS] [--index_mode {flat,faiss}] [--faiss_index_type {ivfflat,flat,ivfsq,sq,pq}]
                [--faiss_code_size FAISS_CODE_SIZE] --reader_model_type
                {t5-small,t5-base,t5-large,t5-3b,t5-11b,google/t5-v1_1-base,google/t5-v1_1-large,google/t5-v1_1-xl,google/t5-v1_1-xxl,google/t5-base-lm-adapt,google/t5-large-lm-adapt,google/t5-xl-lm-adapt,google/t5-xxl-lm-adapt}
                [--text_maxlength TEXT_MAXLENGTH] [--target_maxlength TARGET_MAXLENGTH] [--n_context N_CONTEXT] [--passages PASSAGES [PASSAGES ...]] [--max_passages MAX_PASSAGES] [--retriever_model_path RETRIEVER_MODEL_PATH]
                [--retrieve_only] [--train_retriever] [--use_file_passages] [--retriever_n_context RETRIEVER_N_CONTEXT] [--gold_score_mode {evalnormsum,loop,ppmean,emdr,pdist,adist}] [--closed_book]
                [--temperature_score TEMPERATURE_SCORE] [--temperature_gold TEMPERATURE_GOLD] [--compute_crossattention_stats] [--filtering_overretrieve_ratio FILTERING_OVERRETRIEVE_RATIO]
                [--freeze_retriever_steps FREEZE_RETRIEVER_STEPS] [--query_side_retriever_training] [--retrieve_with_rerank] [--n_to_rerank_with_retrieve_with_rerank N_TO_RERANK_WITH_RETRIEVE_WITH_RERANK]
                [--decoder_format DECODER_FORMAT] [--decoder_prompt_format DECODER_PROMPT_FORMAT] [--encoder_format ENCODER_FORMAT] [--retriever_format RETRIEVER_FORMAT] [--generation_max_length GENERATION_MAX_LENGTH]
                [--generation_min_length GENERATION_MIN_LENGTH] [--generation_length_penalty GENERATION_LENGTH_PENALTY] [--generation_num_beams GENERATION_NUM_BEAMS] [--task {base,mlm,lm,multiple_choice,kilt,section,fever,qa}]
                [--mlm_noise_density MLM_NOISE_DENSITY] [--mlm_mean_noise_span_length MLM_MEAN_NOISE_SPAN_LENGTH] [--min_words_per_lm_instance MIN_WORDS_PER_LM_INSTANCE] [--min_lm_context_ratio MIN_LM_CONTEXT_RATIO]
                [--max_lm_context_ratio MAX_LM_CONTEXT_RATIO] [--qa_prompt_format QA_PROMPT_FORMAT] [--multiple_choice_num_options MULTIPLE_CHOICE_NUM_OPTIONS] [--multiple_choice_train_permutations {single,cyclic,all}]
                [--multiple_choice_eval_permutations {single,cyclic,all}] [--warmup_steps WARMUP_STEPS] [--total_steps TOTAL_STEPS] [--scheduler_steps SCHEDULER_STEPS] [--accumulation_steps ACCUMULATION_STEPS] [--dropout DROPOUT]
                [--lr LR] [--lr_retriever LR_RETRIEVER] [--clip CLIP] [--scheduler {linear,cosine,fixed}] [--weight_decay WEIGHT_DECAY] [--save_optimizer] [--epsilon EPSILON] [--alpha ALPHA] [--beta2 BETA2]
                [--refresh_index REFRESH_INDEX] [--shuffle] [--precision {fp16,fp32,bf16}] [--shard_optim] [--shard_grads] [--use_gradient_checkpoint_reader] [--use_gradient_checkpoint_retriever]

可选参数:
  -h, --help            显示此帮助信息并退出
  --name NAME           实验名称,也用作目录名(默认值:experiment_name)
  --checkpoint_dir CHECKPOINT_DIR
                        模型保存在此目录下(默认值:./checkpoint/)
  --model_path MODEL_PATH
                        用于初始化的预训练模型路径(传入 'none' 表示从 T5 和 Contriever 初始化)(默认值:无)
  --per_gpu_batch_size PER_GPU_BATCH_SIZE
                        每个 GPU/CPU 的训练批次大小。(默认值:1)
  --per_gpu_embedder_batch_size PER_GPU_EMBEDDER_BATCH_SIZE
                        Embedder 每个 GPU 的批次大小。(默认值:512)
  --local_rank LOCAL_RANK
                        用于分布式训练:本地进程编号(默认值:-1)
  --main_port MAIN_PORT
                        主端口(用于多节点任务)(默认值:-1)
  --seed SEED           初始化时使用的随机种子(默认值:0)
  --log_freq LOG_FREQ   训练过程中每 <log_freq> 步记录一次训练统计信息(默认值:100)
  --eval_freq EVAL_FREQ
                        训练过程中每 <eval_freq> 步评估一次模型(默认值:500)
  --save_freq SAVE_FREQ
                        训练过程中每 <save_freq> 步保存一次模型(默认值:5000)
  --train_data TRAIN_DATA [TRAIN_DATA ...]
                        以空格分隔的 JSONL 格式训练数据集路径列表(默认值:空列表)
  --eval_data EVAL_DATA [EVAL_DATA ...]
                        以空格分隔的 JSONL 格式评估数据集路径列表(默认值:空列表)
  --write_results       将评估结果保存到文件中(默认值:False)
  --dont_write_passages
                        如果要写结果,段落可能会占用大量空间,使用此标志可以不将段落写入导出的结果中(默认值:False)
  --load_index_path LOAD_INDEX_PATH
                        用于加载索引、段落嵌入和段落的路径(默认值:None)
  --save_index_path SAVE_INDEX_PATH
                        用于保存索引和/或嵌入的路径(默认值:None)
  --save_index_n_shards SAVE_INDEX_N_SHARDS
                        将索引保存为多少个分片文件。必须是工作进程数的整数倍。(默认值:128)
  --index_mode {flat,faiss}
                        使用扁平的 PyTorch 索引或 Faiss 索引来检索 k 个最近邻(默认值:flat)
  --faiss_index_type {ivfflat,flat,ivfsq,sq,pq}
                        IVFFlat、IndexFlatIP、IVFScalarQuantizer、ScalarQuantizer 或带有 faiss-gpu 的 IndexPQ(默认值:flat)
  --faiss_code_size FAISS_CODE_SIZE
                        PQ 量化参数(默认值:None)
  --reader_model_type {t5-small,t5-base,t5-large,t5-3b,t5-11b,google/t5-v1_1-base,google/t5-v1_1-large,google/t5-v1_1-xl,google/t5-v1_1-xxl,google/t5-base-lm-adapt,google/t5-large-lm-adapt,google/t5-xl-lm-adapt,google/t5-xxl-lm-adapt}
                        阅读器 FID 模型的 T5 架构,例如 google/t5-xl-lm-adapt(默认值:None)
  --text_maxlength TEXT_MAXLENGTH
                        输入文本片段(问题+段落拼接后)的最大 token 数。超过此长度的输入将被截断。(默认值:200)
  --target_maxlength TARGET_MAXLENGTH
                        训练模型时目标输出的最大 token 长度。超过此长度的目标将被截断。如果设置为 -1,则不进行截断(默认值:None)
  --n_context N_CONTEXT
                        传递给阅读器的 top k 段落数量(默认值:1)
  --passages PASSAGES [PASSAGES ...]
                        包含要索引和检索的段落的 JSONL 文件路径列表。如果使用 --load_index_path 加载已保存的索引,则此参数无效(默认值:None)
  --max_passages MAX_PASSAGES
                        要索引的最大段落数量。设置为 -1 表示读取段落文件中的所有段落(默认值:-1)
  --retriever_model_path RETRIEVER_MODEL_PATH
                        用于初始化的 Contriever 模型路径(如果传入 --model_path 参数,则覆盖此值)(默认值:facebook/contriever)
  --retrieve_only       传入此参数以防止加载阅读器,仅运行检索评估(默认值:False)
  --train_retriever     传入此参数以同时训练检索器和阅读器(默认值:False)
  --use_file_passages   使用训练或评估 JSONL 文件中 "passages" 字段中的段落,而不是通过检索获取段落(默认值:False)
  --retriever_n_context RETRIEVER_N_CONTEXT
                        用于训练检索器的 top k 段落数量(默认值:5)
  --gold_score_mode {evalnormsum,loop,ppmean,emdr,pdist,adist}
                        训练检索器的方法。`pdist` 是论文中对 `ppmean` 的称呼。`adist` 是论文中对 `evalnormsum` 的称呼(默认值:ppmean)
  --closed_book         不使用检索功能——退化为 T5 模型。如果设置了 n_context、n_context_retriever 和 encoder_format,则此选项会覆盖它们(默认值:False)
  --temperature_score TEMPERATURE_SCORE
                        检索器的 softmax 温度(默认值:0.01)
  --temperature_gold TEMPERATURE_GOLD
                        检索器蒸馏目标分布的 softmax 温度(默认值:0.01)
  --compute_crossattention_stats
  --filtering_overretrieve_ratio FILTERING_OVERRETRIEVE_RATIO
                        如果进行过滤,先按此比例超额检索 topK,然后再过滤掉不需要的结果。很有用;只有在处理不需过滤检索结果的任务时才设为 1(默认值:2)
  --freeze_retriever_steps FREEZE_RETRIEVER_STEPS
                        冻结检索器 n 步(默认值:-1)
  --query_side_retriever_training
                        传入此参数以启用查询端的检索器微调(解绑 Contriever 编码器中段落和查询编码器的参数,并冻结段落编码器。有助于避免索引刷新。(默认值:False)
  --retrieve_with_rerank
                        传入此参数以启用使用全新段落编码器进行重新排序的检索(默认值:False)
  --n_to_rerank_with_retrieve_with_rerank N_TO_RERANK_WITH_RETRIEVE_WITH_RERANK
                        当传入 --retrieve_with_rerank 时,需要重新排序的段落数量。数值越高越慢但越准确。推荐 64-128(默认值:128)
  --decoder_format DECODER_FORMAT
                        解码器的格式。模型将按照该格式进行训练,评估也将采用与 decoder_prompt_format 选项相反的格式(默认值:无)
  --decoder_prompt_format DECODER_PROMPT_FORMAT
                        解码器提示的格式,例如“{query} 的答案是什么?”(默认值:无)
  --encoder_format ENCODER_FORMAT
                        阅读器编码器预处理的格式字符串(默认值:{query} 标题:{title} 上下文:{text})
  --retriever_format RETRIEVER_FORMAT
                        检索器编码器预处理的格式字符串(默认值:{title} {text})
  --generation_max_length GENERATION_MAX_LENGTH
  --generation_min_length GENERATION_MIN_LENGTH
  --generation_length_penalty GENERATION_LENGTH_PENALTY
  --generation_num_beams GENERATION_NUM_BEAMS
  --task {base,mlm,lm,multiple_choice,kilt,section,fever,qa}
                        模型执行的任务。用于设置预处理、检索过滤、评估等。(默认值:无)
  --mlm_noise_density MLM_NOISE_DENSITY
                        输入文本中应被掩码跨度遮盖的比例(默认值:0.15)
  --mlm_mean_noise_span_length MLM_MEAN_NOISE_SPAN_LENGTH
                        MLM 掩码跨度的平均长度(默认值:3)
  --min_words_per_lm_instance MIN_WORDS_PER_LM_INSTANCE
                        对于 MLM/LM/Section Generation,如果实例中的单词数少于 min_words_per_lm_instance,则会跳过该实例(默认值:无)
  --min_lm_context_ratio MIN_LM_CONTEXT_RATIO
                        将文本分为两个部分进行语言建模。“左半部分作为条件上下文,右半部分用于生成。”左半部分必须大于右半部分的 min_lm_context_ratio。
                        (默认值:0.5)
  --max_lm_context_ratio MAX_LM_CONTEXT_RATIO
                        将文本分为两个部分进行语言建模。“左半部分作为条件上下文,右半部分用于生成。”左半部分必须小于右半部分的 max_lm_context_ratio。
                        (默认值:0.5)
  --qa_prompt_format QA_PROMPT_FORMAT
                        当使用 --task qa 时,如何将问题格式化为输入提示(默认值:问题:{question} 答案:<extra_id_0>)
  --multiple_choice_num_options MULTIPLE_CHOICE_NUM_OPTIONS
                        多项选择问答中有多少个选项(例如 MMLU 是 4 个)(默认值:4)
  --multiple_choice_train_permutations {single,cyclic,all}
                        在多项选择任务(如 MMLU)训练时,是否使用答案顺序的排列组合。这可以改善结果,消除模型对任意答案顺序的偏好。建议使用 all。
                        single:无排列;cyclic:循环排列;all:所有可能的答案排列组合。(默认值:single)
  --multiple_choice_eval_permutations {single,cyclic,all}
                        在多项选择任务(如 MMLU)评估时,是否使用答案顺序的排列组合。这可以改善结果,消除模型对任意答案顺序的偏好。最好使用 all,但速度很慢。
                        cyclic 是一个不错的折中方案。single:无排列;cyclic:循环排列;all:所有可能的答案排列组合。(默认值:single)
  --warmup_steps WARMUP_STEPS
                        学习率预热步数(默认值:1000)
  --total_steps TOTAL_STEPS
                        总训练步数(默认值:1000)
  --scheduler_steps SCHEDULER_STEPS
                        调度器的总步数。如果未指定,则 scheduler_total_step = total_step。(默认值:无)
  --accumulation_steps ACCUMULATION_STEPS
                        梯度累积(默认值:1)
  --dropout DROPOUT     掉落率(默认值:0.1)
  --lr LR               学习率(默认值:0.0001)
  --lr_retriever LR_RETRIEVER
                        检索器的学习率(默认值:1e-05)
  --clip CLIP           梯度裁剪(默认值:1.0)
  --scheduler {linear,cosine,fixed}
                        使用的学习率调度策略(默认值:cosine)
  --weight_decay WEIGHT_DECAY
                        训练中应用的权重衰减量(默认值:0.1)
  --save_optimizer      传入此标志以在保存的检查点中保存优化器状态(默认值:False)
  --epsilon EPSILON     AdamW 的 epsilon 值(默认值:1e-06)
  --alpha ALPHA         AdamW 的 alpha 值(默认值:1.0)
  --beta2 BETA2         AdamW 的 beta2 值(默认值:0.999)
  --refresh_index REFRESH_INDEX
                        索引刷新计划。格式:起始步-结束步:刷新频率,起始步-结束步:刷新频率。例如,--refresh_index 0-100:10,100-1000000:500 将在前 100 步中每 10 步刷新一次索引,然后从第 100 步到第 100 万步每 500 步刷新一次。对于固定计划,可以直接传入一个数字,例如 --refresh_index 100 将每 100 步刷新一次索引。传入 -1 表示永不刷新。(默认值:0-1000000:1000000)
  --shuffle             在训练时打乱数据(默认值:False)
  --precision {fp16,fp32,bf16}
                        数值精度——如果可用,建议使用 bf16;fp16 在训练中可能不稳定(默认值:fp32)
  --shard_optim         训练时的内存优化:使用分片数据并行将优化器状态分散到可用的 GPU 上,推荐用于大型模型(默认值:False)
  --shard_grads         训练时的内存优化:使用分片数据并行将梯度分散到可用的 GPU 上,推荐用于大型模型(默认值:False)
  --use_gradient_checkpoint_reader
                        在阅读器中使用梯度检查点技术(默认值:False)
  --use_gradient_checkpoint_retriever
                        在检索器中使用梯度检查点技术(默认值:False)

引用

如需引用本工作,请使用以下 BibTeX 格式:

@article{izacard_few-shot_2022,
	title = {基于检索增强语言模型的少样本学习},
	url = {http://arxiv.org/abs/2208.03299},
	publisher = {arXiv},
	author = {Izacard, Gautier 和 Lewis, Patrick 和 Lomeli, Maria 和 Hosseini, Lucas 和 Petroni, Fabio 和 Schick, Timo 和 Dwivedi-Yu, Jane 和 Joulin, Armand 和 Riedel, Sebastian 和 Grave, Edouard},
	year = {2022},
}

许可证

代码许可证:

Atlas 项目的大部分代码采用 CC-BY-NC 许可证,但项目中的部分组件则遵循单独的许可条款:Hugging Face Transformers 库采用 Apache 2.0 许可证,该许可证适用于 src/modeling_bert.pysrc/modeling_t5.py 文件。

数据许可证:

仓库中使用的维基百科相关数据,例如通过 download_corpus.pydownload_index.py 获取的语料库和索引,均依据 CC-BY-SA 许可证授权。

常见问题

相似工具推荐

stable-diffusion-webui

stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。

162.1k|★★★☆☆|今天
开发框架图像Agent

everything-claude-code

everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上

139k|★★☆☆☆|今天
开发框架Agent语言模型

ComfyUI

ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。

107.7k|★★☆☆☆|2天前
开发框架图像Agent

NextChat

NextChat 是一款轻量且极速的 AI 助手,旨在为用户提供流畅、跨平台的大模型交互体验。它完美解决了用户在多设备间切换时难以保持对话连续性,以及面对众多 AI 模型不知如何统一管理的痛点。无论是日常办公、学习辅助还是创意激发,NextChat 都能让用户随时随地通过网页、iOS、Android、Windows、MacOS 或 Linux 端无缝接入智能服务。 这款工具非常适合普通用户、学生、职场人士以及需要私有化部署的企业团队使用。对于开发者而言,它也提供了便捷的自托管方案,支持一键部署到 Vercel 或 Zeabur 等平台。 NextChat 的核心亮点在于其广泛的模型兼容性,原生支持 Claude、DeepSeek、GPT-4 及 Gemini Pro 等主流大模型,让用户在一个界面即可自由切换不同 AI 能力。此外,它还率先支持 MCP(Model Context Protocol)协议,增强了上下文处理能力。针对企业用户,NextChat 提供专业版解决方案,具备品牌定制、细粒度权限控制、内部知识库整合及安全审计等功能,满足公司对数据隐私和个性化管理的高标准要求。

87.6k|★★☆☆☆|今天
开发框架语言模型

ML-For-Beginners

ML-For-Beginners 是由微软推出的一套系统化机器学习入门课程,旨在帮助零基础用户轻松掌握经典机器学习知识。这套课程将学习路径规划为 12 周,包含 26 节精炼课程和 52 道配套测验,内容涵盖从基础概念到实际应用的完整流程,有效解决了初学者面对庞大知识体系时无从下手、缺乏结构化指导的痛点。 无论是希望转型的开发者、需要补充算法背景的研究人员,还是对人工智能充满好奇的普通爱好者,都能从中受益。课程不仅提供了清晰的理论讲解,还强调动手实践,让用户在循序渐进中建立扎实的技能基础。其独特的亮点在于强大的多语言支持,通过自动化机制提供了包括简体中文在内的 50 多种语言版本,极大地降低了全球不同背景用户的学习门槛。此外,项目采用开源协作模式,社区活跃且内容持续更新,确保学习者能获取前沿且准确的技术资讯。如果你正寻找一条清晰、友好且专业的机器学习入门之路,ML-For-Beginners 将是理想的起点。

85k|★★☆☆☆|今天
图像数据工具视频

ragflow

RAGFlow 是一款领先的开源检索增强生成(RAG)引擎,旨在为大语言模型构建更精准、可靠的上下文层。它巧妙地将前沿的 RAG 技术与智能体(Agent)能力相结合,不仅支持从各类文档中高效提取知识,还能让模型基于这些知识进行逻辑推理和任务执行。 在大模型应用中,幻觉问题和知识滞后是常见痛点。RAGFlow 通过深度解析复杂文档结构(如表格、图表及混合排版),显著提升了信息检索的准确度,从而有效减少模型“胡编乱造”的现象,确保回答既有据可依又具备时效性。其内置的智能体机制更进一步,使系统不仅能回答问题,还能自主规划步骤解决复杂问题。 这款工具特别适合开发者、企业技术团队以及 AI 研究人员使用。无论是希望快速搭建私有知识库问答系统,还是致力于探索大模型在垂直领域落地的创新者,都能从中受益。RAGFlow 提供了可视化的工作流编排界面和灵活的 API 接口,既降低了非算法背景用户的上手门槛,也满足了专业开发者对系统深度定制的需求。作为基于 Apache 2.0 协议开源的项目,它正成为连接通用大模型与行业专有知识之间的重要桥梁。

77.1k|★★★☆☆|昨天
Agent图像开发框架