open-metric-learning

GitHub
986 76 非常简单 1 次阅读 2天前Apache-2.0开发框架图像其他数据工具
AI 解读 由 AI 自动生成,仅供参考

open-metric-learning 是一个基于 PyTorch 的开源框架,专为训练和验证能生成高质量嵌入(Embeddings)的模型而设计。它提供了一套完整的度量学习与检索流程、预训练模型库以及丰富的工具组件,帮助开发者轻松构建高效的相似性搜索系统。

在传统分类任务中,模型通常不直接优化向量间的距离(如余弦相似度或 L2 距离),导致直接提取特征用于检索时效果不佳。open-metric-learning 正是为了解决这一痛点,通过专门的损失函数和训练策略,让模型学会“拉近”相似样本、“推远”不同样本,从而显著提升检索精度。

该工具非常适合从事计算机视觉、推荐系统或搜索引擎开发的工程师与研究人员使用。无论是需要构建以图搜图功能,还是开发个性化推荐算法,都能从中获益。其亮点在于模块化设计清晰,支持多种主流度量学习算法,并兼容 Python 3.10 至 3.12 版本,便于快速实验与部署。此外,该项目已被牛津大学、高等经济大学等学术机构及多家科技企业应用于实际研究与产品中,具备良好的社区支持与可靠性。

使用场景

某时尚电商平台的算法团队正致力于构建一个“以图搜图”功能,希望用户上传一张街拍照片后,系统能精准推荐库中款式最相似的商品。

没有 open-metric-learning 时

  • 训练目标错位:团队直接复用分类模型的倒数层特征,但分类任务优化的是类别边界,并未直接优化向量间的余弦距离或欧氏距离,导致检索排序不准。
  • 实验流程繁琐:每次尝试新的损失函数(如 Triplet Loss 或 Circle Loss)都需要手动重写数据加载器和验证逻辑,开发周期长达数周。
  • 评估标准缺失:缺乏统一的 mAP 或 CMC 曲线评估体系,难以量化不同模型在检索任务上的真实性能差异。
  • 复现难度极高:由于缺少标准化的管道配置,团队成员间难以复现彼此的实验结果,调参过程如同“黑盒”摸索。

使用 open-metric-learning 后

  • 度量对齐优化:利用 open-metric-learning 内置的专用损失函数和采样策略,直接针对向量距离进行优化,显著提升了相似款式的召回率。
  • 流水线极速搭建:通过其模块化的 PyTorch 框架,只需修改配置文件即可快速切换模型架构与损失函数,新实验上线时间从周缩短至小时级。
  • 专业指标监控:集成标准的检索评估指标(如 mAP@K),实时可视化模型性能,让迭代方向清晰明确。
  • 标准化复现:依托其成熟的实验管理范式,确保了从数据增强到模型验证的全流程可复现,团队协作效率大幅提升。

open-metric-learning 将原本碎片化、高门槛的度量学习研发过程,转化为标准化、高效率的工业级落地方案。

运行环境要求

操作系统
  • 未说明
GPU

未说明 (基于 PyTorch,支持 DDP 多卡训练,具体显存和 CUDA 版本取决于所选模型架构)

内存

未说明

依赖
notes该工具是一个基于 PyTorch 和 PyTorch Lightning 的度量学习框架,专注于端到端的训练流程和预训练模型库(Zoo)。README 中未明确列出具体的操作系统、GPU 型号、显存大小或内存需求,这些通常取决于用户选择的具体模型(如 ResNet, ViT 等)和数据集规模。项目提供了针对大规模类别(数千个 ID)和小样本情况的优化策略。
python3.10, 3.11, 3.12
torch
pytorch-lightning
open-metric-learning
open-metric-learning hero image

快速开始

Documentation Status PyPI Status Pipi version python python python

OML 是一个基于 PyTorch 的框架,用于训练和验证能够生成高质量嵌入的模型。

受信任的机构

ㅤㅤ ㅤㅤ ㅤㅤ ㅤㅤ ㅤㅤ ㅤㅤ ㅤㅤ ㅤㅤ

ㅤㅤ

来自 牛津大学HSE 大学 的许多研究人员已经在他们的论文中使用了 OML。 [1] [2] [3]

文档

常见问题解答
为什么需要 OML?

你可能会想:“如果我需要图像嵌入,可以直接训练一个普通的分类器,然后取它的倒数第二层。” 这确实是一个不错的起点。但这样做也存在一些潜在的问题:

  • 如果你想利用嵌入进行检索,就需要计算它们之间的距离(例如余弦距离或 L2 距离)。 在分类任务中,这些距离通常不会在训练过程中被直接优化。因此,你只能寄希望于最终的嵌入具备理想的性质。

  • 第二个问题是验证过程。 在检索任务中,我们通常关心的是前 N 个结果与查询的相关性。评估模型的自然方式是模拟对参考集的检索请求,并使用某种检索指标来衡量性能。 因此,分类准确率并不能保证与这些检索指标相关。

  • 最后,你也可以尝试自己实现度量学习的流程。 这其中涉及大量工作:比如使用三元组损失时,需要以特定方式构建批次、实现不同类型的三元组挖掘、跟踪距离等;而在验证阶段,还需要实现检索指标, 包括高效地积累每个 epoch 的嵌入、处理各种边界情况等。如果你有多块 GPU 并使用 DDP 分布式训练,难度会更大。 此外,你可能还希望可视化自己的检索请求,突出显示好的和坏的检索结果。与其从头开始自己动手,不如直接使用 OML 来满足需求。

