torchkeras

GitHub
2k 255 非常简单 1 次阅读 昨天Apache-2.0开发框架
AI 解读 由 AI 自动生成,仅供参考

torchkeras 是一款专为 PyTorch 打造的通用模型训练模板工具,旨在让深度学习训练过程像 Keras 一样简洁优雅。它主要解决了 PyTorch 生态中训练代码风格各异、结构复杂且难以修改的痛点。面对不同模型库中嵌套深、私有变量多且互不兼容的训练逻辑,开发者往往需要花费大量精力重写循环,而 torchkeras 通过统一的 compilefitevaluate 接口,让用户无需编写繁琐的训练循环,仅需几行代码即可启动训练。

该工具非常适合希望快速复现模型的研究人员、需要灵活调整训练逻辑的算法工程师,以及被复杂训练代码困扰的深度学习开发者。其核心亮点在于“好看、好用、好改”:不仅内置了美观的动态可视化图表和进度条,支持 TensorBoard 与 WandB 回调,还集成了早停(Early Stopping)、多 GPU 分布式训练及混合精度训练等高级功能。更难得的是,其核心代码仅约 200 行,高度模块化,既保证了功能的强大,又极大地降低了二次开发的门槛,让用户能轻松定制属于自己的训练流程。

使用场景

某计算机视觉算法工程师正在基于 PyTorch 微调一个医疗影像分割模型,急需快速验证多种网络结构并监控训练效果。

没有 torchkeras 时

  • 代码重复且混乱:每次更换模型库(如从 torchvision 切换到 segmentation_models),都需要重新阅读并改写几十行风格迥异的训练循环代码,极易出错。
  • 调试过程盲目:缺乏统一的可视化界面,只能盯着枯燥的终端日志猜测收敛情况,难以直观判断是否过拟合或需要早停。
  • 功能扩展困难:想要添加自定义评估指标或断点续训功能时,往往陷入层层嵌套的私有变量和复杂的 Trainer 类中,修改成本极高。
  • 多卡配置繁琐:部署到多 GPU 环境时,需手动编写分布式数据并行(DDP)样板代码,配置过程耗时且容易引发环境兼容问题。

使用 torchkeras 后

  • 接口统一简洁:只需定义好网络结构和损失函数,通过 KerasModel 封装后调用 fit 方法即可启动训练,彻底告别手写训练循环。
  • 训练过程透明:在 Jupyter Notebook 中自动呈现动态损失曲线和进度条,实时可视化监控“北极星指标”,让调参决策有据可依。
  • 功能即插即用:原生支持 Early Stopping、自定义 Metrics 及 TensorBoard 回调,核心逻辑仅约 200 行,按需修改极其轻松。
  • 高性能一键开启:依托 accelerate 后端,仅需简单参数配置即可无缝切换单卡、多卡 DDP 甚至混合精度训练,大幅提升实验效率。

torchkeras 将 PyTorch 的灵活性与 Keras 的优雅体验完美融合,让算法工程师从繁琐的工程细节中解放出来,真正专注于模型创新本身。

运行环境要求

操作系统
  • 未说明
GPU
  • 非必需
  • 支持 GPU 训练、多 GPU (DDP) 及 TPU,依赖 accelerate 库自动管理设备
  • 具体显存和 CUDA 版本取决于用户所选模型(如 LLM 微调需大显存),工具本身无硬性限制
内存

未说明

依赖
notes该工具是一个通用的 PyTorch 训练模板,核心代码仅约 200 行。它通过封装 KerasModel 类简化训练流程(compile/fit/evaluate 风格),并支持动态可视化日志 (VLog)。具体的硬件需求(如显存大小)完全取决于用户加载的具体模型(例如 examples 中展示的 YOLOv8、Llama、BERT 等),而非 torchkeras 本身。安装只需执行 'pip install torchkeras'。
python未说明
torch
torchkeras
torchmetrics
accelerate
tqdm
matplotlib
tensorboard
wandb
torchkeras hero image

快速开始

炼丹师,这是你的梦中情炉吗?🌹🌹

英文 | 简体中文

