Description
Hello,
Following the release of tch==0.14.0
and the publication of changes introduced in #758, I have noticed that the implementation is problematic for some devices.
A large number of half-precision operations (including layer normalization and linear layers) are indeed not implemented for torch- CPU (they are available for CUDA). An example of failing cases with the release of these changes can be seen here. The model weights have been saved as float16
. Note that loading these weights in BART in the Python version of Transformers would also fail for the same reason.
I am unsure about MPS
and Vulkan
, but it may make sense to either restrict the automatic precision casting for CUDA devices, or exclude CPU.
I have tested a small fix that solves the issue above, but other solutions may exist (for example, always cast the loaded tensors to the VarStore
precision, therefore loading the weights as fp16
if and only if the VarStore
was converted to half
before):
The following skips the casting for CPU devices for the copy_data_with_precision_update method:
/// Copies the data from source tensor to destination
///
/// Updates the precision of the destination to match the source
fn copy_data_with_precision_update(
&self,
src: &Tensor,
dst: &mut Tensor,
) -> Result<(), TchError> {
if self.device != Device::Cpu {
dst.set_data(&dst.to_kind(src.kind()));
}
dst.f_copy_(src)
}
I am happy to help with a fix, it would be good to hear what would be your preferred way forward.