textgen

GitHub
981 112 中等 1 次阅读 1周前Apache-2.0开发框架语言模型
AI 解读 由 AI 自动生成,仅供参考

TextGen 是一个功能强大的开源文本生成工具库,旨在让大语言模型的训练与应用变得简单高效。它集成了 LLaMA、ChatGLM、BLOOM、GPT2、T5、BART 等多种主流架构,为开发者提供了一套“开箱即用”的解决方案,有效解决了从模型微调(SFT)、LoRA 高效训练到多卡推理部署全流程中的技术门槛高、配置复杂等痛点。

无论是希望快速验证想法的算法研究人员,还是需要将大模型落地到具体业务场景的 AI 工程师,都能通过 TextGen 轻松上手。它不仅支持通用的对话生成、文本摘要和翻译任务,还特色化地实现了诗词歌词生成(SongNet)以及基于 UDA/EDA 的文本数据扩增功能,特别适合需要定制化领域模型或解决数据稀缺问题的用户。

在技术亮点方面,TextGen 紧跟前沿进展,支持 NEFTune 噪声嵌入训练以提升微调效果,并优化了多显卡并行推理能力,显著提升了运行效率。此外,项目还配套发布了多个经过中文语料精细微调的现成模型(如医疗问诊、多轮对话模型),用户可直接调用或作为基座进行二次开发。凭借友好的接口设计和详尽的文档,TextGen 致力于成为连接学术研究与工业应用的便捷桥梁。

使用场景

某电商初创团队急需构建一个能处理售后咨询、支持多轮对话且具备行业知识的智能客服系统,以缓解人工压力。

没有 textgen 时

  • 模型适配门槛高:团队需手动编写复杂的 PyTorch 代码来适配 LLaMA 或 ChatGLM 等大模型,环境配置繁琐,调试周期长达数周。
  • 领域知识缺失:通用大模型不懂电商术语(如“仅退款”、“运费险”),回答生硬且经常产生幻觉,无法直接用于生产环境。
  • 数据扩充困难:缺乏高质量的客服对话语料,团队难以通过有效手段(如回译、同义词替换)低成本扩增训练数据,导致模型泛化能力差。
  • 资源利用率低:推理阶段不支持多卡并行,响应速度慢,难以应对高峰期的并发请求,硬件成本居高不下。

使用 textgen 后

  • 开箱即用微调:利用 textgen 内置的 LoRA 微调脚本,团队仅需几条命令即可基于 ChatGLM2 或 LLaMA2 启动训练,将部署周期从数周缩短至两天。
  • 精准领域定制:通过加载官方发布的医疗或通用对话 LoRA 模型作为基座,并注入自有客服数据进行 SFT 训练,模型迅速掌握了电商业务逻辑,回答准确自然。
  • 智能数据增强:调用 textgen 集成的 UDA 和回译算法,自动将少量种子问答扩充为千级高质量训练样本,显著提升了模型对多样化用户提问的理解力。
  • 高效多卡推理:借助 textgen 新增的多卡推理支持,团队轻松实现批量请求处理,响应速度加倍,在同等硬件下支撑了更高的并发流量。

textgen 通过提供一站式的大模型微调与推理方案,让中小团队也能低门槛地拥有定制化、高性能的行业专属生成式 AI 能力。

运行环境要求

操作系统
  • 未说明
GPU
  • 训练和运行大模型(如 LLaMA, ChatGLM, Baichuan)通常需要 NVIDIA GPU
  • 具体显存需求取决于模型大小:7B/6B 模型建议 16GB+ 显存,13B 模型建议 24GB+ 显存或使用多卡推理
  • 支持多卡加速
内存

未说明(建议 16GB 以上以加载大型模型)

依赖
notes该工具支持多种架构(LLaMA, ChatGLM, Baichuan, T5, GPT2 等)的 LoRA 微调和推理。大模型训练或推理对显存要求较高,小显存用户建议使用 LoRA 微调或量化技术。部分功能(如 NEFTune)需特定参数启用。模型文件首次运行时会自动从 HuggingFace 下载。
python3.8+
torch
transformers
accelerate
peft
datasets
sentencepiece
protobuf
textgen hero image

快速开始

🇨🇳中文 | 🌐English | 📖文档/Docs | 🤖模型/Models


TextGen: 文本生成模型的实现

PyPI版本 下载量 欢迎贡献 许可证 Apache 2.0 Python版本 GitHub问题 微信交流群

📖 简介

TextGen实现了多种文本生成模型,包括:LLaMA、ChatGLM、UDA、GPT2、Seq2Seq、BART、T5、SongNet等模型,开箱即用。

🔥 新闻

