Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move glu to Aten(CPU) #33179

Closed
wants to merge 4 commits into from
Closed

Move glu to Aten(CPU) #33179

wants to merge 4 commits into from

Conversation

XiaobingSuper
Copy link
Collaborator

This PR move glu to Aten(CPU).
Test script:

import torch
import torch.nn.functional as F
import time

torch.manual_seed(0)

def _time():
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return time.time()

device = "cpu"

#warm up
for n in [10, 100, 1000, 10000]:
    input = torch.randn(128, n, requires_grad=True, device=device)
    grad_output = torch.ones(128, n // 2, device=device)
    for i in range(1000):
        output = F.glu(input)
        output.backward(grad_output)

for n in [10, 100, 1000, 10000]:
    fwd_t = 0
    bwd_t = 0
    input = torch.randn(128, n, requires_grad=True, device=device)
    grad_output = torch.ones(128, n // 2, device=device)
    for i in range(10000):
        t1 = _time()
        output = F.glu(input)
        t2 = _time()
        output.backward(grad_output)
        t3 = _time()
        fwd_t = fwd_t + (t2 -t1)
        bwd_t = bwd_t + (t3 - t2)
    fwd_avg = fwd_t / 10000 * 1000
    bwd_avg = bwd_t / 10000 * 1000
    print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)."
          % (n, fwd_avg, bwd_avg))

Test device: skx-8180.
Before:

input size(128, 10) forward time is 0.04 (ms); backwad avg time is 0.08 (ms).
input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.14 (ms).
input size(128, 1000) forward time is 0.11 (ms); backwad avg time is 0.31 (ms).
input size(128, 10000) forward time is 1.52 (ms); backwad avg time is 2.04 (ms).

After:

input size(128, 10) forward time is 0.02 (ms); backwad avg time is 0.05 (ms).
input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.09 (ms).
input size(128, 1000) forward time is 0.07 (ms); backwad avg time is 0.17 (ms).
input size(128, 10000) forward time is 0.13 (ms); backwad avg time is 1.03 (ms).

Fix #24707, #24708.

@dr-ci
Copy link

dr-ci bot commented Feb 11, 2020

💊 CircleCI build failures summary and remediations

As of commit 049547b:

Commit 049547b was recently pushed. Waiting for builds...


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker.

This comment has been revised 16 times.

@XiaobingSuper
Copy link
Collaborator Author

I also removed some dead codes in TH.

Copy link
Collaborator

@xuhdev xuhdev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggested some changes from Python-style code to C++-style code

aten/src/ATen/native/cpu/Activation.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/cpu/Activation.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/cpu/Activation.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/cpu/Activation.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/cpu/Activation.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/Activation.cpp Outdated Show resolved Hide resolved
@gchanan gchanan added the module: porting Issues related to porting TH/THNN legacy to ATen native label Feb 11, 2020
@VitalyFedyunin
Copy link
Contributor

Wow, impressive cleanup. I will run bigger tests to make sure there is no dependencies.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@gchanan
Copy link
Contributor

gchanan commented Feb 11, 2020

Sorry for the really slow response on #26687 -- this looks pretty good.

Some thoughts:
do you use ghstack? It would be nice to be able to separate out the implementation cleanup from the code cleanup in a nicer way. In particular, we have this kind of unfortunate overlap where #26687 moves CUDA forward and this moves CPU backward, and it would be nice to combine those in a way that is separate from the larger cleanup stuff here.

@ezyang ezyang removed their request for review February 12, 2020 01:30
@XiaobingSuper
Copy link
Collaborator Author

@VitalyFedyunin , I just change the code style according to @xuhdev's suggestion and move THNN doc to THCUNN. please re-landing this PR. Thanks!

@XiaobingSuper
Copy link
Collaborator Author

Sorry for the really slow response on #26687 -- this looks pretty good.

Some thoughts:
do you use ghstack? It would be nice to be able to separate out the implementation cleanup from the code cleanup in a nicer way. In particular, we have this kind of unfortunate overlap where #26687 moves CUDA forward and this moves CPU backward, and it would be nice to combine those in a way that is separate from the larger cleanup stuff here.

I didn't use ghstake, I will try it. Yes, ther has a overlap to #26687, for cuda part, perhaps we can port forward and backward code together.

@zhangguanheng66 zhangguanheng66 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Feb 12, 2020
@ezyang
Copy link
Contributor

ezyang commented Feb 12, 2020

