• 文档 >
  • torch.cuda >
  • torch.cuda.jiterator._create_jit_fn
快捷键

torch.cuda.jiterator._create_jit_fn

torch.cuda.jiterator._create_jit_fn(code_string, **kwargs)[source][source]

创建一个用于元素级操作的 jiterator 生成的 cuda 内核。

代码字符串必须是一个有效的 CUDA 函数,用于描述单个元素的运算。代码字符串必须遵循以下示例中的 C++模板模式。此函数将被内联到元素级内核模板中,并即时编译。编译后的内核将被缓存到内存以及本地临时目录中。

Jiterator 生成的内核接受非连续张量,并支持广播和类型提升。

参数:
  • code_string(str)- 由 jiterator 编译的 CUDA 代码字符串。条目函数必须按值返回。

  • kwargs(字典,可选)- 生成函数的关键字参数

返回类型:

可调用

示例:

code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)

代码字符串也允许定义多个函数,最后一个函数将被视为入口函数。

示例:

code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
jitted_fn = create_jit_fn(code_string, val=0.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b)  # using default val=0.0

Jiterator 可以与 Python 注册一起使用,以覆盖运算符的 CUDA 内核。以下示例是覆盖 gelu 的 CUDA 内核为 relu。

示例:

code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
my_gelu = create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', my_gelu, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
a = torch.rand(3, device='cuda')
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))

警告

此 API 处于测试阶段,未来版本中可能会有所变化。

警告

此 API 仅支持最多 8 个输入和 1 个输出。

警告

所有输入张量必须存在于 CUDA 设备上。


© 版权所有 PyTorch 贡献者。

使用 Sphinx 构建,并使用 Read the Docs 提供的主题。

文档

PyTorch 的全面开发者文档

查看文档

教程

深入了解初学者和高级开发者的教程

查看教程

资源

查找开发资源并获得您的疑问解答

查看资源