[2023/11/02] v1.1.2版本: GPT模型支持了NEFTune给embedding加噪SFT训练方法,SFT中使用 --neft_alpha 参数启用 NEFTune,例如 --neft_alpha 5。详见Release-v1.1.2

[2023/09/05] v1.1.1版本: 支持多卡推理,推理速度加倍,调库textgen做batch推理,多卡推理更方便、快速。详见Release-v1.1.1

[2023/08/23] v1.1.0版本: 发布基于ShareGPT4数据集微调的中英文Vicuna-13B模型shibing624/vicuna-baichuan-13b-chat,和对应的LoRA模型shibing624/vicuna-baichuan-13b-chat-lora,支持多轮对话,评测效果有提升,详见Release-v1.1.0

[2023/08/02] v1.0.2版本: 新增支持ChatGLM2和LLaMA2模型的SFT微调训练,详见Release-v1.0.2

[2023/06/15] v1.0.0版本: 新增ChatGLM/LLaMA/Bloom模型的多轮对话微调训练,并发布医疗问诊LoRA模型shibing624/ziya-llama-13b-medical-lora。详见Release-v1.0.0

[2023/06/02] v0.2.7版本: 新增ChatGLM/LLaMA/Bloom模型的SFT微调训练,并发布适用于通用对话和中文纠错的LoRA模型。详见Release-v0.2.7

😊 特性

  • GPT:本项目基于PyTorch实现了 ChatGLM-6B 1,2,3 / Baichuan 1,2 / LLaMA 1,2 / BLOOM / Mistral / QWen 等GPT模型LoRA微调训练和预测,可以用于对话生成任务和领域微调训练
  • UDA/EDA:本项目实现了UDA(非核心词替换)、EDA和Back Translation(回译)算法,基于TF-IDF将句子中部分不重要词替换为同义词,随机词插入、删除、替换等方法,产生新的文本,实现了文本扩增
  • Seq2Seq:本项目基于PyTorch实现了Seq2Seq、ConvSeq2Seq、BART模型的训练和预测,可以用于文本翻译、对话生成、摘要生成等文本生成任务
  • T5:本项目基于PyTorch实现了T5和CopyT5模型训练和预测,可以用于文本翻译、对话生成、对联生成、文案撰写等文本生成任务
  • GPT2:本项目基于PyTorch实现了GTP2模型训练和预测,可以用于文章生成、对联生成等文本生成任务
  • SongNet:本项目基于PyTorch实现了SongNet模型训练和预测,可以用于规范格式的诗词、歌词等文本生成任务
  • TGLS:本项目实现了TGLS无监督相似文本生成模型,是一种“先搜索后学习”的文本生成方法,通过反复迭代学习候选集,最终模型能生成类似候选集的高质量相似文本

发布模型

release基于textgen训练的中文模型,模型已经release到HuggingFace models,指定模型名称textgen会自动下载模型,可直接使用。

模型 架构 简介 训练脚本 预测脚本
shibing624/t5-chinese-couplet T5 经过中文对联数据微调后的模型 对联生成模型调研 predict script
shibing624/songnet-base-chinese-songci SongNet 经过宋词数据微调后的模型 training script predict script
shibing624/songnet-base-chinese-couplet SongNet 经过对联数据微调后的模型 training script predict script
shibing624/chatglm-6b-csc-zh-lora ChatGLM-6B 在27万中文拼写纠错数据shibing624/CSC上微调了一版ChatGLM-6B,纠错效果有提升,发布微调后的LoRA权重 training script predict script
shibing624/chatglm-6b-belle-zh-lora ChatGLM-6B 在100万条中文ChatGPT指令Belle数据集BelleGroup/train_1M_CN上微调了一版ChatGLM-6B,问答效果有提升,发布微调后的LoRA权重 training script predict script
shibing624/llama-13b-belle-zh-lora LLaMA-13B 在100万条中文ChatGPT指令Belle数据集BelleGroup/train_1M_CN上微调了一版Llama-13B,问答效果有提升,发布微调后的LoRA权重 training script predict script
shibing624/chinese-alpaca-plus-7b-hf LLaMA-7B 中文LLaMA-Plus, Alpaca-Plus 7B版本,在LLaMA-7B上扩充了中文词表并继续预训练120G文本(通用领域),在4M指令数据集上微调后得到的中文Alpaca-plus模型 training script predict script
shibing624/chinese-alpaca-plus-13b-hf LLaMA-13B 中文LLaMA-Plus, Alpaca-Plus 13B版本,在LLaMA-13B上扩充了中文词表并继续预训练120G文本(通用领域),在4.3M指令数据集上微调后得到的中文Alpaca-plus模型 training script predict script
shibing624/ziya-llama-13b-medical-lora LLaMA-13B 在240万条中英文医疗数据集shibing624/medical上微调了一版Ziya-LLaMA-13B模型,医疗问答效果有提升,发布微调后的LoRA权重 training script predict script
shibing624/vicuna-baichuan-13b-chat Baichuan-13B-Chat 在10万条多语言ShareGPT GPT4多轮对话数据集shibing624/sharegpt_gpt4上SFT微调了一版baichuan-13b-chat多轮问答模型,日常问答和医疗问答效果有提升,发布微调后的完整模型权重 training script predict script

