Skip to content

Commit

Permalink
Fix visualization of RM forward layers
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom94 committed Feb 11, 2022
1 parent d7b408b commit 5730a5f
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 18 deletions.
10 changes: 6 additions & 4 deletions include/tiny-cuda-nn/common_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -558,22 +558,24 @@ __global__ void cast_from(const uint32_t num_elements, const T* __restrict__ pre
}

template <typename T>
__global__ void extract_dimension_pos_neg_kernel(const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const T* __restrict__ encoded, float* __restrict__ output) {
__global__ void extract_dimension_pos_neg_kernel(const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const T* __restrict__ encoded, const MatrixLayout layout, float* __restrict__ output) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= num_elements) return;

const uint32_t elem_idx = i / fan_out;
const uint32_t dim_idx = i % fan_out;

const uint32_t encoded_idx = layout == MatrixLayout::AoS ? (elem_idx * fan_in + dim) : (elem_idx + dim * num_elements / fan_out);

if (fan_out == 1) {
output[i] = (float)encoded[elem_idx * fan_in + dim];
output[i] = (float)encoded[encoded_idx];
return;
}

if (dim_idx == 0) {
output[i] = fmaxf(-(float)encoded[elem_idx * fan_in + dim], 0.0f);
output[i] = fmaxf(-(float)encoded[encoded_idx], 0.0f);
} else if (dim_idx == 1) {
output[i] = fmaxf((float)encoded[elem_idx * fan_in + dim], 0.0f);
output[i] = fmaxf((float)encoded[encoded_idx], 0.0f);
} else if (dim_idx == 2) {
output[i] = 0;
} else {
Expand Down
7 changes: 4 additions & 3 deletions include/tiny-cuda-nn/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ enum class WeightUsage {
};

template <typename T>
void extract_dimension_pos_neg(cudaStream_t stream, const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const T* encoded, float* output);
void extract_dimension_pos_neg(cudaStream_t stream, const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const T* encoded, MatrixLayout layout, float* output);

template <typename T, typename PARAMS_T=T>
class Network : public DifferentiableObject<T, PARAMS_T, PARAMS_T> {
Expand All @@ -59,13 +59,14 @@ class Network : public DifferentiableObject<T, PARAMS_T, PARAMS_T> {
dimension = std::min(dimension, width(layer)-1);

this->forward(stream, input);
extract_dimension_pos_neg<PARAMS_T>(stream, output.n_elements(), dimension, width(layer), output.rows(), forward_activations(layer), output.data());
auto vals = forward_activations(layer);
extract_dimension_pos_neg<PARAMS_T>(stream, output.n_elements(), dimension, width(layer), output.rows(), vals.first, vals.second, output.data());
this->forward_clear();
}

virtual uint32_t width(uint32_t layer) const = 0;
virtual uint32_t num_forward_activations() const = 0;
virtual const PARAMS_T* forward_activations(uint32_t layer) const = 0;
virtual std::pair<const PARAMS_T*, MatrixLayout> forward_activations(uint32_t layer) const = 0;
};

template <typename T>
Expand Down
4 changes: 2 additions & 2 deletions include/tiny-cuda-nn/network_with_input_encoding.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ class NetworkWithInputEncoding : public Network<float, T> {
return m_network->num_forward_activations() + 1;
}

const T* forward_activations(uint32_t layer) const override {
std::pair<const T*, MatrixLayout> forward_activations(uint32_t layer) const override {
if (!m_forward.network_input.data()) {
throw std::runtime_error{"Must call forward() before accessing activations."};
}
return layer == 0 ? m_forward.network_input.data() : m_network->forward_activations(layer - 1);
return layer == 0 ? std::make_pair<const T*, MatrixLayout>(m_forward.network_input.data(), m_encoding->output_layout()) : m_network->forward_activations(layer - 1);
}

uint32_t input_width() const {
Expand Down
4 changes: 2 additions & 2 deletions include/tiny-cuda-nn/networks/cutlass_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,11 +133,11 @@ class CutlassMLP : public Network<T> {
return m_can_fuse_activation ? m_n_hidden_layers : (m_n_hidden_layers * 2);
}

const T* forward_activations(uint32_t layer) const override {
std::pair<const T*, MatrixLayout> forward_activations(uint32_t layer) const override {
if (m_forward.hidden.size() == 0) {
throw std::runtime_error{"Must call forward() before accessing activations."};
}
return m_forward.hidden.at(layer).data();
return {m_forward.hidden.at(layer).data(), CM};
}

private:
Expand Down
4 changes: 2 additions & 2 deletions include/tiny-cuda-nn/networks/cutlass_resnet.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,11 @@ class CutlassResNet : public Network<T> {
return m_n_blocks * m_n_matrices_per_block + 1;
}

const T* forward_activations(uint32_t layer) const override {
std::pair<const T*, MatrixLayout> forward_activations(uint32_t layer) const override {
if (m_forward.hidden.size() == 0) {
throw std::runtime_error{"Must call forward() before accessing activations."};
}
return m_forward.hidden.at(layer).data();
return {m_forward.hidden.at(layer).data(), CM};
}

private:
Expand Down
4 changes: 2 additions & 2 deletions include/tiny-cuda-nn/networks/fully_fused_mlp.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,11 @@ class FullyFusedMLP : public Network<T> {
return m_n_hidden_layers;
}

const T* forward_activations(uint32_t layer) const override {
std::pair<const T*, MatrixLayout> forward_activations(uint32_t layer) const override {
if (m_forward.hidden.size() == 0) {
throw std::runtime_error{"Must call forward() before accessing activations."};
}
return m_forward.hidden.at(layer).data();
return {m_forward.hidden.at(layer).data(), CM};
}

private:
Expand Down
6 changes: 3 additions & 3 deletions src/network.cu
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ Activation string_to_activation(std::string activation_name) {
}

template <typename T>
void extract_dimension_pos_neg(cudaStream_t stream, const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const T* encoded, float* output) {
linear_kernel(extract_dimension_pos_neg_kernel<T>, 0, stream, num_elements, dim, fan_in, fan_out, encoded, output);
void extract_dimension_pos_neg(cudaStream_t stream, const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const T* encoded, MatrixLayout layout, float* output) {
linear_kernel(extract_dimension_pos_neg_kernel<T>, 0, stream, num_elements, dim, fan_in, fan_out, encoded, layout, output);
}

template void extract_dimension_pos_neg(cudaStream_t stream, const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const network_precision_t* encoded, float* output);
template void extract_dimension_pos_neg(cudaStream_t stream, const uint32_t num_elements, const uint32_t dim, const uint32_t fan_in, const uint32_t fan_out, const network_precision_t* encoded, MatrixLayout layout, float* output);

template <typename T>
Network<T>* create_network(const json& network) {
Expand Down

0 comments on commit 5730a5f

Please sign in to comment.