WOAH this kills THNN. Nice work!!

@XiaobingSuper
Copy link
Collaborator Author

just code rebased.

@ezyang
Copy link
Contributor

ezyang commented Feb 14, 2020

@gchanan can you please instruct how to resolve merge conflicts with your other PR

@XiaobingSuper
Copy link
Collaborator Author

@gchanan , please tell me when your all PRs are merged. Thanks!

@XiaobingSuper XiaobingSuper force-pushed the glu branch 2 times, most recently from 5f244a0 to 6af02a8 Compare February 23, 2020 09:50
@XiaobingSuper XiaobingSuper requested a review from ezyang February 23, 2020 09:54
@XiaobingSuper
Copy link
Collaborator Author

Code rebased.

@XiaobingSuper
Copy link
Collaborator Author

@VitalyFedyunin

@XiaobingSuper
Copy link
Collaborator Author

@VitalyFedyunin, @ezyang , code was rebased.

Copy link
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@VitalyFedyunin
Copy link
Contributor

Overall looks good, I still have couple files to check and internal tests to run. But anyway almost 3k removed lines - this deserves the medal!

@VitalyFedyunin
Copy link
Contributor

There are few internal dependencies on THNN/generic/THNN.h I will clean them up and land this PR after.

@VitalyFedyunin
Copy link
Contributor

Btw, it is around 34k lines of TH code, you are killing little less than 10%!

@XiaobingSuper
Copy link
Collaborator Author

Btw, it is around 34k lines of TH code, you are killing little less than 10%!

Yes, there still have many code to be killed, could you update the cpu ops state
in #24507, I will check which one I can do. Thanks!

@facebook-github-bot
Copy link
Contributor

@VitalyFedyunin merged this pull request in b678256.

ttumiel pushed a commit to ttumiel/pytorch that referenced this pull request Mar 4, 2020
Summary:
This PR move glu to Aten(CPU).
Test script:
```
import torch
import torch.nn.functional as F
import time

torch.manual_seed(0)

def _time():
    if torch.cuda.is_available():
        torch.cuda.synchronize()
    return time.time()

device = "cpu"

#warm up
for n in [10, 100, 1000, 10000]:
    input = torch.randn(128, n, requires_grad=True, device=device)
    grad_output = torch.ones(128, n // 2, device=device)
    for i in range(1000):
        output = F.glu(input)
        output.backward(grad_output)

for n in [10, 100, 1000, 10000]:
    fwd_t = 0
    bwd_t = 0
    input = torch.randn(128, n, requires_grad=True, device=device)
    grad_output = torch.ones(128, n // 2, device=device)
    for i in range(10000):
        t1 = _time()
        output = F.glu(input)
        t2 = _time()
        output.backward(grad_output)
        t3 = _time()
        fwd_t = fwd_t + (t2 -t1)
        bwd_t = bwd_t + (t3 - t2)
    fwd_avg = fwd_t / 10000 * 1000
    bwd_avg = bwd_t / 10000 * 1000
    print("input size(128, %d) forward time is %.2f (ms); backwad avg time is %.2f (ms)."
          % (n, fwd_avg, bwd_avg))
```
Test device: **skx-8180.**
Before:
```
input size(128, 10) forward time is 0.04 (ms); backwad avg time is 0.08 (ms).
input size(128, 100) forward time is 0.06 (ms); backwad avg time is 0.14 (ms).
input size(128, 1000) forward time is 0.11 (ms); backwad avg time is 0.31 (ms).
input size(128, 10000) forward time is 1.52 (ms); backwad avg time is 2.04 (ms).
```
After:
```
input size(128, 10) forward time is 0.02 (ms); backwad avg time is 0.05 (ms).
input size(128, 100) forward time is 0.04 (ms); backwad avg time is 0.09 (ms).
input size(128, 1000) forward time is 0.07 (ms); backwad avg time is 0.17 (ms).
input size(128, 10000) forward time is 0.13 (ms); backwad avg time is 1.03 (ms).
```
Fix pytorch#24707, pytorch#24708.
Pull Request resolved: pytorch#33179

Differential Revision: D19839835

Pulled By: VitalyFedyunin

fbshipit-source-id: e4d3438556a1068da2c4a7e573d6bbf8d2a6e2b9
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Merged module: porting Issues related to porting TH/THNN legacy to ATen native open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Migrate glu from the TH to Aten (CPU)
9 participants