评估

模型 架构 简介 分数
LLaMA-7B-Chinese-Alpaca LLaMA-7B 复用ymcui/Chinese-LLaMA-Alpaca的评估case和得分 4.92
LLaMA-13B-Chinese-Alpaca LLaMA-13B 复用ymcui/Chinese-LLaMA-Alpaca的评估case和得分 7.05
ChatGLM-6B ChatGLM-6B 基于原生THUDM/chatglm-6b评估测试集得分 7.16
ChatGLM-6B-v1.1 ChatGLM-6B 基于原生THUDM/chatglm-6bv1.1英文优化版模型评估测试集得分 7.18
shibing624/chatglm-6b-belle-zh-lora ChatGLM-6B 基于THUDM/chatglm-6b加载shibing624/chatglm-6b-belle-zh-loraLoRA模型后评估测试集得分 7.03
facat/alpaca-lora-cn-13b LLaMA-13B 基于decapoda-research/llama-13b-hf加载facat/alpaca-lora-cn-13bLoRA模型后评估测试集并标注得分 4.13
Chinese-Vicuna/Chinese-Vicuna-lora-13b-belle-and-guanaco LLaMA-13B 基于decapoda-research/llama-13b-hf加载Chinese-Vicuna/Chinese-Vicuna-lora-13b-belle-and-guanacoLoRA模型后评估测试集并标注得分 3.98
shibing624/chinese-alpaca-plus-7b-hf LLaMA-7B 使用ymcui/Chinese-LLaMA-Alpaca 合并模型方法合并HF权重后,评估测试集并标注得分 6.93
shibing624/chinese-alpaca-plus-13b-hf LLaMA-13B 使用ymcui/Chinese-LLaMA-Alpaca 合并模型方法合并HF权重后,评估测试集并标注得分 7.07
TheBloke/vicuna-13B-1.1-HF LLaMA-13B 使用原生vicuna-13B-1.1合并后的模型,评估测试集并标注得分 5.13
IDEA-CCNL/Ziya-LLaMA-13B-v1 LLaMA-13B 使用姜子牙通用大模型V1,评估测试集并标注得分 6.63

说明:

  • 评估case,详见在线文档:中文LLM-benchmark多任务评估集(腾讯文档) https://docs.qq.com/sheet/DUUpsREtWbFBsUVJE?tab=r7io7g 感谢韩俊明、杨家铭等同学的标注
  • 评估任务类型包括:知识问答,开放式问答,数值计算,诗词、音乐、体育,娱乐,写文章,文本翻译,代码编程,伦理、拒答类,多轮问答,Score 评分是前100条(10分制)的平均分数,人工打分,越高越好
  • 评估数量少,任务类型不够全面,评分之间的大小关系有一些参考价值,分数的绝对值没太大参考价值
  • 评估脚本:tests/test_benchmark.py ,使用fp16预测,无int量化处理,运行脚本可复现评估结果,但生成结果具有随机性,受解码超参、随机种子等因素影响。评测并非绝对严谨,测试结果仅供晾晒参考
  • 结论:ChatGLM-6B、LLaMA-13B的中文衍生模型(包括alpaca-plus, vicuna, ziya)的表现属于第一梯队,原版LLaMA-7B的表现整体稍差些
  • LLaMA-13B-Chinese-Alpaca是在原版LLaMA上扩充了中文词表,并融入了约20G的通用中文语料后的指令微调模型,表明了LLaMA的底座优秀,具有强大的语言迁移能力
  • ChatGLM这种原生的中文预训练模型更理解中文语义,且在中文知识问答、开放式问答得分高
  • LLaMA系列模型数值计算、中英翻译、代码编程类得分高
  • 经过中文预训练和SFT微调后的Chinese-LLaMA模型在中文诗词、娱乐、伦理类得分相较原版LLaMA有提升

🚀 演示

HuggingFace 演示:https://huggingface.co/spaces/shibing624/chinese-couplet-generate

运行示例:examples/T5/gradio_demo.py 查看演示:

python examples/T5/gradio_demo.py

