from jax import numpy as jnp from jwave.geometry import _circ_mask def three_circles(N: tuple) -> jnp.ndarray: """ Generate a 3-circle phantom. Args: N: The size of the phantom. Returns: jnp.ndarray: The phantom. """ radius = sum(N) / float(len(N)) mask1 = _circ_mask(N, radius * 0.05, (int(N[0] / 2 + N[0] / 8), int(N[1] / 2))) mask2 = _circ_mask( N, radius * 0.1, (int(N[0] / 2 - N[0] / 8), int(N[1] / 2 + N[1] / 6)) ) mask3 = _circ_mask(N, radius * 0.15, (int(N[0] / 2), int(N[1] / 2))) p0 = 5.0 * mask1 + 3.0 * mask2 + 4.0 * mask3 return jnp.expand_dims(p0, -1)