Open Metric Learning 和 PyTorch Metric Learning 有什么区别?

PML 是一个流行的度量学习库,它包含了丰富的损失函数、三元组挖掘方法、距离度量和归约策略;因此,我们也提供了如何将这些组件与 OML 结合使用的简单示例。 最初,我们曾尝试使用 PML,但最终还是决定开发自己的库,更加注重流水线和实际应用。 这就是 OML 与 PML 的主要区别:

  • OML 提供了 流水线,只需准备好配置文件和符合要求的数据格式,即可开始训练模型。 这类似于将数据转换为 COCO 格式以便使用 mmdetection 训练目标检测器。

  • OML 更加专注于端到端的流水线和实际应用场景。 它提供了基于配置的示例,涵盖了贴近现实生活的常用基准数据集(如包含数千个类别的商品图片)。我们在这些数据集中找到了一些优秀的超参数组合,并训练和发布了相应的模型及其配置文件。 因此,相比 PML,OML 更加注重“配方”式的解决方案。PML 的作者也曾在评论中表示,他的库更像是工具集而非现成的解决方案,而且 PML 中的示例大多针对 CIFAR 和 MNIST 数据集。

  • OML 拥有 预训练模型库,可以通过代码轻松调用,就像使用 torchvision 中的 resnet50(pretrained=True) 一样。

  • OML 与 PyTorch Lightning 集成,因此我们可以利用其 Trainer 的强大功能。 这在使用 DDP 时尤其有帮助。你可以对比我们的 DDP 示例PML 的示例。 顺便说一下,PML 也有 Trainers,但在示例中并不常用,通常还是直接使用自定义的 traintest 函数。

我们认为,提供流水线、简洁的示例以及预训练模型库,能够将入门门槛降到极低。

什么是度量学习?

度量学习问题(也称为 极端分类 问题)是指我们拥有数千个实体的 ID,但每个实体只有少量样本的情况。 通常假设在测试阶段(或生产环境中)我们会遇到未见过的实体,这就使得无法直接应用常规的分类流程。在这种情况下,通常会利用得到的嵌入向量来进行检索或匹配操作。

以下是计算机视觉领域中的一些此类任务示例:

  • 人物/动物重识别
  • 人脸识别
  • 地标识别
  • 电商网站的搜索引擎 以及其他许多任务。

术语表(命名约定)

  • embedding - 模型的输出(也称为 特征向量描述符)。
  • query - 在检索过程中用作查询的样本。
  • gallery set - 用于搜索与 query 相似项的实体集合(也称为 referenceindex)。
  • Sampler - 用于 DataLoader 的参数,用来组成批次。
  • Miner - 在 Sampler 组成批次之后,用于形成样本对或三元组的对象。不一定只在当前批次内组合样本,有时还会结合内存库来完成这一工作。
  • Samples/Labels/Instances - 以 DeepFashion 数据集为例,它包含数千种时尚单品的 ID(我们称之为 labels),每种 ID 对应几张照片(我们称单张照片为 instancesample)。所有这些时尚单品又可以分为不同的类别,如“裙子”、“夹克”、“短裤”等(我们称之为 categories)。 注意,为了避免误解,我们尽量不使用 “class” 这一术语。
  • training epoch - 对于基于组合损失的训练,我们使用的批次采样器长度通常等于 [训练数据集中标签的数量] / [每个批次中的标签数量]。这意味着在一个 epoch 中,我们并不会遍历所有的训练样本(与常规分类不同),而是确保每个标签都被覆盖到。

使用 OML 训练的模型效果如何?

其性能可以与当前(2022 年)的 SotA 方法相媲美,例如 Hyp-ViT(关于这种方法的简要说明:它是一种基于 ViT 架构并采用对比损失训练的模型,但其嵌入被投影到了双曲空间中。 作者声称,这种空间能够更好地描述现实世界数据的层次化结构。 因此,这篇论文需要大量的数学推导来将常规运算适配到双曲空间中。)

我们在保持其他参数不变的情况下,使用相同的架构训练了一个采用三元组损失的模型:包括训练和测试时的数据增强、图像尺寸以及优化器等。相关配置请参阅 Models Zoo。 关键在于我们所采用的启发式矿工和采样器:

  • Category Balance Sampler 通过限制每个批次中包含的类别数量 C 来生成批次。例如,当 C = 1 时,它会把所有的夹克放在一个批次里,而所有的牛仔裤则放在另一个批次里(仅为示例)。这种方式自动提高了负样本的难度:让模型区分两件夹克的不同要比区分一件夹克和一件T恤更有意义。

  • Hard Triplets Miner 通过仅保留最难的三元组(即正样本距离最大、负样本距离最小的三元组),进一步提升了任务难度。

