Skip to content

Commit

Permalink
Fix RM fp32 inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom94 committed Feb 9, 2022
1 parent 3777f47 commit a69c534
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion include/tiny-cuda-nn/networks/cutlass_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ class CutlassMLP : public Network<T> {
// Storage of inference temporary data
GPUMemory<char> m_inference_buffer;
std::array<GPUMatrix<T>, 2> m_inference_tmp;
GPUMatrix<T> m_inference_output_tmp;
GPUMatrixDynamic<T> m_inference_output_tmp;

// Storage of forward pass data
GPUMemory<char> m_forward_buffer = GPUMemory<char>(0);
Expand Down
2 changes: 1 addition & 1 deletion include/tiny-cuda-nn/networks/cutlass_resnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ class CutlassResNet : public Network<T> {
GPUMemory<char> m_inference_buffer;
GPUMatrix<T> m_inference_linear_tmp;
std::array<GPUMatrix<T>, 2> m_inference_residual_tmp;
GPUMatrix<T> m_inference_output_tmp;
GPUMatrixDynamic<T> m_inference_output_tmp;

// Storage of forward pass data
GPUMemory<char> m_forward_buffer;
Expand Down
2 changes: 1 addition & 1 deletion include/tiny-cuda-nn/networks/fully_fused_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class FullyFusedMLP : public Network<T> {
// Storage of inference temporary data
GPUMemory<char> m_inference_buffer;
GPUMatrix<T> m_inference_tmp;
GPUMatrix<T> m_inference_output_tmp;
GPUMatrixDynamic<T> m_inference_output_tmp;

// Storage of forward pass data
GPUMemory<char> m_forward_buffer = GPUMemory<char>(0);
Expand Down
1 change: 1 addition & 0 deletions src/cutlass_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ CutlassMLP<T>::~CutlassMLP() {

template <typename T>
void CutlassMLP<T>::inference(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<float>& output) {
m_inference_output_tmp.set_layout(output.layout());
inference_mixed_precision(stream, input, m_inference_output_tmp);

const uint32_t n_elements = (uint32_t)output.n_elements();
Expand Down
1 change: 1 addition & 0 deletions src/cutlass_resnet.cu
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ CutlassResNet<T, input_activation>::~CutlassResNet() {

template <typename T, Activation input_activation>
void CutlassResNet<T, input_activation>::inference(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<float>& output) {
m_inference_output_tmp.set_layout(output.layout());
inference_mixed_precision(stream, input, m_inference_output_tmp);

const uint32_t n_elements = (uint32_t)output.n_elements();
Expand Down
1 change: 1 addition & 0 deletions src/fully_fused_mlp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -719,6 +719,7 @@ FullyFusedMLP<T, WIDTH>::~FullyFusedMLP() {

template <typename T, int WIDTH>
void FullyFusedMLP<T, WIDTH>::inference(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<float>& output) {
m_inference_output_tmp.set_layout(output.layout());
inference_mixed_precision(stream, input, m_inference_output_tmp);

const uint32_t n_elements = (uint32_t)output.n_elements();
Expand Down

0 comments on commit a69c534

Please sign in to comment.