Skip to content

Commit

Permalink
Removed no-cuda flag and creation of CPU context.
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrvilya committed Oct 11, 2019
1 parent 72795a7 commit 6d754d5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 16 deletions.
3 changes: 0 additions & 3 deletions adaptis/utils/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@ def get_common_arguments():
parser.add_argument('--thread-pool', action='store_true', default=False,
help='use ThreadPool for dataloader workers')

parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')

parser.add_argument('--ngpus', type=int,
default=len(mx.test_utils.list_gpus()),
help='number of GPUs')
Expand Down
21 changes: 8 additions & 13 deletions adaptis/utils/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,15 @@ def init_experiment(experiment_name, add_exp_args, script_path=None):
fh.setFormatter(formatter)
logger.addHandler(fh)

if args.no_cuda:
logger.info('Using CPU')
args.kvstore = 'local'
args.ctx = mx.cpu(0)
if args.gpus:
args.ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')]
args.ngpus = len(args.ctx)
else:
if args.gpus:
args.ctx = [mx.gpu(int(i)) for i in args.gpus.split(',')]
args.ngpus = len(args.ctx)
else:
args.ctx = [mx.gpu(i) for i in range(args.ngpus)]
logger.info(f'Number of GPUs: {args.ngpus}')

if args.ngpus < 2:
args.syncbn = False
args.ctx = [mx.gpu(i) for i in range(args.ngpus)]
logger.info(f'Number of GPUs: {args.ngpus}')

if args.ngpus < 2:
args.syncbn = False

logger.info(args)

Expand Down

0 comments on commit 6d754d5

Please sign in to comment.