Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove dependency of finput in accGradParams of SpatialConvolutionMM #501

Open
fmassa opened this issue Nov 27, 2015 · 1 comment
Open

Comments

@fmassa
Copy link
Contributor

fmassa commented Nov 27, 2015

Currently, SpatialConvolutionMM is quite fast on CPU, but its memory requirements are too high.

As it parallellizes the computations over batch examples, it requires a huge buffer (dependent on the batch size) for storing the unfolded image.
I'm reasonably ok with it, as we could share this buffer over multiple convolutions to reduce memory usage (even though it's still requires a lot of memory).
But, as accGradParameters reuses the finput which was already computed in forward (see here and here), this sharing can't be used for training. What is worse, reusing finput in accGradParameters forces us to have another huge buffer fgradInput, of the same size as finput.
Thus, I think memory requirements are too high to use the CPU version of SpatialConvolutionMM in any real case scenario.

What do you think of the following:

  • Recompute finput in accGradParameters. This reduce the amount of buffer memory by 2 (no need of fgradInput anymore, and also buffers can be shared between modules. Forward timings stays the same, there is a penalty for backward though

There are other possibilities as well, but which hurts performance even more (but reduces the amount of memory required).

What do you think ?

@fmassa
Copy link
Contributor Author

fmassa commented Nov 28, 2015

I gave it a second thought, and I think we don't need to have a buffer finput dependent on the batch size, but only on the number of threads that are going to be used. This could reduce the memory requirements by a big margin (say 10x for 12 threads on 128 batch size), without loss in runtime.

What do you think ? Am I missing something here ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant