Skip to content

Commit

Permalink
Merge pull request NVlabs#33 from NVlabs/misc-additions
Browse files Browse the repository at this point in the history
Row-major inputs; reduced memory usage; misc improvements
  • Loading branch information
Tom94 authored Feb 8, 2022
2 parents 7ce5a49 + 1fd5448 commit 48edf10
Show file tree
Hide file tree
Showing 17 changed files with 253 additions and 167 deletions.
20 changes: 19 additions & 1 deletion include/tiny-cuda-nn/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@
#include <stdexcept>
#include <string>

#ifdef __CUDA_NO_HALF_OPERATORS__
#undef __CUDA_NO_HALF_OPERATORS__
#endif
#ifdef __CUDA_NO_HALF_CONVERSIONS__
#undef __CUDA_NO_HALF_CONVERSIONS__
#endif
#ifdef __CUDA_NO_HALF2_OPERATORS__
#undef __CUDA_NO_HALF2_OPERATORS__
#endif

#include <cuda_fp16.h>


Expand Down Expand Up @@ -359,7 +369,7 @@ using vector_halfp_t = vector_t<__half, N_HALFS>;
template <typename T>
struct PitchedPtr {
TCNN_HOST_DEVICE PitchedPtr() : ptr{nullptr}, stride_in_bytes{sizeof(T)} {}
TCNN_HOST_DEVICE PitchedPtr(T* ptr, size_t stride_in_elements, size_t offset = 0) : ptr{ptr + offset}, stride_in_bytes{(uint32_t)(stride_in_elements * sizeof(T))} {}
TCNN_HOST_DEVICE PitchedPtr(T* ptr, size_t stride_in_elements, size_t offset = 0, size_t extra_stride_bytes = 0) : ptr{ptr + offset}, stride_in_bytes{(uint32_t)(stride_in_elements * sizeof(T) + extra_stride_bytes)} {}

template <typename U>
TCNN_HOST_DEVICE explicit PitchedPtr(PitchedPtr<U> other) : ptr{(T*)other.ptr}, stride_in_bytes{other.stride_in_bytes} {}
Expand All @@ -368,6 +378,14 @@ struct PitchedPtr {
return (T*)((const char*)ptr + y * stride_in_bytes);
}

TCNN_HOST_DEVICE void operator+=(uint32_t y) {
ptr = (T*)((const char*)ptr + y * stride_in_bytes);
}

TCNN_HOST_DEVICE void operator-=(uint32_t y) {
ptr = (T*)((const char*)ptr - y * stride_in_bytes);
}

TCNN_HOST_DEVICE explicit operator bool() const {
return ptr;
}
Expand Down
15 changes: 5 additions & 10 deletions include/tiny-cuda-nn/common_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,17 +513,15 @@ __global__ void add(const uint32_t num_elements, const T1* data_in_1, const T2*
}

template <typename T>
__global__ void add(const uint32_t num_elements, const T* __restrict__ data_in, T* __restrict__ data_in_out)
{
__global__ void add(const uint32_t num_elements, const T* __restrict__ data_in, T* __restrict__ data_in_out) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= num_elements) return;

data_in_out[i] = data_in[i] + data_in_out[i];
}

template <typename T>
__global__ void trim(const uint32_t num_elements, const uint32_t stride, const uint32_t dims, const T* __restrict__ data_in, T* __restrict__ data_out)
{
__global__ void trim(const uint32_t num_elements, const uint32_t stride, const uint32_t dims, const T* __restrict__ data_in, T* __restrict__ data_out) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= num_elements) return;

Expand All @@ -534,8 +532,7 @@ __global__ void trim(const uint32_t num_elements, const uint32_t stride, const u
}

template <typename T>
__global__ void trim_and_cast(const uint32_t num_elements, const uint32_t stride, const uint32_t dims, const T* __restrict__ data_in, float* __restrict__ data_out)
{
__global__ void trim_and_cast(const uint32_t num_elements, const uint32_t stride, const uint32_t dims, const T* __restrict__ data_in, float* __restrict__ data_out) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= num_elements) return;

Expand All @@ -546,17 +543,15 @@ __global__ void trim_and_cast(const uint32_t num_elements, const uint32_t stride
}

template <typename T>
__global__ void cast(const uint32_t num_elements, const float* __restrict__ full_precision, T* __restrict__ target)
{
__global__ void cast(const uint32_t num_elements, const float* __restrict__ full_precision, T* __restrict__ target) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= num_elements) return;

target[i] = (T)full_precision[i];
}

