RFC: Controlling fp32 matmul’s internal precision #76440
Description
Float32 matmuls are one of the most important operations in PyTorch, and considerable attention has been paid to running them faster with specialized math libraries and new math modes. For example, NVIDIA’s Ampere (and later) CUDA devices support computing fp32 matrix multiplication using tf32 math mode, and Google’s TPUs have bfloat16 and bfloat16_3x math modes, all of which perform some of the internal matmul computations in lower precision floating point values that trade numerical accuracy for performance.
Today in PyTorch these less accurate but faster float32 matmuls are the default, but this has resulted in user confusion (see this Dev Discuss post summarizing the TensorFloat32 discussion, in particular). While many programs benefit from these faster matmuls, we think it’s better if PyTorch bias towards numerical accuracy by default, and require users explicitly opt-in to these faster math modes when they’re available. We think this is the best way to keep PyTorch users in control and provide them clarity about what their programs are doing.
To enable more generic control over precision of matrix multiplication operation we propose adding a device-agnostic math mode setting, modeled after JAX’s float32 matmul precision UX. This setting would work as follows:
- Add a new function, torch.set_float32_matmul_precision
- This function accept values from a FLOAT32_MATMUL_PRECISION enum, as well as strings corresponding to those values
- The enum has the following values:
- HIGHEST (corresponding string: “highest”, this is the conceptual default value)
- HIGH (corresponding string: “high”)
- MEDIUM (corresponding string: “medium”)
- The behavior of each flag when passed to the function is as follows:
- HIGHEST sets torch.backends.cuda.matmul.allow_tf32 = False
- Non-native device types should also respect this flag
- For the XLA device type the corresponding precision setting would be “HIGHEST”
- HIGH sets torch.backends.cuda.matmul.allow_tf32 = True
- For the XLA device type the corresponding precision setting would be “HIGH” (bfloat16_3x precision on TPUs)
- MEDIUM sets torch.backends.cuda.matmul.allow_tf32 = True
- For the XLA device type the corresponding precision setting would be “DEFAULT” (bfloat16 precision on TPUs)
- HIGHEST sets torch.backends.cuda.matmul.allow_tf32 = False
- The default value of the following flags will be changed:
- torch.backends.cuda.matmul.allow_tf32 will be changed to False by default
- Non-native device types should change their default float32 matmul mode to be consistent with these settings
(Update: these settings are chosen for consistency with JAX's fp32 matmul precision settings, see here.)
Note the CPU device type is currently unaffected by this setting.
We think that in the future this mechanism can be used to provide a way for programs to specify their desired float32 matrix multiplication mode across a variety of devices.
cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @lezcano @ptrblck @JackCaoG @csarofeen @jeffdaily @Balandat @ngimel