• 教程 >
  • 通过 PrivateUse1 促进新的后端集成
快捷键

通过私有用途 1 促进新后端集成

创建于:2025 年 4 月 1 日 | 最后更新:2025 年 4 月 1 日 | 最后验证:2024 年 11 月 5 日

在本教程中,我们将介绍一些必要的步骤,以集成位于 pytorch/pytorch 仓库之外的新后端。请注意,本教程假设您已经对 PyTorch 有基本的了解。您是 PyTorch 的高级用户。

备注

本教程仅涉及与 PrivateUse1 机制相关的部分,该机制有助于新设备的集成,其他部分将不会涉及。同时,本教程中涉及的所有模块并非都是必需的,您可以根据实际需求选择对您有帮助的模块。

什么是 PrivateUse1?

在 PyTorch 2.0 之前,PyTorch 提供了三个保留的调度键(及其对应的 Autograd 键)用于原型设计树外后端扩展,这三个调度键如下:

  • PrivateUse1/AutogradPrivateUse1

  • PrivateUse2/AutogradPrivateUse2

  • PrivateUse3/AutogradPrivateUse3

在原型验证通过后,您可以申请为新后端(如 CUDA、XLA、MPS 等)申请私钥。

然而,随着 PyTorch 的快速发展,越来越多的硬件制造商试图将他们的后端集成到 PyTorch 中,这可能会引起以下问题:

  • 每次新的后端集成都涉及大量的文件修改

  • 目前对调度密钥的数量有一个硬限制( DispatchKeySet 64 位限制)

备注

将新的后端通过 PrivateUse1 密钥集成到 PyTorch 中也有问题,因为同时集成多个后端是不可能的。幸运的是,这些树外后端很少同时使用。

鉴于上述原因,社区开始推荐通过 PrivateUse1 将新的后端集成到 PyTorch 中。

然而,之前的 PrivateUse1 机制并不能完全与新的后端集成,因为它在某些模块(如存储、AMP、分布式等)中缺少一些相关支持。

随着 Pytorch 2.1.0 的到来,针对 PrivateUse1 的新后端集成进行了一系列优化和增强,现在可以快速有效地支持新设备的集成。

如何通过 PrivateUse1 集成新的后端

在本节中,我们将讨论通过 PrivateUse1 将新后端集成到 Pytorch 中的细节,这主要包含以下部分:

  1. 为新后端注册内核。

  2. 为新后端注册生成器。

  3. 为新后端注册设备保护器。

  4. 为新后端元数据注册序列化和反序列化函数。

  5. 其他模块

注册新后端的核心

新后端可能包含一些高性能的算子实现,这些算子可以通过 C++中描述的 TORCH_LIBRARY_IMPL API 注册到调度器。这涉及到几种情况:

  1. 将新后端支持的所有前向算子注册到调度器,并注册回退,以便当新后端不支持某些算子时,这些算子可以回退到 CPU 执行,以确保功能的可用性。

at::Tensor wrapper_Custom_Tensor_add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) {
  // Implementation of add kernel in new backend
  ...
}

TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
  ...
  m.impl("add.Tensor", TORCH_FN(wrapper_Custom_Tensor_add));
  ...
}

void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
  // Add some hints about new devices that do not support and need to fall back to cpu
  at::native::cpu_fallback(op, stack);
}

TORCH_LIBRARY_IMPL(_, PrivateUse1, m) {
  m.fallback(torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
}
  1. 将从 torch::autograd::Function 到调度器的内核进行注册,如果新后端需要覆盖 AutogradPrivateUse1 ,调度器和自动微分系统将自动调用这些算子的前向和反向实现。

class CumtomSeluFunction : public torch::autograd::Function<CumtomSeluFunction> {
  // Implementation of selu kernel in new backend
}

at::Tensor wrapper_AutogradCumstom__selu(const at::Tensor & self) {
  return CumtomSeluFunction::apply(self);
}

TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) {
  ...
  m.impl("selu", TORCH_FN(wrapper_AutogradCustom__selu));
  ...
}
  1. 通过 AutocastPrivateUse1 注册希望支持自动混合精度(AMP)和回退机制的内核到调度器,当需要时,自动转换系统将自动调用这些内核。

TORCH_LIBRARY_IMPL(aten, AutocastPrivateUse1, m) {
  ...
  KERNEL_PRIVATEUSEONE(<operator>, <policy>)
  ...
}

TORCH_LIBRARY_IMPL(_, AutocastPrivateUse1, m) {
  m.fallback(torch::CppFunction::makeFallthrough());
}

需要添加的是,如果您想在新的后端支持 AMP,您需要通过 torch._register_device_module("backend_name", BackendModule) 注册一个新的 BackendModule ,并且该 BackendModule 需要具备以下 API:

  • get_amp_supported_dtype() -> List[torch.dtype]

    获取新后端在 AMP 中支持的 dtype,可能支持一个额外的 dtype

  • is_autocast_enabled() -> bool

    检查新后端是否启用了 AMP。

  • get_autocast_dtype() -> torch.dtype

    获取新后端在 AMP 中支持的 dtype ,该设置由 set_autocast_dtype 或默认的 dtype 决定,默认的 dtypetorch.float16

  • set_autocast_enabled(bool) -> None

    在新后端中启用或禁用 AMP。

  • set_autocast_dtype(dtype) -> None

    设置新后端在 AMP 中支持的 dtype ,并且 dtype 应包含在从 get_amp_supported_dtype 获取的 dtypes 中。