torchkeras 是一个通用的pytorch模型训练模版工具,按照如下目标进行设计和实现:

  • 好看 (代码优雅,日志美丽,自带可视化)

  • 好用 (使用方便,支持 进度条、评估指标、early-stopping等常用功能,支持tensorboard,wandb回调函数等扩展功能)

  • 好改 (修改简单,核心代码模块化,仅约200行,并提供丰富的修改使用案例)


1,炼丹之痛 😭😭

无论是学术研究还是工业落地,pytorch几乎都是目前炼丹的首选框架。

pytorch的胜出不仅在于其简洁一致的api设计,更在于其生态中丰富和强大的模型库。

但是我们会发现不同的pytorch模型库提供的训练和验证代码非常不一样。

torchvision官方提供的范例代码主要是一个关联了非常多依赖函数的train_one_epoch和evaluate函数,针对检测和分割各有一套。

yolo系列的主要是支持ddp模式的各种风格迥异的Trainer,每个不同的yolo版本都会改动很多导致不同yolo版本之间都难以通用。

抱抱脸的transformers库在借鉴了pytorch_lightning的基础上也搞了一个自己的Trainer,但与pytorch_lightning并不兼容。

非常有名的facebook的目标检测库detectron2, 也是搞了一个它自己的Trainer,配合一个全局的cfg参数设置对象来训练模型。

还有我用的比较多的语义分割的segmentation_models.pytorch这个库,设计了一个TrainEpoch和一个ValidEpoch来做训练和验证。

在学习和使用这些不同的pytorch模型库时,尝试阅读理解和改动这些训练和验证相关的代码让我受到了一万点伤害。

有些设计非常糟糕,嵌套了十几层,有些实现非常dirty,各种带下划线的私有变量满天飞。

让你每次想要改动一下加入一些自己想要的功能时就感到望而却步。

我不就想finetune一下模型嘛,何必拿这么多垃圾代码搞我?


2,梦中情炉 🤗🤗

这一切的苦不由得让我怀念起tensorflow中keras的美好了。

还记得keras那compile, fit, evalute三连击吗?一切都像行云流水般自然,真正的for humans。

而且你看任何用keras实现的模型库,训练和验证都几乎可以用这一套相同的接口,没有那么多莫名奇妙的野生Trainer。

我能否基于pytorch打造一个接口和keras一样简洁易用,功能强大,但是实现代码非常简短易懂,便于修改的模型训练工具呢?

从2020年7月左右发布1.0版本到最近发布的3.86版本,我陆陆续续在工作中一边使用一边打磨一个工具,总共提交修改了70多次。

现在我感觉我细心雕琢的这个作品终于长成了我心目中接近完美的样子。

她有一个美丽的名字:torchkeras.

是的,她兼具torch的灵动,也有keras的优雅~

并且她的美丽,无与伦比~

她,就是我的梦中情炉~ 🤗🤗


3,使用方法 🍊🍊

安装torchkeras

pip install torchkeras

通过使用torchkeras,你不需要写自己的pytorch模型训练循环。你只要做这样两步就可以了。

(1) 创建你的模型结构net,然后把它和损失函数传入torchkeras.KerasModel构建一个model。

(2) 使用model的fit方法在你的训练数据和验证数据上进行训练,训练数据和验证数据需要封装成两个DataLoader.

核心使用代码就像下面这样:

import torch 
import torchkeras
import torchmetrics
model = torchkeras.KerasModel(net,
                              loss_fn = nn.BCEWithLogitsLoss(),
                              optimizer= torch.optim.Adam(net.parameters(),lr = 1e-4),
                              metrics_dict = {"acc":torchmetrics.Accuracy(task='binary')}
                             )
dfhistory=model.fit(train_data=dl_train, 
                    val_data=dl_val, 
                    epochs=20, 
                    patience=3, 
                    ckpt_path='checkpoint',
                    monitor="val_acc",
                    mode="max",
                    plot=True
                   )

在jupyter notebook中执行训练代码,你将看到类似下面的动态可视化图像和训练日志进度条。

除此之外,torchkeras还提供了一个VLog类,方便你在任意的训练逻辑中使用动态可视化图像和日志进度条。

import time
import math,random
from torchkeras import VLog

epochs = 10
batchs = 30

#0, 指定监控北极星指标,以及指标优化方向
vlog = VLog(epochs, monitor_metric='val_loss', monitor_mode='min') 

#1, log_start 初始化动态图表
vlog.log_start() 

