Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Support calculate the flops of matmul with single dimension matrix #970

Merged
merged 3 commits into from
Mar 9, 2023

Conversation

HAOCHENYE
Copy link
Collaborator

@HAOCHENYE HAOCHENYE commented Feb 28, 2023

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

If user write the code like:

import torch
import torch.nn as nn

from mmengine.analysis import flop_count


class MatmulNet(nn.Module):
    """A network with a single torch.matmul operation.

    This is used for testing flop count for torch.matmul.
    """

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        x = torch.matmul(x, y)
        return x

m = 20
n = 10
p = 1
m_net = MatmulNet()
x = torch.randn(m, n)
y = torch.randn(n)
flop_dict, _ = flop_count(m_net, (x, y))

flop_count will raise an error since it requires both of x, y at least have two dims:

Traceback (most recent call last):
  File "/home/yehaochen/codebase/mmengine/work_dirs/demo_train.py", line 26, in <module>
    flop_dict, _ = flop_count(m_net, (x, y))
  File "/home/yehaochen/codebase/mmengine/mmengine/analysis/complexity_analysis.py", line 230, in flop_count
    for op, flop in flop_counter.by_operator().items():
  File "/home/yehaochen/codebase/mmengine/mmengine/analysis/jit_analysis.py", line 287, in by_operator
    stats = self._analyze()
  File "/home/yehaochen/codebase/mmengine/mmengine/analysis/jit_analysis.py", line 616, in _analyze
    op_counts = self._op_handles[kind](inputs, outputs)
  File "/home/yehaochen/codebase/mmengine/mmengine/analysis/jit_handles.py", line 216, in matmul_flop_jit
    assert input_shapes[0][-1] == input_shapes[1][  # type: ignore
IndexError: list index out of range

Modification

Please briefly describe what modification is made in this PR.

BC-breaking (Optional)

Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

Checklist

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  4. The documentation has been modified accordingly, like docstring or example tutorials.

@codecov
Copy link

codecov bot commented Feb 28, 2023

Codecov Report

❗ No coverage uploaded for pull request base (main@7e1b273). Click here to learn what that means.
Patch has no changes to coverable lines.

❗ Current head c06d3f3 differs from pull request most recent head a78ec8b. Consider uploading reports for the commit a78ec8b to get more accurate results

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #970   +/-   ##
=======================================
  Coverage        ?   76.57%           
=======================================
  Files           ?      138           
  Lines           ?    10843           
  Branches        ?     2168           
=======================================
  Hits            ?     8303           
  Misses          ?     2182           
  Partials        ?      358           
Flag Coverage Δ
unittests 76.57% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@zhouzaida zhouzaida added this to the 0.7.0 milestone Mar 5, 2023
tonysy
tonysy previously approved these changes Mar 6, 2023
Copy link
Collaborator

@tonysy tonysy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
@zhouzaida zhouzaida merged commit 8beacd3 into open-mmlab:main Mar 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants