Skip to content

Commit

Permalink
Add high level control of fp32 matmul precision; disable TF32 for mat…
Browse files Browse the repository at this point in the history
…muls by default

pytorch#76440

CC @mruberry @ptrblck

Pull Request resolved: pytorch#76509
Approved by: https://github.com/ngimel
  • Loading branch information
eqy authored and pytorchmergebot committed May 4, 2022
1 parent 679fc90 commit e838137
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 3 deletions.
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,10 @@ if(NOT MSVC)
string(APPEND CMAKE_CXX_FLAGS " -Wno-range-loop-analysis")
string(APPEND CMAKE_CXX_FLAGS " -Wno-pass-failed")
endif()
if(CMAKE_COMPILER_IS_GNUCXX AND (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 6.0.0))
# Suppress issue: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=43407
string(APPEND CMAKE_CXX_FLAGS " -Wno-attributes")
endif()
if(CMAKE_COMPILER_IS_GNUCXX AND NOT (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0.0))
string(APPEND CMAKE_CXX_FLAGS " -Wno-stringop-overflow")
endif()
Expand Down
37 changes: 35 additions & 2 deletions aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <c10/core/TensorOptions.h>
#include <c10/core/CPUAllocator.h>

#include <algorithm>
#include <mutex>
#include <sstream>
#include <stdexcept>
Expand Down Expand Up @@ -138,11 +139,43 @@ void Context::setBenchmarkCuDNN(bool b) {
}

bool Context::allowTF32CuBLAS() const {
return allow_tf32_cublas;
return float32_matmul_precision != at::Float32MatmulPrecision::HIGHEST;
}

void Context::setAllowTF32CuBLAS(bool b) {
allow_tf32_cublas = b;
float32_matmul_precision = b ? at::Float32MatmulPrecision::HIGH : at::Float32MatmulPrecision::HIGHEST;
}

Float32MatmulPrecision Context::float32MatmulPrecision() const {
return float32_matmul_precision;
}

void Context::setFloat32MatmulPrecision(Float32MatmulPrecision p) {
float32_matmul_precision = p;
}

void Context::setFloat32MatmulPrecision(const std::string &s) {
auto match = [this](const std::string & s_) {
// TODO: consider if CuDNN field needs to also be set for potential future CuDNN ops like multi-headed attention
if (s_ == "highest") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
return true;
} else if (s_ == "high") {
float32_matmul_precision = at::Float32MatmulPrecision::HIGH;
return true;
} else if (s_ == "medium") {
float32_matmul_precision = at::Float32MatmulPrecision::MEDIUM;
return true;
}
return false;
};
if (match(s)) { return; }
std::string sl;
std::transform(s.begin(), s.end(), sl.begin(),
[](unsigned char c) -> unsigned char { return std::tolower(c); });
if (match(sl)) { return; }
TORCH_WARN(s, " is not one of 'highest', 'high', or 'medium'; the current"
"setFloat32MatmulPrecision call has no effect.");
}

