• 文档 >
  • torch.cuda >
  • torch.cuda.jiterator._create_multi_output_jit_fn
快捷键

torch.cuda.jiterator._create_multi_output_jit_fn

torch.cuda.jiterator._create_multi_output_jit_fn(code_string, num_outputs, **kwargs)[source][source]

创建一个支持返回一个或多个输出的元素级操作的 jiterator 生成的 cuda 内核。

参数:
  • code_string (str) – CUDA 代码字符串,由 jiterator 编译。条目函数必须通过引用返回值。

  • num_outputs (int) – 内核返回的输出数量

  • kwargs (Dict, 可选) – 生成函数的关键字参数

返回类型:

可调用

示例:

code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)

警告

此 API 处于测试阶段,未来版本中可能会有所变化。

警告

此 API 仅支持最多 8 个输入和 8 个输出


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源