Skip to content

Commit

Permalink
Fix backward() for zero-hidden-layer MLPs
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom94 committed Feb 14, 2022
1 parent c3d9792 commit 8430c46
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion src/cutlass_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ void CutlassMLP<T>::backward(
bool use_inference_matrices,
bool compute_param_gradients
) {
if (m_forward.hidden.size() == 0) {
if (m_n_hidden_layers > 0 && m_forward.hidden.size() == 0) {
throw std::runtime_error{"Must call forward() before calling backward()."};
}

Expand Down
2 changes: 1 addition & 1 deletion src/fully_fused_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -826,7 +826,7 @@ void FullyFusedMLP<T, WIDTH>::backward(
bool use_inference_matrices,
bool compute_param_gradients
) {
if (m_forward.hidden.size() == 0) {
if (m_n_hidden_layers > 0 && m_forward.hidden.size() == 0) {
throw std::runtime_error{"Must call forward() before calling backward()."};
}

Expand Down

0 comments on commit 8430c46

Please sign in to comment.