模型由 examples/t5/T5_Finetune_Chinese_Couplet.ipynb 训练得到。

💾 安装

pip install -U textgen

或者

安装开发版本:

pip install torch # conda install pytorch
git clone https://github.com/shibing624/textgen.git
cd textgen
python setup.py install

▶️ 使用

ChatGLM-6B 模型

使用 ChatGLM-6B 微调后的模型

示例:examples/gpt/inference_demo.py

from textgen import GptModel

model = GptModel("chatglm", "THUDM/chatglm-6b", peft_name="shibing624/chatglm-6b-csc-zh-lora")
r = model.predict(["介绍下北京"])
print(r)  # ['北京是中国的首都...']

训练 ChatGLM-6B 微调模型

  1. 支持自定义训练数据集和训练参数,数据集格式参考examples/data/sharegpt_zh_100_format.jsonl
  2. 支持QLoRA、AdaLoRA、LoRA、P_Tuning、Prefix_Tuning等部分参数微调方法,也支持全参微调
  3. 支持多卡训练,支持混合精度训练
  4. 支持多卡推理

示例:examples/gpt/training_chatglm_demo.py

单卡训练:

cd examples/gpt
CUDA_VISIBLE_DEVICES=0 python training_chatglm_demo.py --do_train --do_predict --num_epochs 1 --output_dir outputs_chatglm_v1

多卡训练:

cd examples/gpt
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 training_chatglm_demo.py --do_train --do_predict --num_epochs 20 --output_dir outputs_chatglm_v1

多卡推理:

cd examples/gpt
CUDA_VISIBLE_DEVICES=0,1 torchrun --nproc_per_node 2 inference_multigpu_demo.py --model_type chatglm --base_model THUDM/chatglm-6b

LLaMA 模型

使用 LLaMA 微调后的模型

示例:examples/gpt/inference_demo.py

显示代码示例和结果
import sys

sys.path.append('../..')
from textgen import GptModel

model = GptModel("llama", "decapoda-research/llama-7b-hf", peft_name="ziqingyang/chinese-alpaca-lora-7b")
r = model.predict(["用一句话描述地球为什么是独一无二的。"])
print(r)  # ['地球是唯一一颗拥有生命的行星。']

训练 LLaMA 微调模型

  1. 支持自定义训练数据集和训练参数,数据集格式参考examples/data/sharegpt_zh_100_format.jsonl
  2. 支持QLoRA、AdaLoRA、LoRA、P_Tuning、Prefix_Tuning等部分参数微调方法,也支持全参微调
  3. 支持多卡训练,支持混合精度训练,使用方法同上(ChatGLM多卡训练)
  4. 支持多卡推理

示例:examples/gpt/training_llama_demo.py

基于微调(LoRA)模型继续训练

如果需要基于Lora模型继续训练,可以使用下面的脚本合并模型为新的base model,再微调训练即可。

执行以下命令:

python -m textgen/gpt/merge_peft_adapter \
    --model_type llama \
    --base_model_name_or_path path/to/llama/model \
    --tokenizer_path path/to/llama/tokenizer \
    --peft_model_path path/to/lora/model \
    --output_dir merged

参数说明:

--model_type:模型类型,目前支持bloom,llama,baichuan和chatglm
--base_model_name_or_path:存放HF格式的底座模型权重和配置文件的目录
--tokenizer_path:存放HF格式的底座模型tokenizer文件的目录
--peft_model_path:中文LLaMA/Alpaca LoRA解压后文件所在目录,也可使用HF上的Lora模型名称,如`ziqingyang/chinese-alpaca-lora-7b`会自动下载对应模型
--output_dir:指定保存全量模型权重的目录,默认为./merged

训练领域模型

注意:为了全面地介绍训练医疗大模型的过程,把4阶段训练方法(Pretraining, Supervised Finetuning, Reward Modeling and Reinforcement Learning)单独新建了一个repo:shibing624/MedicalGPT,请移步该repo查看训练方法。

ConvSeq2Seq 模型

训练并预测ConvSeq2Seq模型:

示例:examples/seq2sesq/training_convseq2seq_model_demo.py

显示代码示例和结果
import argparse
from loguru import logger
import sys