for epoch in range(epochs):
    
    #train
    for step in range(batchs):
        
        #2, log_step 更新step级别日志信息,打日志,并用小进度条显示进度
        vlog.log_step({'train_loss':100-2.5*epoch+math.sin(2*step/batchs)}) 
        time.sleep(0.05)
        
    #eval    
    for step in range(20):
        
        #3, log_step 更新step级别日志信息,指定training=False说明在验证模式,只打日志不更新小进度条
        vlog.log_step({'val_loss':100-2*epoch+math.sin(2*step/batchs)},training=False)
        time.sleep(0.05)
        
    #4, log_epoch 更新epoch级别日志信息,每个epoch刷新一次动态图表和大进度条进度
    vlog.log_epoch({'val_loss':100 - 2*epoch+2*random.random()-1,
                    'train_loss':100-2.5*epoch+2*random.random()-1})  

# 5, log_end 调整坐标轴范围,输出最终指标可视化图表
vlog.log_end()

4,主要特性 🍉🍉

torchkeras 支持以下这些功能特性,稳定支持这些功能的起始版本以及这些功能借鉴或者依赖的库的来源见下表。

功能 稳定支持起始版本 依赖或借鉴库
✅ 训练进度条 3.0.0 依赖tqdm,借鉴keras
✅ 训练评估指标 3.0.0 借鉴pytorch_lightning
✅ notebook中训练自带可视化 3.8.0 借鉴fastai
✅ early stopping 3.0.0 借鉴keras
✅ gpu training 3.0.0 依赖accelerate
✅ multi-gpus training(ddp) 3.6.0 依赖accelerate
✅ fp16/bf16 training 3.6.0 依赖accelerate
✅ tensorboard callback 3.7.0 依赖tensorboard
✅ wandb callback 3.7.0 依赖wandb
✅ VLog 3.9.5 依赖matplotlib

5,基本范例 🌰🌰

以下范例是torchkeras的基础范例,演示了torchkeras的主要功能。

包括基础训练,使用wandb可视化,使用wandb调参,使用tensorboard可视化,使用多GPU的ddp模式训练,通用的VLog动态日志可视化等。

example notebook kaggle链接
①基础范例 🔥🔥 basic example
Open In Kaggle

②wandb可视化 🔥🔥🔥 wandb demo
Open In Kaggle

③wandb自动化调参🔥🔥 wandb sweep demo
Open In Kaggle

④tensorboard可视化 tensorboard example
⑤ddp/tpu训练范例 ddp tpu examples
Open In Kaggle

⑥VLog动态日志可视化范例🔥🔥🔥 VLog example

6,进阶范例 🔥🔥

在炼丹实践中,遇到的数据集结构或者训练推理逻辑往往会千差万别。

例如我们可能会遇到多输入多输出结构,或者希望在训练过程中计算并打印一些特定的指标等等。

这时候炼丹师可能会倾向于使用最纯粹的pytorch编写自己的训练循环。

实际上,torchkeras提供了极致的灵活性来让炼丹师掌控训练过程的每个细节。

从这个意义上说,torchkeras更像是一个训练代码模版。

这个模版由低到高由StepRunner,EpochRunner 和 KerasModel 三个类组成。

在绝大多数场景下,用户只需要在StepRunner上稍作修改并覆盖掉,就可以实现自己想要的训练推理逻辑。

就像下面这段代码范例,这是一个多输入的例子,并且嵌入了特定的accuracy计算逻辑。

这段代码的完整范例,见examples下的CRNN_CTC验证码识别。


import torch.nn.functional as F 
from torchkeras import KerasModel
from accelerate import Accelerator

#我们覆盖KerasModel的StepRunner以实现自定义训练逻辑。
#注意这里把acc指标的结果写在了step_losses中以便和loss一样在Epoch上求平均,这是一个非常灵活而且有用的写法。

