-
-
Notifications
You must be signed in to change notification settings - Fork 294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Try Jax as GPU accelerator #1213
base: master
Are you sure you want to change the base?
Conversation
This is interesting. I've never heard of JAX, but will definitely check it out. Speaking of QR decomposition, I remembered the ugly workaround for old numpy. I've not yet compared the speed, but the loop looks like it might not be the most efficient option... scikit-rf/skrf/vectorFitting.py Lines 804 to 825 in ecaa7cb
|
A_ri = np.hstack((A.real, A.imag)) | ||
for i in range(n_responses): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just realized you have already started to replace the loop. But can't you simply use the direct method (a few lines above), which is currently commented out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The direct version is possible since some recent numpy versions, however it wasn't faster than the loop in first quick tests.
This is a first try to use Jax as GPU accelerator.
VectorFitting process is mainly depending on
qr
decomposition, so I just replacednumpy.linalg.qr
withjax.numpy.linalg.qr
.I did some tests with the s4p file from here.
The matrix shape for
qr
composition starts at (16,2992,20) and ends at (16,2992,640) after adding some poles, so i benchmarked the last dimension from 20 to 2000.Tests are done on Ubuntu 24.04, Ryzen 5 5600X, 32GB RAM, RTX3070 (8GB)
Result with jax[cpu]:
Bascally no difference
Result with jax[cuda12]:
Estimated ~30% .. 50% gain on GPU.
Now try to fit
default.s4p
with this code:Original: 30.7s
with jax[cpu]: 27.3s
with jax[cuda12]: 20.9s
Original on Windows: 39s
with jax[cpu]: 43s
Original on WSL: 32s
with jax[cpu]: 33s