sys.path.append('../..')
from textgen.seq2seq.conv_seq2seq_model import ConvSeq2SeqModel


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_file', default='../data/zh_dialog.tsv', type=str, help='Training data file')
    parser.add_argument('--do_train', action='store_true', help='Whether to run training.')
    parser.add_argument('--do_predict', action='store_true', help='Whether to run predict.')
    parser.add_argument('--output_dir', default='./outputs/convseq2seq_zh/', type=str, help='Model output directory')
    parser.add_argument('--max_seq_length', default=50, type=int, help='Max sequence length')
    parser.add_argument('--num_epochs', default=200, type=int, help='Number of training epochs')
    parser.add_argument('--batch_size', default=32, type=int, help='Batch size')
    args = parser.parse_args()
    logger.info(args)

    if args.do_train:
        logger.info('Loading data...')
        model = ConvSeq2SeqModel(epochs=args.num_epochs, batch_size=args.batch_size,
                                 model_dir=args.output_dir, max_length=args.max_seq_length)
        model.train_model(args.train_file)
        print(model.eval_model(args.train_file))

    if args.do_predict:
        model = ConvSeq2SeqModel(epochs=args.num_epochs, batch_size=args.batch_size,
                                 model_dir=args.output_dir, max_length=args.max_seq_length)
        sentences = ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
        print("inputs:", sentences)
        print('outputs:', model.predict(sentences))


if __name__ == '__main__':
    main()

输出:

inputs: ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
outputs: ['人工智能是工程和科学的分支,致力于构建思维的机器。', '我的程序运行在python,所以我在任何运脑上工作!', '我不能错热是一个疯狂的人工智能"200年。']

BART 模型

训练并预测BART模型:

示例:examples/seq2sesq/training_bartseq2seq_zh_demo.py

输出:

inputs: ['什么是ai', '你是什么类型的计算机', '你知道热力学吗']
outputs: ['人工智能是工程和科学的分支,致力于构', '我的程序运行在python,所以我在任何电脑上', '什么是热力学吗?']

T5 模型

示例:examples/t5/training_zh_t5_model_demo.py

显示代码示例和结果
import argparse
from loguru import logger
import pandas as pd
import sys

sys.path.append('../..')
from textgen.t5 import T5Model


def load_data(file_path):
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip('\n')
            terms = line.split('\t')
            if len(terms) == 2:
                data.append(['QA', terms[0], terms[1]])
            else:
                logger.warning(f'line error: {line}')
    return data


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--train_file', default='../data/zh_dialog.tsv', type=str, help='训练数据文件')
    parser.add_argument('--model_type', default='t5', type=str, help='Transformers模型类型')
    parser.add_argument('--model_name', default='Langboat/mengzi-t5-base', type=str, help='Transformers模型或路径')
    parser.add_argument('--do_train', action='store_true', help='是否进行训练。')
    parser.add_argument('--do_predict', action='store_true', help='是否进行预测。')
    parser.add_argument('--output_dir', default='./outputs/mengzi_t5_zh/', type=str, help='模型输出目录')
    parser.add_argument('--max_seq_length', default=50, type=int, help='最大序列长度')
    parser.add_argument('--num_epochs', default=10, type=int, help='训练轮数')
    parser.add_argument('--batch_size', default=32, type=int, help='批量大小')
    args = parser.parse_args()
    logger.info(args)

    if args.do_train:
        logger.info('加载数据...')
        # train_data: 包含3列的Pandas DataFrame - `prefix`, `input_text`, `target_text`。
        #   - `prefix`: 表示要执行的任务的字符串。(例如 `"question"`、`"stsb"`)
        #   - `input_text`: 输入文本。`prefix` 会前置以形成完整的输入。(<prefix>: <input_text>)
        #   - `target_text`: 目标序列
        train_data = load_data(args.train_file)
        logger.debug('train_data: {}'.format(train_data[:10]))
        train_df = pd.DataFrame(train_data, columns=["prefix", "input_text", "target_text"])

        eval_data = load_data(args.train_file)[:10]
        eval_df = pd.DataFrame(eval_data, columns=["prefix", "input_text", "target_text"])

        model_args = {
            "reprocess_input_data": True,
            "overwrite_output_dir": True,
            "max_seq_length": args.max_seq_length,
            "train_batch_size": args.batch_size,
            "num_train_epochs": args.num_epochs,
            "save_eval_checkpoints": False,
            "save_model_every_epoch": False,
            "evaluate_generated_text": True,
            "evaluate_during_training": True,
            "evaluate_during_training_verbose": True,
            "use_multiprocessing": True,
            "save_best_model": True,
            "output_dir": args.output_dir,
            "use_early_stopping": True,
        }
        # model_type: t5  model_name: Langboat/mengzi-t5-base
        model = T5Model(args.model_type, args.model_name, args=model_args)

        def count_matches(labels, preds):
            logger.debug(f"labels: {labels[:10]}")
            logger.debug(f"preds: {preds[:10]}")
            match = sum([1 if label == pred else 0 for label, pred in zip(labels, preds)])
            logger.debug(f"match: {match}")
            return match

        model.train_model(train_df, eval_data=eval_df, matches=count_matches)
        print(model.eval_model(eval_df, matches=count_matches))

    if args.do_predict:
        model = T5Model(args.model_type, args.output_dir)
        sentences = ["什么是ai", "你是什么类型的计算机", "你知道热力学吗"]
        print("inputs:", sentences)
        print("outputs:", model.predict(sentences))


