备注
点击此处下载完整示例代码
使用 torch.compiler.set_stance
¶ 动态编译控制
作者:威廉·文
torch.compiler.set_stance
是一个 torch.compiler
API,允许您在不重新应用 torch.compile
到您的模型的情况下,改变模型在不同调用中的行为。
本菜谱提供了一些关于如何使用 torch.compiler.set_stance
的示例。
前提条件 ¶
torch >= 2.6
描述 ¶
torch.compile.set_stance
可以用作装饰器、上下文管理器或原始函数,以改变模型在不同调用中对 torch.compile
的行为。
在下面的示例中, "force_eager"
立场忽略所有 torch.compile
指令。
import torch
@torch.compile
def foo(x):
if torch.compiler.is_compiling():
# torch.compile is active
return x + 1
else:
# torch.compile is not active
return x - 1
inp = torch.zeros(3)
print(foo(inp)) # compiled, prints 1
样例装饰器使用
@torch.compiler.set_stance("force_eager")
def bar(x):
# force disable the compiler
return foo(x)
print(bar(inp)) # not compiled, prints -1
样例上下文管理器使用
with torch.compiler.set_stance("force_eager"):
print(foo(inp)) # not compiled, prints -1
样例原始函数使用
torch.compiler.set_stance("force_eager")
print(foo(inp)) # not compiled, prints -1
torch.compiler.set_stance("default")
print(foo(inp)) # compiled, prints 1
仅能在任何区域外更改立场。尝试在其他情况下进行更改将导致错误。
@torch.compile
def baz(x):
# error!
with torch.compiler.set_stance("force_eager"):
return x + 1
try:
baz(inp)
except Exception as e:
print(e)
@torch.compiler.set_stance("force_eager")
def inner(x):
return x + 1
@torch.compile
def outer(x):
# error!
return inner(x)
try:
outer(inp)
except Exception as e:
print(e)
- 其他立场包括:
默认立场:用于正常编译。
当需要重新编译时,会积极运行代码。如果存在针对输入有效的缓存编译代码,它仍然会被使用。
当重新编译函数时引发错误。
请参阅 torch.compiler.set_stance
文档页面以获取更多立场和选项。未来也可能添加更多立场/选项。
示例
防止重新编译
一些模型不需要任何重新编译——例如,您可能总是有相同形状的输入。由于重新编译可能代价高昂,我们希望在尝试重新编译时引发错误,以便我们可以检测并修复重新编译的情况。可以使用 "fail_on_recompilation"
立场来实现这一点。
@torch.compile
def my_big_model(x):
return torch.relu(x)
# first compilation
my_big_model(torch.randn(3))
with torch.compiler.set_stance("fail_on_recompile"):
my_big_model(torch.randn(3)) # no recompilation - OK
try:
my_big_model(torch.randn(4)) # recompilation - error
except Exception as e:
print(e)
如果错误中断过于破坏性,我们可以使用 "eager_on_recompile"
代替,这将导致 torch.compile
回退到急切而不是错误中断。这可能在我们不期望频繁重新编译的情况下很有用,但当我们需要时,我们宁愿承担运行急切的成本,也不愿承担重新编译的成本。
@torch.compile
def my_huge_model(x):
if torch.compiler.is_compiling():
return x + 1
else:
return x - 1
# first compilation
print(my_huge_model(torch.zeros(3))) # 1
with torch.compiler.set_stance("eager_on_recompile"):
print(my_huge_model(torch.zeros(3))) # 1
print(my_huge_model(torch.zeros(4))) # -1
print(my_huge_model(torch.zeros(3))) # 1
测量性能提升
torch.compiler.set_stance
可用于比较急切与编译性能,而无需定义单独的急切模型。
# Returns the result of running `fn()` and the time it took for `fn()` to run,
# in seconds. We use CUDA events and synchronization for the most accurate
# measurements.
def timed(fn):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = fn()
end.record()
torch.cuda.synchronize()
return result, start.elapsed_time(end) / 1000
@torch.compile
def my_gigantic_model(x, y):
x = x @ y
x = x @ y
x = x @ y
return x
inps = torch.randn(5, 5), torch.randn(5, 5)
with torch.compiler.set_stance("force_eager"):
print("eager:", timed(lambda: my_gigantic_model(*inps))[1])
# warmups
for _ in range(3):
my_gigantic_model(*inps)
print("compiled:", timed(lambda: my_gigantic_model(*inps))[1])
提前崩溃
在使用 "force_eager"
立场之前先进行一次急切迭代,然后再进行编译迭代,这有助于我们在尝试非常长的编译之前捕捉到与 torch.compile
无关的错误。
@torch.compile
def my_humongous_model(x):
return torch.sin(x, x)
try:
with torch.compiler.set_stance("force_eager"):
print(my_humongous_model(torch.randn(3)))
# this call to the compiled model won't run
print(my_humongous_model(torch.randn(3)))
except Exception as e:
print(e)
结论 ¶
在本菜谱中,我们学习了如何使用 torch.compiler.set_stance
API 来修改模型在不同调用中的行为,而无需重新应用它。该菜谱演示了如何使用 torch.compiler.set_stance
作为装饰器、上下文管理器或原始函数来控制编译立场,如 force_eager
、 default
、 eager_on_recompile
以及“fail_on_recompile”。
更多信息请参阅:torch.compiler.set_stance API 文档。
脚本总运行时间:(0 分钟 0.000 秒)