• 教程 >
  • 火炬多模态教程:FLAVA 微调
快捷键

火炬多模态教程:FLAVA 微调 ¶

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

多模态 AI 近年来因其无处不在的特性而变得非常流行,从图像字幕和视觉搜索等用例到更近期的从文本生成图像等应用。TorchMultimodal 是一个由 Pytorch 支持的库,包含构建块和端到端示例,旨在促进和加速多模态研究。

在本教程中,我们将演示如何使用来自 TorchMultimodal 库的预训练 SoTA 模型 FLAVA,在多模态任务上进行微调,即视觉问答(VQA)。该模型由两个基于 transformer 的单模态编码器(用于文本和图像)和一个多模态编码器组成,用于结合两个嵌入。它使用对比、图像文本匹配以及文本、图像和多模态掩码损失进行预训练。

安装¶

对于本教程,我们将使用 TextVQA 数据集和 Hugging Face 的 bert tokenizer 。因此,您需要安装数据集和 transformers,以及 TorchMultimodal。

备注

当在 Google Colab 中运行此教程时,请创建一个新单元格并运行以下命令来安装所需的包:

!pip install torchmultimodal-nightly
!pip install datasets
!pip install transformers

步骤 ¶

  1. 在您的计算机上运行以下命令将 Hugging Face 数据集下载到目录中:

    wget http://dl.fbaipublicfiles.com/pythia/data/vocab.tar.gz
    tar xf vocab.tar.gz
    

    备注

    如果您在 Google Colab 中运行此教程,请在新的单元中运行以下命令,并在这些命令前加上感叹号(!):

  2. 对于这个教程,我们将 VQA 视为一个分类任务,其中输入是图像和问题(文本),输出是答案类别。因此,我们需要下载包含答案类别的词汇表文件并创建答案到标签的映射。

    我们还从 Hugging Face 加载包含 34602 个训练样本(图像、问题和答案)的 textvqa 数据集。

我们看到有 3997 个答案类别,包括一个表示未知答案的类别。

with open("data/vocabs/answers_textvqa_more_than_1.txt") as f:
  vocab = f.readlines()

answer_to_idx = {}
for idx, entry in enumerate(vocab):
  answer_to_idx[entry.strip("\n")] = idx
print(len(vocab))
print(vocab[:5])

from datasets import load_dataset
dataset = load_dataset("textvqa")

让我们显示数据集中的一个样本条目:

import matplotlib.pyplot as plt
import numpy as np
idx = 5
print("Question: ", dataset["train"][idx]["question"])
print("Answers: " ,dataset["train"][idx]["answers"])
im = np.asarray(dataset["train"][idx]["image"].resize((500,500)))
plt.imshow(im)
plt.show()

3. 接下来,我们编写转换函数将图像和文本转换为模型可消费的张量 - 对于图像,我们使用 torchvision 中的转换将其转换为张量并调整到统一大小 - 对于文本,我们使用 Hugging Face 的 BertTokenizer 进行标记(并填充) - 对于答案(即标签),我们取最频繁出现的答案作为训练标签:

import torch
from torchvision import transforms
from collections import defaultdict
from transformers import BertTokenizer
from functools import partial

def transform(tokenizer, input):
  batch = {}
  image_transform = transforms.Compose([transforms.ToTensor(), transforms.Resize([224,224])])
  image = image_transform(input["image"][0].convert("RGB"))
  batch["image"] = [image]

  tokenized=tokenizer(input["question"],return_tensors='pt',padding="max_length",max_length=512)
  batch.update(tokenized)


  ans_to_count = defaultdict(int)
  for ans in input["answers"][0]:
    ans_to_count[ans] += 1
  max_value = max(ans_to_count, key=ans_to_count.get)
  ans_idx = answer_to_idx.get(max_value,0)
  batch["answers"] = torch.as_tensor([ans_idx])
  return batch

tokenizer=BertTokenizer.from_pretrained("bert-base-uncased",padding="max_length",max_length=512)
transform=partial(transform,tokenizer)
dataset.set_transform(transform)

4. 最后,我们从 torchmultimodal 导入 flava_model_for_classification 。它默认加载预训练的 FLAVA 检查点,并包括分类头。

模型前向函数将图像通过视觉编码器,问题通过文本编码器传递。然后,将图像和问题的嵌入传递到多模态编码器。最后,与 CLS 标记对应的最终嵌入通过 MLP 头传递,最终给出每个可能答案的概率分布。

from torchmultimodal.models.flava.model import flava_model_for_classification
model = flava_model_for_classification(num_classes=len(vocab))

5. 我们将数据集和模型组合在一起,在一个玩具训练循环中演示如何对模型进行 3 次迭代训练:

from torch import nn
BATCH_SIZE = 2
MAX_STEPS = 3
from torch.utils.data import DataLoader

train_dataloader = DataLoader(dataset["train"], batch_size= BATCH_SIZE)
optimizer = torch.optim.AdamW(model.parameters())


epochs = 1
for _ in range(epochs):
  for idx, batch in enumerate(train_dataloader):
    optimizer.zero_grad()
    out = model(text = batch["input_ids"], image = batch["image"], labels = batch["answers"])
    loss = out.loss
    loss.backward()
    optimizer.step()
    print(f"Loss at step {idx} = {loss}")
    if idx >= MAX_STEPS-1:
      break

结论 ¶

本教程介绍了使用 TorchMultimodal 的 FLAVA 在多模态任务上进行微调的基本方法。请参阅库中的其他示例,如 MDETR,这是一个用于目标检测的多模态模型,以及 Omnivore,这是一个跨越图像、视频和 3D 分类的多任务模型。

脚本总运行时间:(0 分钟 0.000 秒)

由 Sphinx-Gallery 生成的画廊


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源