diff --git a/lagom/utils/seeding.py b/lagom/utils/seeding.py index e7672119..a3a80708 100644 --- a/lagom/utils/seeding.py +++ b/lagom/utils/seeding.py @@ -17,7 +17,8 @@ def set_global_seeds(seed): """ random.seed(seed) np.random.seed(seed) - torch.manual_seed(seed) # both torch and torch.cuda internally + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) class Seeder(object):