以下是两个流行基准上的 CMC@1 分数: SOP 数据集:Hyp-ViT — 85.9,我们的模型 — 86.6。DeepFashion 数据集:Hyp-ViT — 92.5,我们的模型 — 92.1。 由此可见,通过简单的启发式方法和避免复杂的数学计算,我们同样能够在 SotA 水平上取得优异表现。

自监督学习呢?

最近的自监督学习研究确实取得了显著成果。然而,这类方法往往需要非常庞大的计算资源才能训练出模型。而在我们的框架中,我们主要考虑的是普通用户通常只有几块 GPU 的情况。

尽管如此,我们也不会忽视该领域的成功,而是从两个方面加以利用:

  • 作为预训练检查点的来源,以便更好地进行模型初始化。根据文献和我们的经验,这些检查点作为初始权重比直接使用 ImageNet 上预训练的监督模型要好得多。因此,我们提供了在配置文件或构造函数中传入参数即可加载这些预训练检查点的功能。
  • 作为灵感来源。例如,我们将 MoCo 中的内存库思想应用到了 TripletLoss 上。

使用 OML 是否需要了解其他深度学习框架?

不需要。OML 是框架无关的。虽然我们在实验中使用 PyTorch Lightning 作为训练循环的运行者,但也保留了完全使用原生 PyTorch 运行的能力。因此,OML 中与 Lightning 相关的部分非常少,并且这部分逻辑与其他代码是分开存放的(见 oml.lightning)。即使你使用 Lightning,也不必深入了解它,因为我们已经提供了开箱即用的 Pipelines

由于支持纯 PyTorch 运行以及模块化的代码结构,你可以在实现必要的封装后,轻松地将 OML 与自己喜爱的框架结合使用。

没有数据科学基础也能使用 OML 吗?

是的。要使用流水线运行实验,你只需要编写一个转换器,将其数据转换为我们的格式(即准备包含几个预定义列的.csv表格)。就这么简单!

很可能我们在模型仓库中已经为你所在的领域准备了合适的预训练模型。在这种情况下,你甚至无需再进行训练。

我可以将模型导出为ONNX格式吗?

目前我们还不支持直接将模型导出为ONNX格式。不过,你可以利用PyTorch内置的功能来实现这一点。更多信息请参阅此议题

文档

入门教程: 英文 | 俄文 | 中文

更多

安装

pip install -U open-metric-learning  # 最小依赖
pip install -U open-metric-learning[nlp]
pip install -U open-metric-learning[audio]
pip install -U open-metric-learning[pipelines]

# 如果出现冲突,可以不带依赖项安装,并手动管理版本:
pip install --no-deps open-metric-learning

OML功能

损失函数 | 挖掘器
miner = AllTripletsMiner()
miner = NHardTripletsMiner()
miner = MinerWithBank()
...
criterion = TripletLossWithMiner(0.1, miner)
criterion = ArcFaceLoss()
criterion = SurrogatePrecision()
采样器
labels = train.get_labels()
l2c = train.get_label2category()


sampler = BalanceSampler(labels)
sampler = CategoryBalanceSampler(labels, l2c)
sampler = DistinctCategoryBalanceSampler(labels, l2c)
配置支持
max_epochs: 10
sampler:
  name: balance
  args:
    n_labels: 2
    n_instances: 2
多模态预训练模型
model_hf = AutoModel.from_pretrained("roberta-base")
tokenizer = AutoTokenizer.from_pretrained("roberta-base")
extractor_txt = HFWrapper(model_hf)

extractor_img = ViTExtractor.from_pretrained("vits16_dino")
transforms, _ = get_transforms_for_pretrained("vits16_dino")

extractor_audio = ECAPATDNNExtractor.from_pretrained()
后处理
emb = inference(extractor, dataset)
rr = RetrievalResults.from_embeddings(emb, dataset)

postprocessor = AdaptiveThresholding()
rr_upd = postprocessor.process(rr, dataset)
基于神经网络的后处理 | 论文
embeddings = inference(extractor, dataset)
rr = RetrievalResults.from_embeddings(embeddings, dataset)

postprocessor = PairwiseReranker(ConcatSiamese(), top_n=3)
rr_upd = postprocessor.process(rr, dataset)
日志记录
logger = TensorBoardPipelineLogger()
logger = NeptunePipelineLogger()
logger = WandBPipelineLogger()
logger = MLFlowPipelineLogger()
logger = ClearMLPipelineLogger()
PML
from pytorch_metric_learning import losses

