Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix illegal memory access with multi_tensor_apply size above INT_MAX (#…
…1825) Currently, multi_tensor_apply causes an illegal memory access due to an overflow in the `size` field of `TensorListMetadata`. This can be reproduced using the following standalone script: ```python import torch, amp_C from apex.multi_tensor_apply import multi_tensor_applier multi_tensor_adam = amp_C.multi_tensor_adam size = 2**32+1 g_32 = [torch.zeros(size, dtype=torch.float32, device='cuda')] p_32 = [torch.zeros(size, dtype=torch.float32, device='cuda')] m_32 = [torch.zeros(size, dtype=torch.float32, device='cuda')] v_32 = [torch.zeros(size, dtype=torch.float32, device='cuda')] _dummy_overflow_buf = torch.zeros(1, dtype=torch.int32, device='cuda') multi_tensor_applier(multi_tensor_adam, _dummy_overflow_buf, [g_32, p_32, m_32, v_32], 0.0, 0.9, 0.95, 1e-08, 1, 1, 1, 0.1) print(g_32) ```
- Loading branch information