if __name__ == '__main__':
    main()

输出:

inputs: ['什么是ai', '你是什么类型的计算机', '你知道热力学吗']
outputs: ['人工智能有两个广义的定义,任何拟人的机械,如在卡雷尔capeks', '我的程序运行在Python,所以我在任何电脑上工作!', '什么是热力学']

GPT2 模型

中文GPT2 - 文章生成

使用中文数据集(段落格式,\n间隔),训练GPT2模型,可以用于诗歌生成、文章生成等任务。

示例:examples/gpt2/training_zh_gpt2_demo.py

中文GPT2 - 对联生成

使用中文对联数据集(tsv格式,\t间隔),自定义数据集读取Dataset,训练GPT2模型,可以用于对联生成、对话生成等任务。

示例:examples/gpt2/training_couplet_gpt2_demo.py

GPT2与T5的对比:

  1. 两者都是基于Transformer改进而来,T5同时具有编码器和解码器,而GPT2只有解码器。
  2. T5的优势在于处理给定输入并生成对应输出的任务,如翻译、对话、问答等。
  3. GPT2的优势在于自由创作,比如撰写短文。
  4. T5在对联生成方面效果优于GPT2,而GPT2在诗词生成方面效果优于T5。

SongNet 模型

格式控制的文本生成模型,论文见SongNet: Rigid Formats Controlled Text Generation,适用于对韵律格式要求较高的诗歌、对联、歌词生成等任务。

示例:examples/songnet/training_zh_songnet_demo.py

关键词文本增强(EDA/UDA)

示例:examples/text_augmentation/text_augmentation_demo.py

显示代码示例和结果
import sys

sys.path.append('..')
from textgen.augment import TextAugment

if __name__ == '__main__':
    docs = ['主要研究机器学习、深度学习、计算机视觉、智能对话系统相关内容',
            '晚上肚子好难受',
            '你会武功吗,我不会',
            '组装标题质量受限于广告主自提物料的片段质量,且表达丰富度有限',
            ]
    m = TextAugment(sentence_list=docs)
    a = docs[0]
    print(a)

    b = m.augment(a, aug_ops='random-0.2')
    print('random-0.2:', b)

    b = m.augment(a, aug_ops='insert-0.2')
    print('insert-0.2:', b)

    b = m.augment(a, aug_ops='delete-0.2')
    print('delete-0.2:', b)

    b = m.augment(a, aug_ops='tfidf-0.2')
    print('tfidf-0.2:', b)

    b = m.augment(a, aug_ops='mix-0.2')
    print('mix-0.2:', b)

输出:

主要研究机器学习、深度学习、计算机视觉、智能对话系统相关内容
random-0.2: ('主要陪陪机器学习、深度学习主要计算机视觉、智能对话系统受限于内容', [('研究', '陪陪', 2, 4), ('、', '主要', 13, 15), ('相关', '受限于', 27, 30)])
insert-0.2: ('主要研究机器机器学习学习、深度深度学习、计算机视觉、智能对话系统相关内容', [('机器', '机器机器', 4, 8), ('学习', '学习学习', 8, 12), ('深度', '深度深度', 13, 17)])
delete-0.2: ('主要研究机器学习、深度学习、计算机视觉、对话系统相关内容', [('智能', '', 20, 20)])
tfidf-0.2: ('一是研究机器学习、深度学习、计算机听觉、智能交谈系统密切相关内容',([('主要', '一是', 0, 2), ('视觉', '听觉', 17, 19), ('对话', '交谈', 22, 24), ('相关', '密切相关', 26, 30)]))
mix-0.2: ('主要研究机器学习、深度学、计算机听觉、智能对话软件系统相关内容',([('学习', '学', 11, 12), ('视觉', '听觉', 16, 18), ('系统', '软件系统', 23, 27)]))

TGLS 模型(无监督相似文本生成模型)

无监督的中文电商评论生成:从电商评论中提取用户表达观点的短句并进行组合来生成仿真评论。

example: examples/unsup_generation/unsup_generation_demo.py

展示代码示例和结果
import os
import sys

sys.path.append('..')
from textgen.unsup_generation import TglsModel, load_list

pwd_path = os.path.abspath(os.path.dirname(__file__))