criterion = losses.TripletMarginLoss(0.2, "all")
pred = ViTExtractor()(data)
criterion(pred, gts)
类别支持
# 训练
loader = DataLoader(CategoryBalanceSampler())

# 验证
rr = RetrievalResults.from_embeddings()
m.calc_retrieval_metrics_rr(rr, query_categories)
其他指标
embeddigs = inference(model, dataset)
rr = RetrievalResults.from_embeddings(embeddings, dataset)

m.calc_retrieval_metrics_rr(rr, precision_top_k=(5,))
m.calc_fnmr_at_fmr_rr(rr, fmr_vals=(0.1,))
m.calc_topological_metrics(embeddings, pcf_variance=(0.5,))
Lightning
import pytorch_lightning as pl

model = ViTExtractor.from_pretrained("vits16_dino")
clb = MetricValCallback(EmbeddingMetrics(dataset))
module = ExtractorModule(model, criterion, optimizer)

trainer = pl.Trainer(max_epochs=3, callbacks=[clb])
trainer.fit(module, train_loader, val_loader)
Lightning DDP
clb = MetricValCallback(EmbeddingMetrics(val))
module = ExtractorModuleDDP(
    model, criterion, optimizer, train, val
)

ddp = {"devices": 2, "strategy": DDPStrategy()}
trainer = pl.Trainer(max_epochs=3, callbacks=[clb], **ddp)
trainer.fit(module)

示例

以下是一个如何在小型数据集上训练、验证和后处理模型的示例,该数据集包含 图像文本, 或 音频。 有关数据集格式的更多详细信息,请参阅 文档

向右滚动以查看 图像 > 文本 > 音频

图像 文本 音频
from torch.optim import Adam
from torch.utils.data import DataLoader

from oml import datasets as d
from oml.inference import inference
from oml.losses import TripletLossWithMiner
from oml.metrics import calc_retrieval_metrics_rr
from oml.miners import HardTripletsMiner
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.retrieval import RetrievalResults, AdaptiveThresholding
from oml.samplers import BalanceSampler
from oml.utils import get_mock_images_dataset

model = ViTExtractor.from_pretrained("vits16_dino").to("cpu").train()
transform, _ = get_transforms_for_pretrained("vits16_dino")

df_train, df_val = get_mock_images_dataset(global_paths=True)
train = d.ImageLabeledDataset(df_train, transform=transform)
val = d.ImageQueryGalleryLabeledDataset(df_val, transform=transform)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = TripletLossWithMiner(0.1, HardTripletsMiner(), need_logs=True)
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)