at::LinalgBackend Context::linalgPreferredBackend() const {
Expand Down
7 changes: 6 additions & 1 deletion aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace at {

class Tensor;

enum class TORCH_API Float32MatmulPrecision {HIGHEST, HIGH, MEDIUM};

class TORCH_API Context {
public:
Context();
Expand Down Expand Up @@ -204,10 +206,13 @@ class TORCH_API Context {
// https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
void alertCuBLASConfigNotDeterministic() const;

void setFloat32MatmulPrecision(const std::string & s);
bool allowTF32CuDNN() const;
void setAllowTF32CuDNN(bool);
bool allowTF32CuBLAS() const;
void setAllowTF32CuBLAS(bool);
Float32MatmulPrecision float32MatmulPrecision() const;
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
bool allowFP16ReductionCuBLAS() const;
void setAllowFP16ReductionCuBLAS(bool);
at::QEngine qEngine() const;
Expand Down Expand Up @@ -245,8 +250,8 @@ class TORCH_API Context {
bool _deterministic_algorithms = false;
bool _deterministic_algorithms_warn_only = false;
bool benchmark_cudnn = false;
Float32MatmulPrecision float32_matmul_precision = at::Float32MatmulPrecision::HIGHEST;
bool allow_tf32_cudnn = true;
bool allow_tf32_cublas = true;
bool allow_fp16_reduction_cublas = true;
bool enabled_mkldnn = true;
at::LinalgBackend linalg_preferred_backend = at::LinalgBackend::Default;
Expand Down
2 changes: 2 additions & 0 deletions docs/source/torch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,8 @@ Utilities
is_deterministic_algorithms_warn_only_enabled
set_deterministic_debug_mode
get_deterministic_debug_mode
set_float32_matmul_precision
get_float32_matmul_precision
set_warn_always
is_warn_always_enabled
vmap
Expand Down
11 changes: 11 additions & 0 deletions test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,17 @@ def test_cublas_allow_tf32_get_set(self):
self.assertEqual(torch._C._get_cublas_allow_tf32(), not orig)
torch.backends.cuda.matmul.allow_tf32 = orig

def test_float32_matmul_precision_get_set(self):
self.assertEqual(torch.get_float32_matmul_precision(), 'highest')
self.assertFalse(torch.backends.cuda.matmul.allow_tf32, False)
for p in ('medium', 'high'):
torch.set_float32_matmul_precision(p)
self.assertEqual(torch.get_float32_matmul_precision(), p)
self.assertTrue(torch.backends.cuda.matmul.allow_tf32, True)
torch.set_float32_matmul_precision('highest')
self.assertEqual(torch.get_float32_matmul_precision(), 'highest')
self.assertFalse(torch.backends.cuda.matmul.allow_tf32, False)

def test_cublas_allow_fp16_reduced_precision_reduction_get_set(self):
orig = torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction
self.assertEqual(torch._C._get_cublas_allow_fp16_reduced_precision_reduction(), orig)
Expand Down
2 changes: 2 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,8 @@ def _get_cudnn_allow_tf32() -> _bool: ... # THPModule_allowTF32CuDNN
def _set_cudnn_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuDNN
def _get_cublas_allow_tf32() -> _bool: ... # THPModule_allowTF32CuBLAS
def _set_cublas_allow_tf32(arg: _bool) -> None: ... # THPModule_setAllowTF32CuBLAS
def _get_float32_matmul_precision() -> str: ... #THPModule_float32MatmulPrecision
def _set_float32_matmul_precision(arg: str) -> None: ... #THPModule_setFloat32MatmulPrecision
def _get_cublas_allow_fp16_reduced_precision_reduction() -> _bool: ... #THPModule_allowFP16ReductionCuBLAS
def _set_cublas_allow_fp16_reduced_precision_reduction(arg: _bool) -> None: ... #THPModule_setAllowFP16ReductionCuBLAS
# NB: There is no Capsule type in typing, see
Expand Down
18 changes: 18 additions & 0 deletions torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
'are_deterministic_algorithms_enabled',
'is_deterministic_algorithms_warn_only_enabled',
'set_deterministic_debug_mode', 'get_deterministic_debug_mode',
'set_float32_matmul_precision', 'get_float32_matmul_precision',
'set_warn_always', 'is_warn_always_enabled',
]

Expand Down Expand Up @@ -569,6 +570,23 @@ def get_deterministic_debug_mode() -> builtins.int:
else:
return 0

def get_float32_matmul_precision() -> builtins.str:
r"""Returns the current value of float32 matrix multiplication precision. Refer to
:func:`torch.set_float32_matmul_precision` documentation for more details.
"""
return _C._get_float32_matmul_precision()

def set_float32_matmul_precision(precision):
r"""Sets the precision of float32 matrix multiplication (one of HIGHEST, HIGH, MEDIUM).
Original RFC: https://github.com/pytorch/pytorch/issues/76440
Args:
precision(str): default "highest": avoid internally reducing precision with
formats such as TF32.
If "high," allow TF32.
If "medium," allow TF32.
"""
_C._set_float32_matmul_precision(precision)

def set_warn_always(b):
r"""When this flag is False (default) then some PyTorch warnings may only
appear once per process. This helps avoid excessive warning information.
Expand Down
23 changes: 23 additions & 0 deletions torch/csrc/Module.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,27 @@ PyObject *THPModule_allowTF32CuDNN(PyObject *_unused, PyObject *noargs)
else Py_RETURN_FALSE;
}

PyObject *THPModule_setFloat32MatmulPrecision(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(THPUtils_checkString(arg), "set_float32_matmul_precision expects a str, "
"but got %s", THPUtils_typename(arg));
std::string s = THPUtils_unpackString(arg);
at::globalContext().setFloat32MatmulPrecision(s);
Py_RETURN_NONE;
}

PyObject *THPModule_float32MatmulPrecision(PyObject *_unused, PyObject *noargs)
{
std::string s = "highest";
auto p = at::globalContext().float32MatmulPrecision();
if (p == at::Float32MatmulPrecision::HIGH) {
s = "high";
} else if (p == at::Float32MatmulPrecision::MEDIUM) {
s = "medium";
}
return THPUtils_packString(s);
}

PyObject *THPModule_setUserEnabledCuDNN(PyObject *_unused, PyObject *arg)
{
THPUtils_assert(PyBool_Check(arg), "set_enabled_cudnn expects a bool, "
Expand Down Expand Up @@ -686,6 +707,8 @@ static PyMethodDef TorchMethods[] = {
{"_set_warnAlways", THPModule_setWarnAlways, METH_O, nullptr},
{"_get_cublas_allow_tf32", THPModule_allowTF32CuBLAS, METH_NOARGS, nullptr},
{"_set_cublas_allow_tf32", THPModule_setAllowTF32CuBLAS, METH_O, nullptr},
{"_get_float32_matmul_precision", THPModule_float32MatmulPrecision, METH_NOARGS, nullptr},
{"_set_float32_matmul_precision", THPModule_setFloat32MatmulPrecision, METH_O, nullptr},
{"_get_cublas_allow_fp16_reduced_precision_reduction", THPModule_allowFP16ReductionCuBLAS, METH_NOARGS, nullptr},
{"_set_cublas_allow_fp16_reduced_precision_reduction", THPModule_setAllowFP16ReductionCuBLAS, METH_O, nullptr},
{"_vmapmode_increment_nesting", THPModule_vmapmode_increment_nesting, METH_NOARGS, nullptr},
Expand Down
2 changes: 2 additions & 0 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,8 @@ def get_ignored_functions() -> Set[Callable]:
torch.is_deterministic_algorithms_warn_only_enabled,
torch.set_deterministic_debug_mode,
torch.get_deterministic_debug_mode,
torch.set_float32_matmul_precision,
torch.get_float32_matmul_precision,
torch.unify_type_list,
torch.is_warn_always_enabled,
torch.set_warn_always,
Expand Down

0 comments on commit e838137

Please sign in to comment.