Skip to content

Commit

Permalink
Use 512 threads instead of 1024 threads per CUDA block.
Browse files Browse the repository at this point in the history
  • Loading branch information
kduxin committed Oct 18, 2022
1 parent 2789fd3 commit b1aa3ea
Showing 1 changed file with 7 additions and 10 deletions.
17 changes: 7 additions & 10 deletions scripts/sentsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,7 @@
numba_logger = logging.getLogger("numba")
numba_logger.setLevel(logging.ERROR)

BLOCKDIM_X = 32
BLOCKDIM_Y = 32

BLOCKDIM_X = 512

def sentsim_as_weighted_wordsim_cuda(wordsim, weight, idseqs, device=None):
if len(idseqs) >= 1024 * 1024:
Expand All @@ -22,20 +20,22 @@ def sentsim_as_weighted_wordsim_cuda(wordsim, weight, idseqs, device=None):
for i, idseq in enumerate(idseqs):
idseqs_arr[i][: len(idseq)] = idseq

""" columns are handled by blockIdx.x, rows are handled by blockIdx.y and threadIdx.x """
sim = np.zeros((n, n), dtype=np.float64)
if device is not None:
cuda.select_device(device)
wordsim, weight, idseqs_arr, sim, lens = list(
map(cuda.to_device, [wordsim, weight, idseqs_arr, sim, lens])
)

n_block_y = (n + BLOCKDIM_X * BLOCKDIM_Y - 1) // (BLOCKDIM_X * BLOCKDIM_Y)
_weightsum_kernel[(n, n_block_y), (BLOCKDIM_X, BLOCKDIM_Y)](
n_block_y = (n + BLOCKDIM_X - 1) // BLOCKDIM_X
_weightsum_kernel[(n, n_block_y), BLOCKDIM_X](
wordsim, weight, idseqs_arr, lens, sim, n
)

sim = sim.copy_to_host()

""" Fill in the upper triangle area """
diag = np.diag(sim)
sim = sim + sim.T
sim[np.diag_indices(len(diag))] = diag
Expand All @@ -45,13 +45,10 @@ def sentsim_as_weighted_wordsim_cuda(wordsim, weight, idseqs, device=None):
@cuda.jit
def _weightsum_kernel(wordsim, weight, idseqs, lens, out, n):
x = cuda.blockIdx.x
y = (
cuda.blockIdx.y * cuda.blockDim.x * cuda.blockDim.y
+ cuda.threadIdx.y * cuda.blockDim.x
+ cuda.threadIdx.x
)
y = cuda.blockIdx.y * cuda.blockDim.x + cuda.threadIdx.x
if x >= n or y >= n:
return
""" Compute the lower triangle area (including the diagonal) only """
if x > y:
return

Expand Down

0 comments on commit b1aa3ea

Please sign in to comment.