# 训练1个epoch
for batch in DataLoader(train, batch_sampler=sampler):
    embeddings = model(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(criterion.last_logs)


# 验证:通过检索相关项目
embeddings = inference(model, val, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
rr = AdaptiveThresholding(n_std=2).process(rr)
rr.visualize(query_ids=[2, 1], dataset=val, show=True)
print(calc_retrieval_metrics_rr(rr, map_top_k=(3,), cmc_top_k=(1,)))


from torch.optim import Adam
from torch.utils.data import DataLoader
from transformers import AutoModel, AutoTokenizer

from oml import datasets as d
from oml.inference import inference
from oml.losses import TripletLossWithMiner
from oml.metrics import calc_retrieval_metrics_rr
from oml.miners import NHardTripletsMiner
from oml.models import HFWrapper
from oml.retrieval import RetrievalResults, AdaptiveThresholding
from oml.samplers import BalanceSampler
from oml.utils import get_mock_texts_dataset

model = HFWrapper(AutoModel.from_pretrained("bert-base-uncased"), 768).to("cpu").train()
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

df_train, df_val = get_mock_texts_dataset()
train = d.TextLabeledDataset(df_train, tokenizer=tokenizer)
val = d.TextQueryGalleryLabeledDataset(df_val, tokenizer=tokenizer)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = TripletLossWithMiner(
    0.1, NHardTripletsMiner(n_positive=2, n_negative=2), need_logs=True
)
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)


# 训练1个epoch
for batch in DataLoader(train, batch_sampler=sampler):
    embeddings = model(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(criterion.last_logs)


# 验证:通过检索相关项目
embeddings = inference(model, val, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
rr = AdaptiveThresholding(n_std=2).process(rr)
rr.visualize(query_ids=[2, 1], dataset=val, show=True)
print(calc_retrieval_metrics_rr(rr, map_top_k=(3,), cmc_top_k=(1,)))


from torch.optim import Adam
from torch.utils.data import DataLoader

from oml import datasets as d
from oml.inference import inference
from oml.losses import ArcFaceLoss
from oml.metrics import calc_retrieval_metrics_rr
from oml.models import ECAPATDNNExtractor
from oml.retrieval import AdaptiveThresholding, RetrievalResults
from oml.samplers import BalanceSampler
from oml.utils import get_mock_audios_dataset

model = ECAPATDNNExtractor.from_pretrained("ecapa_tdnn_taoruijie").to("cpu").train()

df_train, df_val = get_mock_audios_dataset(global_paths=True)
train = d.AudioLabeledDataset(df_train)
val = d.AudioQueryGalleryLabeledDataset(df_val)

optimizer = Adam(model.parameters(), lr=1e-4)
criterion = ArcFaceLoss(m=0.2, s=30, in_features=192, num_classes=4)  # 类似于论文
sampler = BalanceSampler(train.get_labels(), n_labels=2, n_instances=2)

# 训练1个epoch
for batch in DataLoader(train, batch_sampler=sampler):
    embeddings = model(batch["input_tensors"])
    loss = criterion(embeddings, batch["labels"])
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    print(criterion.last_logs)


# 通过检索相关项目进行验证
embeddings = inference(model, val, batch_size=4, num_workers=0)
rr = RetrievalResults.from_embeddings(embeddings, val, n_items=3)
rr = AdaptiveThresholding(n_std=2).process(rr)
rr.visualize_as_html(query_ids=[2, 1], dataset=val, show=True)
print(calc_retrieval_metrics_rr(rr, map_top_k=(3,), cmc_top_k=(1,)))


输出
{'active_tri': 0.125, 'pos_dist': 82.5, 'neg_dist': 100.5}  # batch 1
{'active_tri': 0.0, 'pos_dist': 36.3, 'neg_dist': 56.9}     # batch 2

{'cmc': {1: 0.75}, 'precision': {5: 0.75}, 'map': {3: 0.8}}

Open In Colab

输出
{'active_tri': 0.0, 'pos_dist': 8.5, 'neg_dist': 11.0}  # batch 1
{'active_tri': 0.25, 'pos_dist': 8.9, 'neg_dist': 9.8}  # batch 2

{'cmc': {1: 0.8}, 'precision': {5: 0.7}, 'map': {3: 0.9}}

Open In Colab

输出
{'active_tri': 0.25, 'pos_dist': 17.3, 'neg_dist': 18.4}  # batch 1
{'active_tri': 0.0, 'pos_dist': 17.1, 'neg_dist': 18.5}   # batch 2

{'cmc': {1: 1.0}, 'precision': {5: 1.0}, 'map': {3: 1.0}}

Open In Colab


额外的插图、解释和技巧 用于上述代码。

由训练好的模型进行检索

这里是一个推理时的例子(换句话说,就是在测试集上进行检索)。 下面的代码既适用于文本也适用于图像。

查看示例

from oml.datasets import ImageQueryGalleryDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.utils import get_mock_images_dataset
from oml.retrieval import RetrievalResults, AdaptiveThresholding

_, df_test = get_mock_images_dataset(global_paths=True)
del df_test["label"]  # 我们不需要真实标签来进行预测

extractor = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transform, _ = get_transforms_for_pretrained("vits16_dino")

dataset = ImageQueryGalleryDataset(df_test, transform=transform)
embeddings = inference(extractor, dataset, batch_size=4, num_workers=0)

rr = RetrievalResults.from_embeddings(embeddings, dataset, n_items=5)
rr = AdaptiveThresholding(n_std=3.5).process(rr)
rr.visualize(query_ids=[0, 1], dataset=dataset, show=True)

# 你会得到检索到的项目的ID以及对应的距离
print(rr)

由训练好的模型进行检索:流式处理与文本转图像

这里有一个查询和图库分别处理的例子。

  • 首先,这可能对流式检索很有用,当图库(索引)集合非常庞大且固定时,而查询则是分批到达。
  • 其次,查询和图库的性质不同,例如,查询是文本,而图库是图像
查看示例

import pandas as pd

from oml.datasets import ImageBaseDataset
from oml.inference import inference
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained
from oml.retrieval import RetrievalResults, ConstantThresholding
from oml.utils import get_mock_images_dataset

extractor = ViTExtractor.from_pretrained("vits16_dino").to("cpu")
transform, _ = get_transforms_for_pretrained("vits16_dino")

paths = pd.concat(get_mock_images_dataset(global_paths=True))["path"]
galleries, queries1, queries2 = paths[:20], paths[20:22], paths[22:24]

# 图库非常庞大且固定,所以我们只处理一次
dataset_gallery = ImageBaseDataset(galleries, transform=transform)
embeddings_gallery = inference(extractor, dataset_gallery, batch_size=4, num_workers=0)

# 查询以“在线”流的形式到来
for queries in [queries1, queries2]:
    dataset_query = ImageBaseDataset(queries, transform=transform)
    embeddings_query = inference(extractor, dataset_query, batch_size=4, num_workers=0)

    # 对于下面的操作,我们将提供与向量搜索数据库(如QDrant或Faiss)的集成
    rr = RetrievalResults.from_embeddings_qg(
        embeddings_query=embeddings_query, embeddings_gallery=embeddings_gallery,
        dataset_query=dataset_query, dataset_gallery=dataset_gallery
    )
    rr = ConstantThresholding(th=80).process(rr)
    rr.visualize_qg([0, 1], dataset_query=dataset_query, dataset_gallery=dataset_gallery, show=True)
    print(rr)

流水线

流水线提供了一种仅通过更改配置文件即可运行度量学习实验的方式。你所需要做的就是将你的数据集准备成所需的格式。

更多详情请参阅Pipelines文件夹:

动物园:图像模型

你可以使用我们动物园中的图像模型,或者在继承自 IExtractor 的基础上使用其他任意模型。

查看如何使用模型

from oml.const import CKPT_SAVE_ROOT as CKPT_DIR, MOCK_DATASET_PATH as DATA_DIR
from oml.models import ViTExtractor
from oml.registry import get_transforms_for_pretrained

model = ViTExtractor.from_pretrained("vits16_dino").eval()
transforms, im_reader = get_transforms_for_pretrained("vits16_dino")

img = im_reader(DATA_DIR / "images" / "circle_1.jpg")  # 在此处放置你的图像路径
img_tensor = transforms(img)
# img_tensor = transforms(image=img)["image"]  # 对于 Albumentations 提供的变换
features = model(img_tensor.unsqueeze(0))

# 查看其他可用模型:
print(list(ViTExtractor.pretrained_models.keys()))

# 加载保存在磁盘上的检查点:
model_ = ViTExtractor(weights=CKPT_DIR / "vits16_dino.ckpt", arch="vits16", normalise_features=False)

图像模型动物园

由我们训练的模型。以下指标适用于 224 x 224 的图像:

模型 cmc1 数据集 权重 实验
ViTExtractor.from_pretrained("vits16_inshop") 0.921 DeepFashion Inshop 链接 链接
ViTExtractor.from_pretrained("vits16_sop") 0.866 Stanford Online Products 链接 链接
ViTExtractor.from_pretrained("vits16_cars") 0.907 CARS 196 链接 链接
ViTExtractor.from_pretrained("vits16_cub") 0.837 CUB 200 2011 链接 链接

由其他研究人员训练的模型。请注意,某些基准上的指标之所以很高,是因为这些数据曾被用作训练集的一部分(例如 unicom)。以下指标同样适用于 224 x 224 的图像:

模型 Stanford Online Products DeepFashion InShop CUB 200 2011 CARS 196
ViTUnicomExtractor.from_pretrained("vitb16_unicom") 0.700 0.734 0.847 0.916
ViTUnicomExtractor.from_pretrained("vitb32_unicom") 0.690 0.722 0.796 0.893
ViTUnicomExtractor.from_pretrained("vitl14_unicom") 0.726 0.790 0.868 0.922
ViTUnicomExtractor.from_pretrained("vitl14_336px_unicom") 0.745 0.810 0.875 0.924
ViTCLIPExtractor.from_pretrained("sber_vitb32_224") 0.547 0.514 0.448 0.618
ViTCLIPExtractor.from_pretrained("sber_vitb16_224") 0.565 0.565 0.524 0.648
ViTCLIPExtractor.from_pretrained("sber_vitl14_224") 0.512 0.555 0.606 0.707
ViTCLIPExtractor.from_pretrained("openai_vitb32_224") 0.612 0.491 0.560 0.693
ViTCLIPExtractor.from_pretrained("openai_vitb16_224") 0.648 0.606 0.665 0.767
ViTCLIPExtractor.from_pretrained("openai_vitl14_224") 0.670 0.675 0.745 0.844
ViTExtractor.from_pretrained("vits16_dino") 0.648 0.509 0.627 0.265
ViTExtractor.from_pretrained("vits8_dino") 0.651 0.524 0.661 0.315
ViTExtractor.from_pretrained("vitb16_dino") 0.658 0.514 0.541 0.288
ViTExtractor.from_pretrained("vitb8_dino") 0.689 0.599 0.506 0.313
ViTExtractor.from_pretrained("vits14_dinov2") 0.566 0.334 0.797 0.503
ViTExtractor.from_pretrained("vits14_reg_dinov2") 0.566 0.332 0.795 0.740
ViTExtractor.from_pretrained("vitb14_dinov2") 0.565 0.342 0.842 0.644
ViTExtractor.from_pretrained("vitb14_reg_dinov2") 0.557 0.324 0.833 0.828
ViTExtractor.from_pretrained("vitl14_dinov2") 0.576 0.352 0.844 0.692
ViTExtractor.from_pretrained("vitl14_reg_dinov2") 0.571 0.340 0.840 0.871
ResnetExtractor.from_pretrained("resnet50_moco_v2") 0.493 0.267 0.264 0.149
ResnetExtractor.from_pretrained("resnet50_imagenet1k_v1") 0.515 0.284 0.455 0.247

这些指标可能与论文中报告的不同,因为训练/验证集划分的方式以及是否使用了边界框可能存在差异。

动物园:文本

这里提供了一个与HuggingFace Transformers模型的轻量级集成。 你可以将其替换为其他任意继承自IExtractor的模型。

pip install open-metric-learning[nlp]
查看如何使用模型

from transformers import AutoModel, AutoTokenizer

from oml.models import HFWrapper

model = AutoModel.from_pretrained('bert-base-uncased').eval()
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
extractor = HFWrapper(model=model, feat_dim=768)

inp = tokenizer(text="Hello world", return_tensors="pt", add_special_tokens=True)
embeddings = extractor(inp)

请注意,目前我们还没有自己的文本模型动物园。

动物园:音频

你可以使用我们动物园中的音频模型,或者在从IExtractor继承后使用其他任意模型。

pip install open-metric-learning[audio]
查看如何使用模型

import torchaudio

from oml.models import ECAPATDNNExtractor
from oml.const import CKPT_SAVE_ROOT as CKPT_DIR, MOCK_AUDIO_DATASET_PATH as DATA_DIR

# 替换为你的实际路径
ckpt_path = CKPT_DIR / "ecapa_tdnn_taoruijie.pth"
file_path = DATA_DIR / "voices" / "voice0_0.wav"

model = ECAPATDNNExtractor(weights=ckpt_path, arch="ecapa_tdnn_taoruijie", normalise_features=False).to("cpu").eval()
audio, sr = torchaudio.load(file_path)

if audio.shape[0] > 1:
    audio = audio.mean(dim=0, keepdim=True)  # 按通道取平均
if sr != 16000:
    audio = torchaudio.functional.resample(audio, sr, 16000)

embeddings = model.extract(audio)

音频模型动物园

模型 Vox1_O Vox1_E Vox1_H
ECAPATDNNExtractor.from_pretrained("ecapa_tdnn_taoruijie") 0.86 1.18 2.17

以上指标表示等错误率(EER)。数值越低越好。

贡献指南

我们欢迎新贡献者!请参阅我们的:

致谢

该项目于2020年作为Catalyst库的一个模块启动。 我要感谢当时与我一起开发该模块的人员: Julia Shenshina, Nikita Balagansky, Sergey Kolesnikov 以及其他成员。

同时,我也要感谢那些在项目独立出来后继续推进这一流程的人: Julia Shenshina, Misha Kindulov, Aron Dik, Aleksei Tarasov以及 Verkhovtsev Leonid

此外,我也要感谢NewYorker公司,因为其中一部分功能是由我领导的计算机视觉团队开发并使用的。

版本历史

release.4.0.02025/04/14
release.3.1.02024/06/13

常见问题

相似工具推荐

openclaw

OpenClaw 是一款专为个人打造的本地化 AI 助手,旨在让你在自己的设备上拥有完全可控的智能伙伴。它打破了传统 AI 助手局限于特定网页或应用的束缚,能够直接接入你日常使用的各类通讯渠道,包括微信、WhatsApp、Telegram、Discord、iMessage 等数十种平台。无论你在哪个聊天软件中发送消息,OpenClaw 都能即时响应,甚至支持在 macOS、iOS 和 Android 设备上进行语音交互,并提供实时的画布渲染功能供你操控。 这款工具主要解决了用户对数据隐私、响应速度以及“始终在线”体验的需求。通过将 AI 部署在本地,用户无需依赖云端服务即可享受快速、私密的智能辅助,真正实现了“你的数据,你做主”。其独特的技术亮点在于强大的网关架构,将控制平面与核心助手分离,确保跨平台通信的流畅性与扩展性。 OpenClaw 非常适合希望构建个性化工作流的技术爱好者、开发者,以及注重隐私保护且不愿被单一生态绑定的普通用户。只要具备基础的终端操作能力(支持 macOS、Linux 及 Windows WSL2),即可通过简单的命令行引导完成部署。如果你渴望拥有一个懂你

349.3k|★★★☆☆|4天前
Agent开发框架图像

stable-diffusion-webui

stable-diffusion-webui 是一个基于 Gradio 构建的网页版操作界面,旨在让用户能够轻松地在本地运行和使用强大的 Stable Diffusion 图像生成模型。它解决了原始模型依赖命令行、操作门槛高且功能分散的痛点,将复杂的 AI 绘图流程整合进一个直观易用的图形化平台。 无论是希望快速上手的普通创作者、需要精细控制画面细节的设计师,还是想要深入探索模型潜力的开发者与研究人员,都能从中获益。其核心亮点在于极高的功能丰富度:不仅支持文生图、图生图、局部重绘(Inpainting)和外绘(Outpainting)等基础模式,还独创了注意力机制调整、提示词矩阵、负向提示词以及“高清修复”等高级功能。此外,它内置了 GFPGAN 和 CodeFormer 等人脸修复工具,支持多种神经网络放大算法,并允许用户通过插件系统无限扩展能力。即使是显存有限的设备,stable-diffusion-webui 也提供了相应的优化选项,让高质量的 AI 艺术创作变得触手可及。

162.1k|★★★☆☆|5天前
开发框架图像Agent

everything-claude-code

everything-claude-code 是一套专为 AI 编程助手(如 Claude Code、Codex、Cursor 等)打造的高性能优化系统。它不仅仅是一组配置文件,而是一个经过长期实战打磨的完整框架,旨在解决 AI 代理在实际开发中面临的效率低下、记忆丢失、安全隐患及缺乏持续学习能力等核心痛点。 通过引入技能模块化、直觉增强、记忆持久化机制以及内置的安全扫描功能,everything-claude-code 能显著提升 AI 在复杂任务中的表现,帮助开发者构建更稳定、更智能的生产级 AI 代理。其独特的“研究优先”开发理念和针对 Token 消耗的优化策略,使得模型响应更快、成本更低,同时有效防御潜在的攻击向量。 这套工具特别适合软件开发者、AI 研究人员以及希望深度定制 AI 工作流的技术团队使用。无论您是在构建大型代码库,还是需要 AI 协助进行安全审计与自动化测试,everything-claude-code 都能提供强大的底层支持。作为一个曾荣获 Anthropic 黑客大奖的开源项目,它融合了多语言支持与丰富的实战钩子(hooks),让 AI 真正成长为懂上

148.6k|★★☆☆☆|今天
开发框架Agent语言模型

ComfyUI

ComfyUI 是一款功能强大且高度模块化的视觉 AI 引擎,专为设计和执行复杂的 Stable Diffusion 图像生成流程而打造。它摒弃了传统的代码编写模式,采用直观的节点式流程图界面,让用户通过连接不同的功能模块即可构建个性化的生成管线。 这一设计巧妙解决了高级 AI 绘图工作流配置复杂、灵活性不足的痛点。用户无需具备编程背景,也能自由组合模型、调整参数并实时预览效果,轻松实现从基础文生图到多步骤高清修复等各类复杂任务。ComfyUI 拥有极佳的兼容性,不仅支持 Windows、macOS 和 Linux 全平台,还广泛适配 NVIDIA、AMD、Intel 及苹果 Silicon 等多种硬件架构,并率先支持 SDXL、Flux、SD3 等前沿模型。 无论是希望深入探索算法潜力的研究人员和开发者,还是追求极致创作自由度的设计师与资深 AI 绘画爱好者,ComfyUI 都能提供强大的支持。其独特的模块化架构允许社区不断扩展新功能,使其成为当前最灵活、生态最丰富的开源扩散模型工具之一,帮助用户将创意高效转化为现实。

108.1k|★★☆☆☆|2天前
开发框架图像Agent

gemini-cli

gemini-cli 是一款由谷歌推出的开源 AI 命令行工具,它将强大的 Gemini 大模型能力直接集成到用户的终端环境中。对于习惯在命令行工作的开发者而言,它提供了一条从输入提示词到获取模型响应的最短路径,无需切换窗口即可享受智能辅助。 这款工具主要解决了开发过程中频繁上下文切换的痛点,让用户能在熟悉的终端界面内直接完成代码理解、生成、调试以及自动化运维任务。无论是查询大型代码库、根据草图生成应用,还是执行复杂的 Git 操作,gemini-cli 都能通过自然语言指令高效处理。 它特别适合广大软件工程师、DevOps 人员及技术研究人员使用。其核心亮点包括支持高达 100 万 token 的超长上下文窗口,具备出色的逻辑推理能力;内置 Google 搜索、文件操作及 Shell 命令执行等实用工具;更独特的是,它支持 MCP(模型上下文协议),允许用户灵活扩展自定义集成,连接如图像生成等外部能力。此外,个人谷歌账号即可享受免费的额度支持,且项目基于 Apache 2.0 协议完全开源,是提升终端工作效率的理想助手。

100.8k|★★☆☆☆|今天
插件Agent图像

markitdown

MarkItDown 是一款由微软 AutoGen 团队打造的轻量级 Python 工具,专为将各类文件高效转换为 Markdown 格式而设计。它支持 PDF、Word、Excel、PPT、图片(含 OCR)、音频(含语音转录)、HTML 乃至 YouTube 链接等多种格式的解析,能够精准提取文档中的标题、列表、表格和链接等关键结构信息。 在人工智能应用日益普及的今天,大语言模型(LLM)虽擅长处理文本,却难以直接读取复杂的二进制办公文档。MarkItDown 恰好解决了这一痛点,它将非结构化或半结构化的文件转化为模型“原生理解”且 Token 效率极高的 Markdown 格式,成为连接本地文件与 AI 分析 pipeline 的理想桥梁。此外,它还提供了 MCP(模型上下文协议)服务器,可无缝集成到 Claude Desktop 等 LLM 应用中。 这款工具特别适合开发者、数据科学家及 AI 研究人员使用,尤其是那些需要构建文档检索增强生成(RAG)系统、进行批量文本分析或希望让 AI 助手直接“阅读”本地文件的用户。虽然生成的内容也具备一定可读性,但其核心优势在于为机器

93.4k|★★☆☆☆|3天前
插件开发框架