Skip to content

[dynamo] Do not always skip code objects unconditionally #144820

Open
@williamwen42

Description

Currently, when Dynamo determines that a frame should be skipped, we will also skip tracing all future calls to the same code object. This can cause issues when skipping a frame is dependent on inputs to the function:

import torch

@torch.compile(dynamic=False)
def fn(x, n):
    if n == 0:
        try:
            # causes frame to be skipped
            torch._dynamo.graph_break()
        finally:
            pass
    if torch.compiler.is_compiling():
        return x + 1
    return x - 1

print(fn(torch.ones(3), 0)) # skipped
print(fn(torch.ones(3), 1)) # skipped

import torch._dynamo
torch._dynamo.reset()

print(fn(torch.ones(3), 1)) # compiled!
print(fn(torch.ones(3), 0)) # skipped

# Output:
# tensor([0., 0., 0.])
# tensor([0., 0., 0.])
# tensor([2., 2., 2.])
# tensor([0., 0., 0.])

We see that whether fn(torch.ones(3), 1) gets compiled is dependent on calling order! This makes it more difficult to understand the PT2 programming model. Thus, when skipping a frame is condition-dependent, we shouldn't skip the code object unconditionally - we should instead just skip the current frame and use guards to check if a future call should also skip/fall back to eager.

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @chenyang78 @kadeng @amjames

Metadata

Assignees

No one assigned

    Labels

    module: dynamooncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions