[dynamo] Do not always skip code objects unconditionally #144820
Description
Currently, when Dynamo determines that a frame should be skipped, we will also skip tracing all future calls to the same code object. This can cause issues when skipping a frame is dependent on inputs to the function:
import torch
@torch.compile(dynamic=False)
def fn(x, n):
if n == 0:
try:
# causes frame to be skipped
torch._dynamo.graph_break()
finally:
pass
if torch.compiler.is_compiling():
return x + 1
return x - 1
print(fn(torch.ones(3), 0)) # skipped
print(fn(torch.ones(3), 1)) # skipped
import torch._dynamo
torch._dynamo.reset()
print(fn(torch.ones(3), 1)) # compiled!
print(fn(torch.ones(3), 0)) # skipped
# Output:
# tensor([0., 0., 0.])
# tensor([0., 0., 0.])
# tensor([2., 2., 2.])
# tensor([0., 0., 0.])
We see that whether fn(torch.ones(3), 1)
gets compiled is dependent on calling order! This makes it more difficult to understand the PT2 programming model. Thus, when skipping a frame is condition-dependent, we shouldn't skip the code object unconditionally - we should instead just skip the current frame and use guards to check if a future call should also skip/fall back to eager.
cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames