Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A few more pyright fixes #23453

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

superbobry
Copy link
Member

No description provided.

except FileNotFoundError:
pass
if filename in self._file_cache:
source = self._file_cache[filename]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a change in program logic – this block now executes even if inspect.getmodulename(filename) is False. I'm not sure of the ramifications of that here: is that an intended change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was a bug in the original implementation.

If the filename is already cached (and it could've only been cached if inspect.getmodulename(filename) returned True AFAICT), this method would raise NotImplementedError.

@@ -1266,6 +1268,7 @@ def check(cond):
dims = []
while flat_assignment.size > 1:
stride = flat_assignment[1]
i = 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we skip this fix please? Or you can send this as a seperate CL to me. I would have to check if this is correct.

SKipping this fix, is not going to break anything since we don't run pyright in our CI.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can revert, but I'm curious why this can be problematic?

The fix is purely technical: flat_assigments.size > 1 guarantees that the loop will iterate at least one, but there isn't a way to express that in the type system.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then send a separate CL please? I want to make sure this is correct and doesn't cause any problems because a lot of code goes via this code.

@@ -575,6 +576,7 @@ def shard_shape(self, global_shape: Shape) -> Shape:
f'devices passed to PmapSharding. Got sharded dimension {sharded_dim} '
f'with value {global_shape[sharded_dim]} in shape {global_shape} and '
f'the number of devices={len(self._device_assignment)}')
assert sharded_shape is not None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need this? Looks like this will always be True?

Can you just add a else branch to error after line 570 or silence pyright here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only True if s is guaranteed to either be sharding_specs.{Unstacked,Chunk}. Otherwise, it's undefined.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Won't adding an else branch like I said above silence the error?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants