Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Try Jax as GPU accelerator #1213

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from

Conversation

FranzForstmayr
Copy link
Collaborator

@FranzForstmayr FranzForstmayr commented Nov 26, 2024

This is a first try to use Jax as GPU accelerator.

VectorFitting process is mainly depending on qr decomposition, so I just replaced numpy.linalg.qr with jax.numpy.linalg.qr.

I did some tests with the s4p file from here.
The matrix shape for qr composition starts at (16,2992,20) and ends at (16,2992,640) after adding some poles, so i benchmarked the last dimension from 20 to 2000.

import numpy as np
import perfplot
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

rng = np.random.default_rng()

qr_vmap = jax.vmap(lambda e: jnp.linalg.qr(e, mode='r'), in_axes=0)

perfplot.show(
    setup=lambda n: rng.random((16, 2992, n)),
    kernels=[
        lambda arr: np.linalg.qr(arr, mode="r"),
        lambda arr: np.array([np.linalg.qr(e, mode="r") for e in arr]),
        lambda arr: np.array([jnp.linalg.qr(e, mode="r") for e in arr]),
        lambda arr: np.array(qr_vmap(arr)),
    ],
    labels=["np vectorized", "np loop", "jax loop", "jax vectorized"],
    n_range=np.linspace(10,2000,20),
    xlabel="matrix shape(16,2992,x)",
)

Tests are done on Ubuntu 24.04, Ryzen 5 5600X, 32GB RAM, RTX3070 (8GB)

Result with jax[cpu]:
cpu_2000

Bascally no difference

Result with jax[cuda12]:
gpu_2000

Estimated ~30% .. 50% gain on GPU.

Now try to fit default.s4p with this code:

import skrf as rf

net = rf.Network("default.s4p")
vf = rf.VectorFitting(net)
vf.auto_fit(n_poles_add=20, model_order_max=500)

print(vf.get_rms_error())

Original: 30.7s
with jax[cpu]: 27.3s
with jax[cuda12]: 20.9s
Original on Windows: 39s
with jax[cpu]: 43s
Original on WSL: 32s
with jax[cpu]: 33s

@Vinc0110
Copy link
Collaborator

Vinc0110 commented Nov 27, 2024

This is interesting. I've never heard of JAX, but will definitely check it out.

Speaking of QR decomposition, I remembered the ugly workaround for old numpy. I've not yet compared the speed, but the loop looks like it might not be the most efficient option...

# QR decomposition
# R = np.linalg.qr(np.hstack((A.real, A.imag)), 'r')
# direct QR of stacked matrices for linalg.qr() only works with numpy>=1.22.0
# workaround for old numpy:
R = np.empty((n_responses, dim_k, n_cols_unused + n_cols_used))
A_ri = np.hstack((A.real, A.imag))
for i in range(n_responses):
R[i] = np.linalg.qr(A_ri[i], mode='r')
# only R22 is required to solve for c_res and d_res
# R12 and R22 can have a different number of rows, depending on K
if dim_k == 2 * n_freqs:
n_rows_r12 = n_freqs
n_rows_r22 = n_freqs
else:
n_rows_r12 = n_cols_unused
n_rows_r22 = n_cols_used
R22 = R[:, n_rows_r12:, n_cols_unused:]
# weighting
R22 = weights_responses[:, None, None] * R22

A_ri = np.hstack((A.real, A.imag))
for i in range(n_responses):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized you have already started to replace the loop. But can't you simply use the direct method (a few lines above), which is currently commented out?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The direct version is possible since some recent numpy versions, however it wasn't faster than the loop in first quick tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants