备注:本文的先前版本于 2022 年 11 月发布。鉴于 torchvision 将于 2023 年 3 月与 PyTorch 2.0 一同发布的 0.15 版本,我们对本文进行了更新,以包含最新信息。
TorchVision 正在扩展其 Transforms API!以下是新功能:
- 您不仅可以用它们进行图像分类,还可以用于目标检测、实例和语义分割以及视频分类。
- 您可以使用新的功能转换来转换视频、边界框和分割掩码。
该 API 与之前的版本完全向后兼容,并保持不变,以协助迁移和采用。我们现在将这个新 API 作为 Beta 版本发布在 torchvision.transforms.v2 命名空间中,我们非常希望得到您的早期反馈以改进其功能。如果您有任何问题或建议,请与我们联系。
当前 Transforms 的限制
TorchVision(又称 V1)现有的 Transforms API 仅支持单张图像。因此,它只能用于分类任务:
from torchvision import transforms
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs = trans(imgs)
上述方法不支持目标检测和分割。这种限制使得任何非分类的计算机视觉任务都成了二等公民,因为无法使用 Transforms API 进行必要的增强。从历史上看,这使得使用 TorchVision 的原始方法训练高精度模型变得困难,因此我们的模型库落后于 SoTA 几个百分点。
为了绕过这个限制,TorchVision 在其参考脚本中提供了自定义实现,展示了如何在每个任务中执行增强。尽管这种做法使我们能够训练高精度的分类、目标检测和分割模型,但它是一种临时解决方案,使得这些转换无法从 TorchVision 二进制文件中导入。
新的 Transforms API
Transforms V2 API 支持视频、边界框和分割掩码,这意味着它为许多计算机视觉任务提供了原生支持。新的解决方案是一个即插即用的替代品:
import torchvision.transforms.v2 as transforms
# Exactly the same interface as V1:
trans = transforms.Compose([
transforms.ColorJitter(contrast=0.5),
transforms.RandomRotation(30),
transforms.CenterCrop(480),
])
imgs, bboxes, labels = trans(imgs, bboxes, labels)
新的转换类可以接收任意数量的输入,而不强制执行特定的顺序或结构:
# Already supported:
trans(imgs) # Image Classification
trans(videos) # Video Tasks
trans(imgs, bboxes, labels) # Object Detection
trans(imgs, bboxes, masks, labels) # Instance Segmentation
trans(imgs, masks) # Semantic Segmentation
trans({"image": imgs, "box": bboxes, "tag": labels}) # Arbitrary Structure
# Future support:
trans(imgs, bboxes, labels, keypoints) # Keypoint Detection
trans(stereo_images, disparities, masks) # Depth Perception
trans(image1, image2, optical_flows, masks) # Optical Flow
trans(imgs_or_videos, labels) # MixUp/CutMix-style Transforms
转换类确保对所有输入应用相同的随机转换,以确保结果一致。
功能性 API 已更新,以支持所有必要的信号处理内核(调整大小、裁剪、仿射变换、填充等)以处理所有输入:
from torchvision.transforms.v2 import functional as F
# High-level dispatcher, accepts any supported input type, fully BC
F.resize(inpt, size=[224, 224])
# Image tensor kernel
F.resize_image_tensor(img_tensor, size=[224, 224], antialias=True)
# PIL image kernel
F.resize_image_pil(img_pil, size=[224, 224], interpolation=BILINEAR)
# Video kernel
F.resize_video(video, size=[224, 224], antialias=True)
# Mask kernel
F.resize_mask(mask, size=[224, 224])
# Bounding box kernel
F.resize_bounding_box(bbox, size=[224, 224], spatial_size=[256, 256])
在底层,API 使用 Tensor 子类封装输入,附加有用的元数据并调度到正确的内核。为了使您的数据与这些新转换兼容,您可以使用提供的 dataset 包装器,它应该适用于大多数 torchvision 内置数据集,或者您可以手动将数据封装到 Datapoints 中:
from torchvision.datasets import wrap_dataset_for_transforms_v2
ds = CocoDetection(..., transforms=v2_transforms)
ds = wrap_dataset_for_transforms_v2(ds) # data is now compatible with transforms v2!
# Or wrap your data manually using the lower-level Datapoint classes:
from torchvision import datapoints
imgs = datapoints.Image(images)
vids = datapoints.Video(videos)
masks = datapoints.Mask(target["masks“])
bboxes = datapoints.BoundingBox(target["boxes“], format=”XYXY”, spatial_size=imgs.shape)
除了新的 API 之外,我们还提供了可导入的几种数据增强实现,这些增强在 SoTA 研究中得到应用,例如大规模抖动、自动增强方法以及多种新的几何、颜色和类型转换变换。
该 API 继续支持图像的 PIL 和 Tensor 后端,包括单张或批量输入,并在功能 API 和类 API 上保持 JIT 脚本性。新 API 已验证与之前的实现具有相同的准确性。
一个端到端示例
下面是使用以下图像的新 API 示例。它既适用于 PIL 图像也适用于张量。更多示例和教程,请查看我们的画廊!
from torchvision import io, utils
from torchvision import datapoints
from torchvision.transforms import v2 as T
from torchvision.transforms.v2 import functional as F
# Defining and wrapping input to appropriate Tensor Subclasses
path = "COCO_val2014_000000418825.jpg"
img = datapoints.Image(io.read_image(path))
# img = PIL.Image.open(path)
bboxes = datapoints.BoundingBox(
[[2, 0, 206, 253], [396, 92, 479, 241], [328, 253, 417, 332],
[148, 68, 256, 182], [93, 158, 170, 260], [432, 0, 438, 26],
[422, 0, 480, 25], [419, 39, 424, 52], [448, 37, 456, 62],
[435, 43, 437, 50], [461, 36, 469, 63], [461, 75, 469, 94],
[469, 36, 480, 64], [440, 37, 446, 56], [398, 233, 480, 304],
[452, 39, 463, 63], [424, 38, 429, 50]],
format=datapoints.BoundingBoxFormat.XYXY,
spatial_size=F.get_spatial_size(img),
)
labels = [59, 58, 50, 64, 76, 74, 74, 74, 74, 74, 74, 74, 74, 74, 50, 74, 74]
# Defining and applying Transforms V2
trans = T.Compose(
[
T.ColorJitter(contrast=0.5),
T.RandomRotation(30),
T.CenterCrop(480),
]
)
img, bboxes, labels = trans(img, bboxes, labels)
# Visualizing results
viz = utils.draw_bounding_boxes(F.to_image_tensor(img), boxes=bboxes)
F.to_pil_image(viz).show()
开发里程碑和未来工作
这里是我们开发中的位置:
- 设计 API
- 编写转换视频、边界框、遮罩和标签的内核
- 在新 API 上重写所有现有的转换类(稳定版+引用)
- 图像分类
- 视频分类
- 目标检测
- 实例分割
- 语义分割
- 验证所有支持的任务和后端的新 API 的准确性
- 速度基准和性能优化(进行中 - 计划于 12 月)
- 毕业于原型(计划于 Q1)
- 添加深度感知、关键点检测、光流等功能支持(未来)
- 添加对批处理变换如 MixUp 和 CutMix 的平滑支持
我们非常希望得到您的反馈以改进其功能。如果您有任何问题或建议,请与我们联系。