Skip to content

Commit

Permalink
More separation of CUDA from ALTCPU
Browse files Browse the repository at this point in the history
  • Loading branch information
jahooker committed Jan 17, 2023
1 parent 1507ef8 commit b160585
Show file tree
Hide file tree
Showing 10 changed files with 463 additions and 508 deletions.
30 changes: 7 additions & 23 deletions src/acc/acc_helper_functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -528,33 +528,17 @@ void lowPassFilterMapGPU(
XFLOAT edge_high = std::min((double) Xdim, (ires_filter + filter_edge_halfwidth) / (RFLOAT) ori_size); // in 1/pix
XFLOAT edge_width = edge_high - edge_low;

int blocks = ceilf(
const int blocks = ceilf(
(float) ((size_t) Xdim * (size_t) Ydim * (size_t) Zdim) /
(float) CFTT_BLOCK_SIZE
);

// It's a shame do_highpass isn't const
if (do_highpass) {
AccUtilities::frequencyPass<true>(
blocks, CFTT_BLOCK_SIZE, img_in.getStream(),
img_in.getAccPtr(),
ori_size,
Xdim, Ydim, Zdim,
edge_low, edge_width, edge_high,
(XFLOAT) angpix,
(size_t) Xdim * (size_t) Ydim * (size_t) Zdim
);
} else {
AccUtilities::frequencyPass<false>(
blocks, CFTT_BLOCK_SIZE, img_in.getStream(),
img_in.getAccPtr(),
ori_size,
Xdim, Ydim, Zdim,
edge_low, edge_width, edge_high,
(XFLOAT) angpix,
(size_t) Xdim * (size_t) Ydim * (size_t) Zdim
);
}
AccUtilities::frequencyPass<acc::type>(
blocks, CFTT_BLOCK_SIZE, img_in.getStream(),
img_in.getAccPtr(), ori_size, Xdim, Ydim, Zdim,
edge_low, edge_width, edge_high, (XFLOAT) angpix,
(size_t) Xdim * (size_t) Ydim * (size_t) Zdim, do_highpass
);
LAUNCH_HANDLE_ERROR(cudaGetLastError());
}

Expand Down
170 changes: 12 additions & 158 deletions src/acc/acc_helper_functions_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,164 +310,18 @@ void runWavgKernel(
cudaStream_t stream
) {
// cudaFuncSetCacheConfig(cuda_kernel_wavg_fast, cudaFuncCachePreferShared);

if (ctf_premultiplied) {
if (refs_are_ctf_corrected) {
if (data_is_3D) {
AccUtilities::kernel_wavg<true, true, true, true, WAVG_BLOCK_SIZE_DATA3D>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
} else if (projector.mdlZ != 0) {
AccUtilities::kernel_wavg<true, true, true, false, WAVG_BLOCK_SIZE>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
} else {
AccUtilities::kernel_wavg<true, true, false, false, WAVG_BLOCK_SIZE>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
}
} else {
if (data_is_3D) {
AccUtilities::kernel_wavg<true, false, true, true, WAVG_BLOCK_SIZE_DATA3D>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
} else if (projector.mdlZ != 0) {
AccUtilities::kernel_wavg<true, false, true, false,WAVG_BLOCK_SIZE>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
} else {
AccUtilities::kernel_wavg<true, false, false, false, WAVG_BLOCK_SIZE>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
}
}
} else { // if (!ctf_premultiplied)
if (refs_are_ctf_corrected) {
if (data_is_3D) {
AccUtilities::kernel_wavg<false, true, true, true, WAVG_BLOCK_SIZE_DATA3D>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
} else if (projector.mdlZ != 0) {
AccUtilities::kernel_wavg<false, true, true, false, WAVG_BLOCK_SIZE>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
} else {
AccUtilities::kernel_wavg<false, true, false, false, WAVG_BLOCK_SIZE>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
}
} else {
if (data_is_3D) {
AccUtilities::kernel_wavg<false, false, true, true, WAVG_BLOCK_SIZE_DATA3D>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
} else if (projector.mdlZ != 0) {
AccUtilities::kernel_wavg<false, false, true, false, WAVG_BLOCK_SIZE>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
} else {
AccUtilities::kernel_wavg<false, false, false, false, WAVG_BLOCK_SIZE>(
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag,
trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream
);
}
}
}
AccUtilities::kernel_wavg<acc::type>(
data_is_3D ? WAVG_BLOCK_SIZE_DATA3D : WAVG_BLOCK_SIZE,
eulers, projector, image_size, orientation_num,
Fimg_real, Fimg_imag, trans_x, trans_y, trans_z,
sorted_weights, ctfs,
wdiff2s_parts, wdiff2s_AA, wdiff2s_XA,
translation_num,
(XFLOAT) op.sum_weight[img_id],
(XFLOAT) op.significant_weight[img_id],
part_scale, stream,
ctf_premultiplied, refs_are_ctf_corrected, data_is_3D || projector.mdlZ != 0, data_is_3D
);
LAUNCH_HANDLE_ERROR(cudaGetLastError());
}

Expand Down
42 changes: 14 additions & 28 deletions src/acc/acc_ml_optimiser_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -615,34 +615,20 @@ void getFourierTransformsAndCtfs(
spectrumAndXi2.accInit(0);
spectrumAndXi2.streamSync();

int gridSize = ceil((float) accMLO->transformer1.fouriers.getSize() / (float) POWERCLASS_BLOCK_SIZE);
if (accMLO->dataIs3D) {
AccUtilities::powerClass<true>(
gridSize, POWERCLASS_BLOCK_SIZE,
accMLO->transformer1.fouriers.getAccPtr(),
spectrumAndXi2.getAccPtr(),
accMLO->transformer1.fouriers.getSize(),
spectrumAndXi2.getSize() - 1,
accMLO->transformer1.sizef[0],
accMLO->transformer1.sizef[1],
accMLO->transformer1.sizef[2],
baseMLO->image_current_size[optics_group] / 2 + 1, // note: NOT baseMLO->image_full_size[optics_group] / 2 + 1
spectrumAndXi2.getAccPtr() + spectrumAndXi2.getSize() - 1
); // last element is the highres_Xi2
} else {
AccUtilities::powerClass<false>(
gridSize, POWERCLASS_BLOCK_SIZE,
accMLO->transformer1.fouriers.getAccPtr(),
spectrumAndXi2.getAccPtr(),
accMLO->transformer1.fouriers.getSize(),
spectrumAndXi2.getSize() - 1,
accMLO->transformer1.sizef[0],
accMLO->transformer1.sizef[1],
accMLO->transformer1.sizef[2],
baseMLO->image_current_size[optics_group] / 2 + 1, // note: NOT baseMLO->image_full_size[optics_group] / 2 + 1
spectrumAndXi2.getAccPtr() + spectrumAndXi2.getSize() - 1
); // last element is the highres_Xi2
}
const int gridSize = ceil((float) accMLO->transformer1.fouriers.getSize() / (float) POWERCLASS_BLOCK_SIZE);
AccUtilities::powerClass<acc::type>(
gridSize, POWERCLASS_BLOCK_SIZE,
accMLO->transformer1.fouriers.getAccPtr(),
spectrumAndXi2.getAccPtr(),
accMLO->transformer1.fouriers.getSize(),
spectrumAndXi2.getSize() - 1,
accMLO->transformer1.sizef[0],
accMLO->transformer1.sizef[1],
accMLO->transformer1.sizef[2],
baseMLO->image_current_size[optics_group] / 2 + 1, // note: NOT baseMLO->image_full_size[optics_group] / 2 + 1
spectrumAndXi2.getAccPtr() + spectrumAndXi2.getSize() - 1, // highres_Xi2
accMLO->dataIs3D
);
LAUNCH_PRIVATE_ERROR(cudaGetLastError(), accMLO->errorStatus);

spectrumAndXi2.streamSync();
Expand Down
105 changes: 16 additions & 89 deletions src/acc/acc_projector_plan_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ void AccProjectorPlan::setup(
unsigned iclass,
bool coarse, bool inverseMatrix, bool do_skip_align, bool do_skip_rotate,
int orientational_prior_mode,
Matrix<RFLOAT> &L_,
Matrix<RFLOAT> &R_
Matrix<RFLOAT> &L_, Matrix<RFLOAT> &R_
) {
TICTOC(TIMING_TOP, ({

Expand Down Expand Up @@ -242,97 +241,25 @@ void AccProjectorPlan::setup(
adjustR.putOnDevice();
}

int grid_size = ceil((float) orientation_num / (float) BLOCK_SIZE);
const int grid_size = ceil((float) orientation_num / (float) BLOCK_SIZE);

/// WARNING: Code multiplication!

if (inverseMatrix) {
if (sampling.is_3D) {
if (doL && doR) {
AccUtilities::acc_make_eulers_3D<true, true, true>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num, adjustL.getAccPtr(), adjustR.getAccPtr()
);
} else if (doL) {
AccUtilities::acc_make_eulers_3D<true, true, false>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num, adjustL.getAccPtr(), NULL
);
} else if (doR) {
AccUtilities::acc_make_eulers_3D<true, false, true>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num, NULL, adjustR.getAccPtr()
);
} else {
AccUtilities::acc_make_eulers_3D<true, false, false>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num, NULL, NULL
);
}
} else {
AccUtilities::acc_make_eulers_2D<true>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), eulers.getAccPtr(),
orientation_num
);
}
if (sampling.is_3D) {
AccUtilities::acc_make_eulers_3D<acc::type>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num,
doL ? adjustL.getAccPtr() : nullptr,
doR ? adjustR.getAccPtr() : nullptr,
doL, doR, inverseMatrix
);
} else {
if (sampling.is_3D) {
if (doL && doR) {
AccUtilities::acc_make_eulers_3D<false, true, true>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num, adjustL.getAccPtr(), adjustR.getAccPtr()
);
} else if (doL) {
AccUtilities::acc_make_eulers_3D<false, true, false>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num, adjustL.getAccPtr(), NULL
);
} else if (doR) {
AccUtilities::acc_make_eulers_3D<false, false, true>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num, NULL, adjustR.getAccPtr()
);
} else {
AccUtilities::acc_make_eulers_3D<false, false, false>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(),
orientation_num, NULL, NULL
);
}
} else {
AccUtilities::acc_make_eulers_2D<false>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), eulers.getAccPtr(),
orientation_num
);
}
AccUtilities::acc_make_eulers_2D<acc::type>(
grid_size, BLOCK_SIZE, eulers.getStream(),
alphas.getAccPtr(), eulers.getAccPtr(),
orientation_num, inverseMatrix
);
}

// Why can't this can all just be:

// if (sampling.is_3D) {
// AccUtilities::acc_make_eulers_3D<inverseMatrix, doL, doR>(
// grid_size, BLOCK_SIZE, eulers.getStream(),
// alphas.getAccPtr(), betas.getAccPtr(), gammas.getAccPtr(), eulers.getAccPtr(), orientation_num,
// doL ? adjustL.getAccPtr() : NULL, doR ? adjustR.getAccPtr() : NULL
// );
// } else {
// AccUtilities::acc_make_eulers_2D<inverseMatrix>(
// grid_size, BLOCK_SIZE, eulers.getStream(),
// alphas.getAccPtr(), eulers.getAccPtr(), orientation_num
// );
// }

// ? Can we make doL and doR compile-time constants?

}));
}

Expand Down
Loading

0 comments on commit b160585

Please sign in to comment.