Skip to content

Commit

Permalink
Fix GridEncoding gradient with max_level / Interpolation::Nearest
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom94 committed Feb 11, 2022
1 parent 9b9d046 commit d829746
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions include/tiny-cuda-nn/encodings/grid.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ __global__ void kernel_grid(
encoded_positions[i + (level * N_FEATURES_PER_LEVEL + f) * num_elements] = (T)0.0f;
}

// Gradient is zero for zeroed-out dimensions.
if (dy_dx) {
#pragma unroll
for (uint32_t grad_dim = 0; grad_dim < N_POS_DIMS; ++grad_dim) {
const uint32_t fan_out_grad = num_grid_features * N_POS_DIMS;
*(vector_fullp_t<N_FEATURES_PER_LEVEL>*)&dy_dx[i * fan_out_grad + level * N_FEATURES_PER_LEVEL + grad_dim * num_grid_features] = {0};
}
}

return;
}

Expand Down Expand Up @@ -189,6 +198,15 @@ __global__ void kernel_grid(
encoded_positions[i + (level * N_FEATURES_PER_LEVEL + f) * num_elements] = result[f];
}

// Gradient is zero when there's no interpolation.
if (dy_dx) {
#pragma unroll
for (uint32_t grad_dim = 0; grad_dim < N_POS_DIMS; ++grad_dim) {
const uint32_t fan_out_grad = num_grid_features * N_POS_DIMS;
*(vector_fullp_t<N_FEATURES_PER_LEVEL>*)&dy_dx[i * fan_out_grad + level * N_FEATURES_PER_LEVEL + grad_dim * num_grid_features] = {0};
}
}

return;
}

Expand Down

0 comments on commit d829746

Please sign in to comment.