class StepRunner:
    def __init__(self, net, loss_fn, accelerator=None, stage = "train", metrics_dict = None, 
                 optimizer = None, lr_scheduler = None
                 ):
        self.net,self.loss_fn,self.metrics_dict,self.stage = net,loss_fn,metrics_dict,stage
        self.optimizer,self.lr_scheduler = optimizer,lr_scheduler
        self.accelerator = accelerator if accelerator is not None else Accelerator()
        if self.stage=='train':
            self.net.train() 
        else:
            self.net.eval()
    
    def __call__(self, batch):
        
        images, targets, input_lengths, target_lengths = batch
        
        #loss
        preds = self.net(images)
        preds_log_softmax = F.log_softmax(preds, dim=-1)
        loss = F.ctc_loss(preds_log_softmax, targets, input_lengths, target_lengths)
        acc = eval_acc(targets,preds)
            

        #backward()
        if self.optimizer is not None and self.stage=="train":
            self.accelerator.backward(loss)
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.optimizer.zero_grad()
            
            
        all_loss = self.accelerator.gather(loss).sum()
        
        #losses (or plain metric that can be averaged)
        step_losses = {self.stage+"_loss":all_loss.item(),
                       self.stage+'_acc':acc}
        
        #metrics (stateful metric)
        step_metrics = {}
        if self.stage=="train":
            if self.optimizer is not None:
                step_metrics['lr'] = self.optimizer.state_dict()['param_groups'][0]['lr']
            else:
                step_metrics['lr'] = 0.0
        return step_losses,step_metrics
    
#覆盖掉默认StepRunner 
KerasModel.StepRunner = StepRunner 

可以看到,这种修改实际上是非常简单并且灵活的,保持每个模块的输出与原始实现格式一致就行,中间处理逻辑根据需要灵活调整。

同理,用户也可以修改并覆盖EpochRunner来实现自己的特定逻辑,但我一般很少遇到有这样需求的场景。

examples目录下的范例库包括了使用torchkeras对一些非常常用的库中的模型进行训练的例子。

例如:

  • torchvision
  • transformers
  • segmentation_models_pytorch
  • ultralytics
  • timm

如果你想掌握一个东西,那么就去使用它,如果你想真正理解一个东西,那么尝试去改变它。 ———— 爱因斯坦

example 使用模型库 notebook
RL
强化学习——Q-Learning 🔥🔥 - Q-learning
强化学习——DQN - DQN
Tabular
二分类——LightGBM - LightGBM
多分类——Tabm🔥🔥🔥🔥🔥 - Tabm
多分类——FTTransformer🔥🔥 - FTTransformer
二分类——FM - FM
二分类——DeepFM - DeepFM
二分类——DeepCross - DeepCross
CV
图片分类——Resnet - Resnet
语义分割——UNet - UNet
目标检测——SSD - SSD
文字识别——CRNN 🔥🔥 - CRNN-CTC
目标检测——FasterRCNN torchvision FasterRCNN
语义分割——DeepLabV3++ segmentation_models_pytorch Deeplabv3++
实例分割——MaskRCNN detectron2 MaskRCNN
图片分类——SwinTransformer timm Swin
目标检测——YOLOv8 🔥🔥🔥 ultralytics YOLOv8_Detect
实例分割——YOLOv8 🔥🔥🔥 ultralytics YOLOv8_Segment
NLP
序列翻译——Transformer🔥🔥 - Transformer
文本生成——Llama🔥 - Llama
文本分类——BERT transformers BERT
命名实体识别——BERT transformers BERT_NER
LLM微调——ChatGLM2_LoRA 🔥🔥🔥 transformers ChatGLM2_LoRA
LLM微调——ChatGLM2_AdaLoRA 🔥 transformers ChatGLM2_AdaLoRA
LLM微调——ChatGLM2_QLoRA transformers ChatGLM2_QLoRA_Kaggle
LLM微调——BaiChuan13B_QLoRA transformers BaiChuan13B_QLoRA
LLM微调——BaiChuan13B_NER 🔥🔥🔥 transformers BaiChuan13B_NER
LLM微调——BaiChuan13B_MultiRounds 🔥 transformers BaiChuan13B_MultiRounds
LLM微调——Qwen7B_MultiRounds 🔥🔥🔥 transformers Qwen7B_MultiRounds
LLM微调——BaiChuan2_13B 🔥 transformers BaiChuan2_13B

7,鼓励和联系作者 🎈🎈

如果本项目对你有所帮助,想鼓励一下作者,记得给本项目加一颗星星star⭐️,并分享给你的朋友们喔😊!

如果在torchkeras的使用中遇到问题,可以在项目中提交issue。

如果想要获得更快的反馈或者与其他torchkeras用户小伙伴进行交流,