samples = load_list(os.path.join(pwd_path, './data/ecommerce_comments.txt'))
docs_text = [
    ["挺好的,速度很快,也很实惠,不知效果如何",
     "产品没得说,买了以后就降价,心情不美丽。",
     "刚收到,包装很完整,不错",
     "发货速度很快,物流也不错,同一时间买的两个东东,一个先到一个还在路上。这个水水很喜欢,不过盖子真的开了。盖不牢了现在。",
     "包装的很好,是正品",
     "被种草兰蔻粉水三百元一大瓶囤货,希望是正品好用,收到的时候用保鲜膜包裹得严严实实,只敢买考拉自营的护肤品",
     ],
    ['很温和,清洗的也很干净,不油腻,很不错,会考虑回购,第一次考拉买护肤品,满意',
     '这款卸妆油我会无限回购的。即使我是油痘皮,也不会闷痘,同时在脸部按摩时,还能解决白头的脂肪粒的问题。用清水洗完脸后,非常的清爽。',
     '自从用了fancl之后就不用其他卸妆了,卸的舒服又干净',
     '买贵了,大润发才卖79。9。',
     ],
    samples
]
m = TglsModel(docs_text)
r = m.generate(samples[:500])
print('size:', len(r))
for review in r:
    print('\t' + review)

output:

美迪惠尔 N.M.F针剂水库保湿面膜有如下的20句评论,其中有10句是真实用户评论,10句是生成的评论,能看出来么?😂

还不错还不错还不错还不错。
东西到了,不知道好不好用。试用过后再来评价。到时看网评都还可以。
哺乳期唯一使用的护肤品,每天都是素颜,脸面全靠面膜吊着😄补水💦不粘腻一如既往的支持,喜欢💕
搞活动时买的面膜,不知道这个面膜是真是假敷在脸上面膜纸都有小水泡鼓起来。
很不错,非常补水,用过的都知道,性价比之王,好用又不贵,正品,用着放心,物流也很快。
面膜非常好用哦。面膜薄薄的。好像是蚕丝面膜啊。精华很多呢。敷在脸上很舒服。感觉挺保湿的,味道也挺好闻的。就是里面只有单纯的面膜直接敷脸上有点不好弄,哈哈哈
还可以保湿效果不错水润润的每天贴一片脸也不干了用完了在买点,不错还会继续回购的。
快递很快,东西很赞!想要得点考拉豆不容易,还要三十个字。时间宝贵,废话不说!用过了就知道了
挺好用的,朋友推荐来的
挺好用的,淡淡的,虽然不是很浓精华的感觉,但是效果也蛮好的。划算
不得不说美迪惠尔的面膜是我用过的最好的面膜之一😎补水效果非常好,没想到这么便宜的价格竟真的能买到真品。
保湿效果挺好的,面膜很好用。
期待好的产品。
一打开包装里面的精华刚刚好,用了补水补水效果不错,物流非常快。
皮肤很光滑😇比上去速度快三天就到了。
前两天皮肤干燥连续敷了两个晚上感觉还不错😂补水效果明显!可想而知精华液又多充足😍敷上以后凉凉的很舒服。
补水效果一般吧~但是我用的韩国背回来的面膜纸不算薄,希望好用会回购的,敷上脸感觉比较清爽~价格还不便宜。
希望好用,面膜用过了很好用,皮肤水嫩光滑白皙,补水不错,价格也合适。
就是精华液太少了,保湿效果不错。
面膜的补水效果非常好,保湿效果确实很赞,这个面膜相对于胶原蛋白和美白的那两款的面膜纸要厚一些,看着价格合适。

前10句是真实用户评论,后10句是生成的。

📚 数据集

SFT 数据集

奖励模型数据集

✅ 待办事项

  1. 添加多轮对话数据的微调方法
  2. 添加奖励模型的微调,前往 shibing624/MeidcalGPT
  3. 添加强化学习微调,前往 shibing624/MeidcalGPT
  4. 添加医疗奖励数据集
  5. 添加llama in4训练,前往 shibing624/MeidcalGPT
  6. 在colab中添加所有训练和预测的演示

☎️ 联系方式

  • 提问(建议): :GitHub issues
  • 邮件我:xuming: xuming624@qq.com
  • 微信我: 加我微信号:xuming624, 备注:姓名-公司名-NLP 进NLP交流群。

😇 引用

如果你在研究中使用了textgen,请按如下格式引用:

