Skip to content

Commit

Permalink
Utilize allocate_workspace for activation visualization & input gra…
Browse files Browse the repository at this point in the history
…dients
  • Loading branch information
Tom94 committed Feb 10, 2022
1 parent a7eef24 commit 9ceb1db
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 29 deletions.
1 change: 1 addition & 0 deletions include/tiny-cuda-nn/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Network : public DifferentiableObject<T, PARAMS_T, PARAMS_T> {

this->forward(stream, input);
extract_dimension_pos_neg<PARAMS_T>(stream, output.n_elements(), dimension, width(layer), output.rows(), forward_activations(layer), output.data());
this->forward_clear();
}

virtual uint32_t width(uint32_t layer) const = 0;
Expand Down
3 changes: 3 additions & 0 deletions include/tiny-cuda-nn/networks/cutlass_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class CutlassMLP : public Network<T> {
void inference_mixed_precision(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>& output, bool use_inference_matrices = true) override;

void forward(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>* output = nullptr, bool use_inference_matrices = false, bool prepare_input_gradients = false) override;
void forward_clear() override {
m_forward.clear();
}

void backward(
cudaStream_t stream,
Expand Down
3 changes: 3 additions & 0 deletions include/tiny-cuda-nn/networks/cutlass_resnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ class CutlassResNet : public Network<T> {
void inference_mixed_precision(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>& output, bool use_inference_matrices = true) override;

void forward(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>* output = nullptr, bool use_inference_matrices = false, bool prepare_input_gradients = false) override;
void forward_clear() override {
m_forward.clear();
}

void backward(
cudaStream_t stream,
Expand Down
3 changes: 3 additions & 0 deletions include/tiny-cuda-nn/networks/fully_fused_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ class FullyFusedMLP : public Network<T> {
void inference_mixed_precision(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>& output, bool use_inference_matrices = true) override;

void forward(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<T>* output = nullptr, bool use_inference_matrices = false, bool prepare_input_gradients = false) override;
void forward_clear() override {
m_forward.clear();
}

void backward(
cudaStream_t stream,
Expand Down
34 changes: 8 additions & 26 deletions include/tiny-cuda-nn/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ class DifferentiableObject : public ParametricObject<PARAMS_T> {
forward(nullptr, input, output, use_inference_matrices, prepare_input_gradients);
}

virtual void forward_clear() {}

virtual void backward(
cudaStream_t stream,
const GPUMatrixDynamic<T>& input,
Expand Down Expand Up @@ -127,19 +129,19 @@ class DifferentiableObject : public ParametricObject<PARAMS_T> {
) {
// Make sure our temporary buffers have the correct size for the given batch size
uint32_t batch_size = input.n();
if (m_input_gradient_output.n() != batch_size) {
allocate_input_gradient_buffers(batch_size);
}

GPUMatrix<COMPUTE_T> d_doutput = {padded_output_width(), batch_size, stream};
GPUMatrix<COMPUTE_T> output = {padded_output_width(), batch_size, stream};

if (dim >= padded_output_width()) {
throw std::runtime_error{"Invalid dimension to compute the input gradient for."};
}

// Set "loss gradient" at network outputs to 1 at the chosen dimension and 0 elsewhere.
one_hot_batched(stream, m_input_gradient_output.n_elements(), padded_output_width(), dim, m_input_gradient_d_doutput.data(), backprop_scale);
one_hot_batched(stream, output.n_elements(), padded_output_width(), dim, d_doutput.data(), backprop_scale);

forward(stream, input, &m_input_gradient_output, true /* inference matrices */, true /* prep forward buffers for input gradients */);
backward(stream, input, m_input_gradient_output, m_input_gradient_d_doutput, &d_dinput, true /* inference matrices */, false /* no param gradients */);
forward(stream, input, &output, true /* inference matrices */, true /* prep forward buffers for input gradients */);
backward(stream, input, output, d_doutput, &d_dinput, true /* inference matrices */, false /* no param gradients */);

mult(stream, d_dinput.n_elements(), d_dinput.data(), 1.0f / backprop_scale);
}
Expand All @@ -148,26 +150,6 @@ class DifferentiableObject : public ParametricObject<PARAMS_T> {
virtual uint32_t output_width() const = 0;

virtual uint32_t required_input_alignment() const = 0;

private:
void allocate_input_gradient_buffers(uint32_t batch_size) {
m_input_gradient_d_doutput.set_size(padded_output_width(), batch_size);
m_input_gradient_output.set_size(padded_output_width(), batch_size);

GPUMatrixBase::allocate_shared_memory(
m_input_gradient_buffer,
{
&m_input_gradient_d_doutput,
&m_input_gradient_output,
}
);
}

// Temporary buffers for computing input gradients.
// (Lazily allocated on demand.)
GPUMemory<char> m_input_gradient_buffer;
GPUMatrix<COMPUTE_T> m_input_gradient_d_doutput;
GPUMatrix<COMPUTE_T> m_input_gradient_output;
};

TCNN_NAMESPACE_END
2 changes: 1 addition & 1 deletion src/cutlass_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ void CutlassMLP<T>::backward(
}
}

m_forward.clear();
forward_clear();
}

template <typename T>
Expand Down
2 changes: 1 addition & 1 deletion src/cutlass_resnet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ void CutlassResNet<T, input_activation>::backward(
}
}

m_forward.clear();
forward_clear();
}

template <typename T, Activation input_activation>
Expand Down
2 changes: 1 addition & 1 deletion src/fully_fused_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -940,7 +940,7 @@ void FullyFusedMLP<T, WIDTH>::backward(
}
}

m_forward.clear();
forward_clear();
}

template <typename T, int WIDTH>
Expand Down

0 comments on commit 9ceb1db

Please sign in to comment.