template <typename T>
__global__ void cast_from(const uint32_t num_elements, const T* __restrict__ precision, float* __restrict__ full_precision)
{
__global__ void cast_from(const uint32_t num_elements, const T* __restrict__ precision, float* __restrict__ full_precision) {
const uint32_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= num_elements) return;

Expand Down
45 changes: 8 additions & 37 deletions include/tiny-cuda-nn/cutlass_matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -346,37 +346,6 @@ using SplitKGemm = cutlass::gemm::device::GemmSplitKParallel<
EPILOGUE
>;

inline std::map<cudaStream_t, GPUMemory<uint8_t>>& cutlass_workspaces() {
static std::map<cudaStream_t, GPUMemory<uint8_t>> s_workspaces;
return s_workspaces;
}

inline uint8_t* cutlass_get_workspace(size_t size, cudaStream_t stream) {
GPUMemory<uint8_t>& workspace = cutlass_workspaces()[stream];
if (size > workspace.size()) {
size *= 2;
#ifdef TCNN_VERBOSE_MEMORY_ALLOCS
std::cout << "CUTLASS GEMM: Allocating temporary workspace of " << bytes_to_string(size) << "." << std::endl;
#endif

// Allocate twice the requested size to make sure we're not constantly allocating small increments.
workspace.resize(size);
}
return workspace.data();
}

inline void cutlass_free_workspace(cudaStream_t stream) {
if (cutlass_workspaces().count(stream) == 0) {
return;
}

#ifdef TCNN_VERBOSE_MEMORY_ALLOCS
std::cout << "CUTLASS GEMM: Freeing temporary workspace of " << bytes_to_string(cutlass_workspaces().at(stream).size()) << "." << std::endl;
#endif
cutlass_workspaces().erase(stream);
}


template <class Gemm>
void fc_multiply_impl(cudaStream_t stream, const typename Gemm::Arguments& args) {
// Using the arguments, query for extra workspace required for matrix multiplication computation
Expand All @@ -386,7 +355,8 @@ void fc_multiply_impl(cudaStream_t stream, const typename Gemm::Arguments& args)
Gemm gemm_op;

// Initialize CUTLASS kernel with arguments and workspace pointer
cutlass::Status status = gemm_op.initialize(args, cutlass_get_workspace(workspace_size, stream), stream);
auto workspace = borrow_workspace(stream, workspace_size);
cutlass::Status status = gemm_op.initialize(args, workspace.data(), stream);
CUTLASS_CHECK(status);

// Launch initialized CUTLASS kernel
Expand All @@ -403,7 +373,8 @@ void fc_multiply_split_k_impl(cudaStream_t stream, const typename Gemm::Argument
Gemm gemm_op;

// Initialize CUTLASS kernel with arguments and workspace pointer
cutlass::Status status = gemm_op.initialize(args, cutlass_get_workspace(workspace_size, stream));
auto workspace = borrow_workspace(stream, workspace_size);
cutlass::Status status = gemm_op.initialize(args, workspace.data());
CUTLASS_CHECK(status);

// Launch initialized CUTLASS kernel
Expand Down Expand Up @@ -589,8 +560,8 @@ void fc_multiply_split_k(cudaStream_t stream, const GPUMatrix<TypeA, LayoutA>& A
}
}

template <typename config, typename TypeA, typename TypeB, MatrixLayout LayoutB, typename TypeC, typename TypeD>
void fc_multiply_split_k(cudaStream_t stream, const GPUMatrixDynamic<TypeA>& A, const GPUMatrix<TypeB, LayoutB>& B, const GPUMatrixDynamic<TypeC>& C, GPUMatrixDynamic<TypeD>& D, int split_k_slices = 1) {
template <typename config, typename TypeA, typename TypeB, typename TypeC, typename TypeD>
void fc_multiply_split_k(cudaStream_t stream, const GPUMatrixDynamic<TypeA>& A, const GPUMatrixDynamic<TypeB>& B, const GPUMatrixDynamic<TypeC>& C, GPUMatrixDynamic<TypeD>& D, int split_k_slices = 1) {
if (A.layout() == CM) {
auto A_CM = GPUMatrix<TypeA, CM>{A};
fc_multiply_split_k<config>(stream, A_CM, B, C, D, split_k_slices);
Expand All @@ -600,8 +571,8 @@ void fc_multiply_split_k(cudaStream_t stream, const GPUMatrixDynamic<TypeA>& A,
}
}

template <typename config, typename TypeA, typename TypeB, MatrixLayout LayoutB, typename TypeD>
void fc_multiply_split_k(cudaStream_t stream, const GPUMatrixDynamic<TypeA>& A, const GPUMatrix<TypeB, LayoutB>& B, GPUMatrixDynamic<TypeD>& D, int split_k_slices) {
template <typename config, typename TypeA, typename TypeB, typename TypeD>
void fc_multiply_split_k(cudaStream_t stream, const GPUMatrixDynamic<TypeA>& A, const GPUMatrixDynamic<TypeB>& B, GPUMatrixDynamic<TypeD>& D, int split_k_slices) {
fc_multiply_split_k<config>(stream, A, B, D, D, split_k_slices);
}

Expand Down
2 changes: 1 addition & 1 deletion include/tiny-cuda-nn/encodings/composite.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ class CompositeEncoding : public Encoding<T> {
}

if (total_nested_dims_to_encode > n_dims_to_encode) {
throw std::runtime_error{"CompositeEncoding: nested encodings must not encode more dims than composite"};
throw std::runtime_error{"CompositeEncoding:' nested encodings must not encode more dims than composite"};
}

uint32_t unspecified_dims_to_encode = n_dims_to_encode - total_nested_dims_to_encode;
Expand Down
33 changes: 16 additions & 17 deletions include/tiny-cuda-nn/encodings/grid.h
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,11 @@ class GridEncodingTemplated : public GridEncoding<T> {
if (n_features % N_FEATURES_PER_LEVEL != 0) {
throw std::runtime_error{"GridEncoding: number of grid features must be a multiple of n_features_per_level"};
}

// Only needs temporary storage if gradients are computed with different precision from T.
if (!std::is_same<grad_t, T>::value) {
m_grid_gradient_tmp.resize(m_n_params);
}
}

void encode(
Expand All @@ -564,12 +569,10 @@ class GridEncodingTemplated : public GridEncoding<T> {
}

GPUMemory<float>* positions = &m_positions[stream];
GPUMemory<T>* encoded_positions = &m_encoded_positions[stream];
positions->enlarge(num_elements * N_POS_DIMS);

if (positions->size() < num_elements * N_POS_DIMS) {
positions->resize(num_elements * N_POS_DIMS * 2);
encoded_positions->resize(num_elements * m_n_features * 2);
}
auto workspace = borrow_workspace(stream, num_elements * m_n_features * sizeof(T));
auto encoded_positions = (vector_t<T, N_FEATURES_PER_LEVEL>*)workspace.data();

SyncedMultiStream synced_streams{stream, m_n_to_pad > 0 ? 2u : 1u};

Expand Down Expand Up @@ -610,7 +613,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
m_grid_type,
is_inference ? m_grid_inference : m_grid,
positions->data(),
(vector_t<T, N_FEATURES_PER_LEVEL>*)encoded_positions->data(),
encoded_positions,
dy_dx
);

Expand All @@ -619,7 +622,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
const uint32_t blocks_transpose = div_round_up(num_elements, threads_transpose.y);
transpose_encoded_position<vector_t<T, N_FEATURES_PER_LEVEL>><<<blocks_transpose, threads_transpose, 0, synced_streams.get(0)>>>(
num_elements,
(const vector_t<T, N_FEATURES_PER_LEVEL>*)encoded_positions->data(),
encoded_positions,
PitchedPtr<vector_t<T, N_FEATURES_PER_LEVEL>>{outputs}
);
}
Expand All @@ -639,9 +642,11 @@ class GridEncodingTemplated : public GridEncoding<T> {

{
GPUMemory<float>* positions = &m_positions[stream];
GPUMemory<T>* encoded_positions = &m_encoded_positions[stream];

if (positions->size() < num_elements || encoded_positions->size() < num_elements) {
auto workspace = borrow_workspace(stream, num_elements * m_n_features * sizeof(T));
auto dL_dy_soa = (T*)workspace.data();

if (positions->size() < num_elements) {
throw std::runtime_error{"GridEncoding: backward(stream) called without calling encode(stream) beforehand."};
}

Expand All @@ -650,7 +655,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
const uint32_t blocks_transpose = div_round_up(num_elements, threads_transpose.y);
transpose_gradients<vector_t<T, N_FEATURES_PER_LEVEL>><<<blocks_transpose, threads_transpose, 0, stream>>>(
num_elements,
(vector_t<T, N_FEATURES_PER_LEVEL>*)encoded_positions->data(),
(vector_t<T, N_FEATURES_PER_LEVEL>*)dL_dy_soa,
PitchedPtr<const vector_t<T, N_FEATURES_PER_LEVEL>>{dL_dy}
);

Expand Down Expand Up @@ -685,7 +690,7 @@ class GridEncodingTemplated : public GridEncoding<T> {
m_grid_type,
grid_gradient,
positions->data(), // positions SoA
(const vector_t<T, N_FEATURES_PER_THREAD>*)encoded_positions->data() // gradients SoA
(const vector_t<T, N_FEATURES_PER_THREAD>*)dL_dy_soa // gradients SoA
);

if (!std::is_same<grad_t, T>::value) {
Expand Down Expand Up @@ -743,11 +748,6 @@ class GridEncodingTemplated : public GridEncoding<T> {

// Initialize the hashgrid from the GPU, because the number of parameters can be quite large.
generate_random_uniform<float>(rnd, n_params(), params_full_precision, -1e-4f, 1e-4f);

// Only needs temporary storage if gradients are computed with different precision from T.
if (!std::is_same<grad_t, T>::value) {
m_grid_gradient_tmp.resize(n_params());
}
}

size_t n_params() const override {
Expand Down Expand Up @@ -797,7 +797,6 @@ class GridEncodingTemplated : public GridEncoding<T> {
T* m_grid_gradient;

mutable std::map<cudaStream_t, GPUMemory<float>> m_positions;
mutable std::map<cudaStream_t, GPUMemory<T>> m_encoded_positions;
};

template <typename T, uint32_t N_FEATURES_PER_LEVEL>
Expand Down
2 changes: 1 addition & 1 deletion include/tiny-cuda-nn/encodings/spherical_harmonics.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ __global__ void kernel_sh(
float z = data_in(i)[2] * 2.f - 1.f;

// Let compiler figure out how to sequence/reorder these calculations w.r.t. branches
float xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z, xyz=xy*z;
float xy=x*y, xz=x*z, yz=y*z, x2=x*x, y2=y*y, z2=z*z;
float x4=x2*x2, y4=y2*y2, z4=z2*z2;
float x6=x4*x2, y6=y4*y2, z6=z4*z2;

Expand Down
77 changes: 77 additions & 0 deletions include/tiny-cuda-nn/gpu_memory.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include <cuda.h>

#include <atomic>
#include <map>
#include <stdexcept>
#include <stdint.h>
#include <string>
Expand Down Expand Up @@ -406,4 +407,80 @@ class GPUMemory {
}
};

class Workspace {
public:
Workspace() = default;
Workspace(Workspace&& other) = default;
Workspace(const Workspace& other) = delete;

GPUMemory<uint8_t>& mem() {
return m_memory;
}

void borrow() {
if (m_in_use) {
throw std::runtime_error{"Attempted to borrow workspace that was already borrowed."};
}
m_in_use = true;
}

void release() {
if (!m_in_use) {
throw std::runtime_error{"Attempted to return workspace that was not borrowed."};
}
m_in_use = false;
}

private:
GPUMemory<uint8_t> m_memory;
bool m_in_use = false;
};

class BorrowedWorkspace {
public:
BorrowedWorkspace(Workspace* workspace) : m_workspace{workspace} {
m_workspace->borrow();
}

~BorrowedWorkspace() {
m_workspace->release();
}

BorrowedWorkspace(const BorrowedWorkspace& other) = delete;

BorrowedWorkspace(BorrowedWorkspace&& other) {
std::swap(m_workspace, other.m_workspace);
}

GPUMemory<uint8_t>& mem() {
return m_workspace->mem();
}

uint8_t* data() {
return m_workspace->mem().data();
}

private:
Workspace* m_workspace = nullptr;
};

inline std::map<cudaStream_t, Workspace>& workspaces() {
static std::map<cudaStream_t, Workspace> s_workspaces;
return s_workspaces;
}

inline BorrowedWorkspace borrow_workspace(cudaStream_t stream) {
return BorrowedWorkspace{&workspaces()[stream]};
}

inline BorrowedWorkspace borrow_workspace(cudaStream_t stream, size_t n_bytes) {
auto workspace = borrow_workspace(stream);
workspace.mem().enlarge(n_bytes);
return workspace;
}

inline void free_workspace(cudaStream_t stream) {
workspaces().erase(stream);
}

TCNN_NAMESPACE_END
4 changes: 2 additions & 2 deletions include/tiny-cuda-nn/network.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class Network : public DifferentiableObject<T, PARAMS_T, PARAMS_T> {
public:
virtual ~Network() { }

virtual void inference_mixed_precision(cudaStream_t stream, const GPUMatrix<T>& input, GPUMatrixDynamic<PARAMS_T>& output, bool use_inference_matrices = true) = 0;
void inference_mixed_precision(const GPUMatrix<T>& input, GPUMatrixDynamic<PARAMS_T>& output, bool use_inference_matrices = true) {
virtual void inference_mixed_precision(cudaStream_t stream, const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<PARAMS_T>& output, bool use_inference_matrices = true) = 0;
void inference_mixed_precision(const GPUMatrixDynamic<T>& input, GPUMatrixDynamic<PARAMS_T>& output, bool use_inference_matrices = true) {
inference_mixed_precision(nullptr, input, output, use_inference_matrices);
}

Expand Down
Loading

0 comments on commit 48edf10

Please sign in to comment.