@misc{textgen,
  title={textgen: Text Generation Tool},
  author={Ming Xu},
  year={2021},
  howpublished={\url{https://github.com/shibing624/textgen}},
}

🤗 许可证

本仓库采用 Apache License 2.0 许可证。

请遵循 Model Card 使用LLaMA模型。

请遵循 RAIL License 使用BLOOM & BLOOMZ模型。

😍 贡献

项目代码还很粗糙,如果大家对代码有所改进,欢迎提交回本项目,在提交之前,注意以下两点:

  • tests添加相应的单元测试
  • 使用python -m pytest来运行所有单元测试,确保所有单测都是通过的

之后即可提交PR。

💕 致谢

感谢他们的杰出工作!

版本历史

1.1.22023/11/02
1.1.12023/09/08
1.1.02023/08/23
1.0.22023/08/02
1.0.02023/06/15
0.2.72023/06/15
0.2.52023/05/12
0.1.72022/11/28
0.1.52022/09/10
0.1.12022/06/30
0.1.02022/06/20
0.0.52022/05/10

常见问题

相似工具推荐

stable-diffusion-webui

stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。

162.1k|★★★☆☆|今天
开发框架图像Agent

everything-claude-code

everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上

139k|★★☆☆☆|今天
开发框架Agent语言模型

ComfyUI

ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。

107.7k|★★☆☆☆|2天前
开发框架图像Agent

NextChat

NextChat 是一款轻量且极速的 AI 助手,旨在为用户提供流畅、跨平台的大模型交互体验。它完美解决了用户在多设备间切换时难以保持对话连续性,以及面对众多 AI 模型不知如何统一管理的痛点。无论是日常办公、学习辅助还是创意激发,NextChat 都能让用户随时随地通过网页、iOS、Android、Windows、MacOS 或 Linux 端无缝接入智能服务。 这款工具非常适合普通用户、学生、职场人士以及需要私有化部署的企业团队使用。对于开发者而言,它也提供了便捷的自托管方案,支持一键部署到 Vercel 或 Zeabur 等平台。 NextChat 的核心亮点在于其广泛的模型兼容性,原生支持 Claude、DeepSeek、GPT-4 及 Gemini Pro 等主流大模型,让用户在一个界面即可自由切换不同 AI 能力。此外,它还率先支持 MCP(Model Context Protocol)协议,增强了上下文处理能力。针对企业用户,NextChat 提供专业版解决方案,具备品牌定制、细粒度权限控制、内部知识库整合及安全审计等功能,满足公司对数据隐私和个性化管理的高标准要求。

87.6k|★★☆☆☆|今天
开发框架语言模型

ML-For-Beginners

ML-For-Beginners 是由微软推出的一套系统化机器学习入门课程,旨在帮助零基础用户轻松掌握经典机器学习知识。这套课程将学习路径规划为 12 周,包含 26 节精炼课程和 52 道配套测验,内容涵盖从基础概念到实际应用的完整流程,有效解决了初学者面对庞大知识体系时无从下手、缺乏结构化指导的痛点。 无论是希望转型的开发者、需要补充算法背景的研究人员,还是对人工智能充满好奇的普通爱好者,都能从中受益。课程不仅提供了清晰的理论讲解,还强调动手实践,让用户在循序渐进中建立扎实的技能基础。其独特的亮点在于强大的多语言支持,通过自动化机制提供了包括简体中文在内的 50 多种语言版本,极大地降低了全球不同背景用户的学习门槛。此外,项目采用开源协作模式,社区活跃且内容持续更新,确保学习者能获取前沿且准确的技术资讯。如果你正寻找一条清晰、友好且专业的机器学习入门之路,ML-For-Beginners 将是理想的起点。

85k|★★☆☆☆|今天
图像数据工具视频

ragflow

RAGFlow 是一款领先的开源检索增强生成(RAG)引擎,旨在为大语言模型构建更精准、可靠的上下文层。它巧妙地将前沿的 RAG 技术与智能体(Agent)能力相结合,不仅支持从各类文档中高效提取知识,还能让模型基于这些知识进行逻辑推理和任务执行。 在大模型应用中,幻觉问题和知识滞后是常见痛点。RAGFlow 通过深度解析复杂文档结构(如表格、图表及混合排版),显著提升了信息检索的准确度,从而有效减少模型“胡编乱造”的现象,确保回答既有据可依又具备时效性。其内置的智能体机制更进一步,使系统不仅能回答问题,还能自主规划步骤解决复杂问题。 这款工具特别适合开发者、企业技术团队以及 AI 研究人员使用。无论是希望快速搭建私有知识库问答系统,还是致力于探索大模型在垂直领域落地的创新者,都能从中受益。RAGFlow 提供了可视化的工作流编排界面和灵活的 API 接口,既降低了非算法背景用户的上手门槛,也满足了专业开发者对系统深度定制的需求。作为基于 Apache 2.0 协议开源的项目,它正成为连接通用大模型与行业专有知识之间的重要桥梁。

77.1k|★★★☆☆|2天前
Agent图像开发框架