注册新后端生成器的生成器

必须支持对应新设备的生成器。目前, PrivateUse1 可以动态注册自定义生成器,主要分为以下步骤。

  1. 继承 GeneratorImpl 类来实现对应新后端的生成器类,并实现各种通用方法。

  2. 定义一个带有单个参数 device index 的新后端 builder

  3. 调用 REGISTER_GENERATOR_PRIVATEUSE1 宏以完成动态注册。

struct CustomGeneratorImpl : public c10::GeneratorImpl {
  // Implementation of generator in new backend
}

at::Generator make_custom_generator(c10::DeviceIndex device_index) {
  return at::make_generator<CustomGeneratorImpl>(device_index);
}

REGISTER_GENERATOR_PRIVATEUSE1(make_cumstom_generator)

为新后端注册设备保护程序 ¶

PyTorch 通过 DeviceGuard 提供与设备、流和事件切换相关的功能。此函数也适用于 PrivateUse1 键。

  1. 继承 DeviceGuardImplInterface 类以实现针对新后端的相应各种通用方法。

  2. 调用 C10_REGISTER_GUARD_IMPL 宏以完成动态注册。

struct CustomGuardImpl final : public c10::impl::DeviceGuardImplInterface {
  // Implementation of guard in new backend
}

C10_REGISTER_GUARD_IMPL(PrivateUse1, CustomGuardImpl);

注册序列化和反序列化函数以支持新后端元数据的序列化和反序列化 ¶

PyTorch 目前能够动态注册序列化和反序列化函数,以支持名为 backend_meta_ 的新后端元数据的序列化和反序列化。您可以参考以下步骤:

  1. 继承 BackendMeta 类以实现 CustomBackendMetadata 对应的新后端,新后端的各个字段可以在类中进行自定义。

  2. 实现新后端的序列化和反序列化函数,函数签名是 void(const at::Tensor&, std::unordered_map<std::string, bool>&)

  3. 调用 TensorBackendMetaRegistry 宏以完成动态注册。

struct CustomBackendMetadata : public c10::BackendMeta {
  // Implementation of backend metadata in new backend
}

void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
  // Implementation of serialization
}

void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
  // Implementation of deserialization
}

TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1, &for_serialization, &for_deserialization);

其他模块 ¶

除了上述部分之外,还有一些其他模块可以通过 PrivateUse1 进行扩展,例如 distributed collective communicationbenchmark timer 等,未来还将添加更多。关于 PrivateUse1 集成的例子是 Ascend NPU。

如何通过 Privateuse1 提高用户体验

通过 PrivateUse1 集成新设备的主要目标是满足基本功能需求,接下来要做的是提高易用性,这主要涉及以下方面。

  1. 将新后端模块注册到 Pytorch。

  2. 将 PrivateUse1 重命名为新后端的自定义名称。

  3. 生成与新的后端相关的方法和属性。

将新的后端模块注册到 Pytorch 中。

PyTorch 中的一些 CUDA 相关接口可以通过以下形式调用: torch.cuda.xxx 。因此,为了符合用户习惯,通过 PrivateUse1 机制实现的新后端也应提供类似的接口。

例如,使用 Ascend NPU :

torch._register_device_module('npu', torch_npu.npu)

在完成上述操作后,用户可以通过 Ascend NPU 通过 torch.npu.xxx 调用一些专属 API

将 PrivateUse1 重命名为新后端的自定义名称 ¶

PrivateUse1 键是新后端集成到 PyTorch 中的内部机制。对于用户来说,与新的后端强相关的自定义名称应该更加友好。

Ascend NPU 为例,第一次使用将更加用户友好。

torch.rand((2,2),device='npu:0')
torch.rand((2,2),device='privateuse1:0')

现在,PyTorch 为自命名的 PrivateUse1 后端提供了一个新的 C++/Python API,使用非常简单。

torch.rename_privateuse1_backend("npu")
c10::register_privateuse1_backend("npu")

未来工作 §

PrivateUse1 机制的改进仍在进行中,因此新模块的 PrivateUse1 集成方法将依次添加。以下是我们在积极工作的几个项目:

  • 添加 distributed collective communication 的集成方法。

  • 添加 benchmark timer 的集成方法。

结论 ¶

本教程向您介绍了通过 PrivateUse1 将新后端集成到 PyTorch 中的过程,包括但不限于操作注册、生成器注册、设备保护注册等。同时,介绍了一些提高用户体验的方法。


评分这个教程

© 版权所有 2024,PyTorch。

使用 Sphinx 构建,主题由 Read the Docs 提供。
//暂时添加调查链接

文档

访问 PyTorch 的全面开发者文档

查看文档

教程

获取初学者和高级开发者的深入教程

查看教程

资源

查找开发资源并解答您的问题

查看资源