可以在公众号算法美食屋后台回复关键字:加群

版本历史

v3.7.22023/02/21

常见问题

相似工具推荐

openclaw

OpenClaw 是一款专为个人打造的本地化 AI 助手,旨在让你在自己的设备上拥有完全可控的智能伙伴。它打破了传统 AI 助手局限于特定网页或应用的束缚,能够直接接入你日常使用的各类通讯渠道,包括微信、WhatsApp、Telegram、Discord、iMessage 等数十种平台。无论你在哪个聊天软件中发送消息,OpenClaw 都能即时响应,甚至支持在 macOS、iOS 和 Android 设备上进行语音交互,并提供实时的画布渲染功能供你操控。 这款工具主要解决了用户对数据隐私、响应速度以及“始终在线”体验的需求。通过将 AI 部署在本地,用户无需依赖云端服务即可享受快速、私密的智能辅助,真正实现了“你的数据,你做主”。其独特的技术亮点在于强大的网关架构,将控制平面与核心助手分离,确保跨平台通信的流畅性与扩展性。 OpenClaw 非常适合希望构建个性化工作流的技术爱好者、开发者,以及注重隐私保护且不愿被单一生态绑定的普通用户。只要具备基础的终端操作能力(支持 macOS、Linux 及 Windows WSL2),即可通过简单的命令行引导完成部署。如果你渴望拥有一个懂你

349.3k|★★★☆☆|5天前
Agent开发框架图像

stable-diffusion-webui

stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。

162.1k|★★★☆☆|6天前
开发框架图像Agent

everything-claude-code

everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上

150.7k|★★☆☆☆|今天
开发框架Agent语言模型

ComfyUI

ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。

108.3k|★★☆☆☆|昨天
开发框架图像Agent

gemini-cli

gemini-cli 是一款由谷歌推出的开源 AI 命令行工具,它将强大的 Gemini 大模型能力直接集成到用户的终端环境中。对于习惯在命令行工作的开发者而言,它提供了一条从输入提示词到获取模型响应的最短路径,无需切换窗口即可享受智能辅助。 这款工具主要解决了开发过程中频繁上下文切换的痛点,让用户能在熟悉的终端界面内直接完成代码理解、生成、调试以及自动化运维任务。无论是查询大型代码库、根据草图生成应用,还是执行复杂的 Git 操作,gemini-cli 都能通过自然语言指令高效处理。 它特别适合广大软件工程师、DevOps 人员及技术研究人员使用。其核心亮点包括支持高达 100 万 token 的超长上下文窗口,具备出色的逻辑推理能力;内置 Google 搜索、文件操作及 Shell 命令执行等实用工具;更独特的是,它支持 MCP(模型上下文协议),允许用户灵活扩展自定义集成,连接如图像生成等外部能力。此外,个人谷歌账号即可享受免费的额度支持,且项目基于 Apache 2.0 协议完全开源,是提升终端工作效率的理想助手。

100.8k|★★☆☆☆|2天前
插件Agent图像

markitdown

MarkItDown 是一款由微软 AutoGen 团队打造的轻量级 Python 工具,专为将各类文件高效转换为 Markdown 格式而设计。它支持 PDF、Word、Excel、PPT、图片(含 OCR)、音频(含语音转录)、HTML 乃至 YouTube 链接等多种格式的解析,能够精准提取文档中的标题、列表、表格和链接等关键结构信息。 在人工智能应用日益普及的今天,大语言模型(LLM)虽擅长处理文本,却难以直接读取复杂的二进制办公文档。MarkItDown 恰好解决了这一痛点,它将非结构化或半结构化的文件转化为模型“原生理解”且 Token 效率极高的 Markdown 格式,成为连接本地文件与 AI 分析 pipeline 的理想桥梁。此外,它还提供了 MCP(模型上下文协议)服务器,可无缝集成到 Claude Desktop 等 LLM 应用中。 这款工具特别适合开发者、数据科学家及 AI 研究人员使用,尤其是那些需要构建文档检索增强生成(RAG)系统、进行批量文本分析或希望让 AI 助手直接“阅读”本地文件的用户。虽然生成的内容也具备一定可读性,但其核心优势在于为机器

93.4k|★★☆☆☆|5天前
插件开发框架