Skip to content

Commit

Permalink
Fix: optimize model selection logic to avoid cuda out of memory error. (
Browse files Browse the repository at this point in the history
  • Loading branch information
remiliacn authored Mar 21, 2023
1 parent 234b0c7 commit e4d219d
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
19 changes: 19 additions & 0 deletions modules/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Optional, List, Tuple

from torch.cuda import get_device_properties

from modules.device import torch_gc
from modules.options import cmd_opts

Expand All @@ -17,6 +19,23 @@ def prepare_model():
else:
model = model.float()
else:
if cmd_opts.precision is None:
total_vram_in_gb = get_device_properties(0).total_memory / 1e9
print(f'GPU memory: {total_vram_in_gb:.2f} GB')

if total_vram_in_gb > 30:
cmd_opts.precision = 'fp32'
elif total_vram_in_gb > 13:
cmd_opts.precision = 'fp16'
elif total_vram_in_gb > 10:
cmd_opts.precision = 'int8'
else:
cmd_opts.precision = 'int4'

print(f'Choosing precision {cmd_opts.precision} according to your VRAM.'
f' If you want to decide precision yourself,'
f' please add argument --precision when launching the application.')

if cmd_opts.precision == "fp16":
model = model.half().cuda()
elif cmd_opts.precision == "int4":
Expand Down
2 changes: 1 addition & 1 deletion modules/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

parser.add_argument("--port", type=int, default="17860")
parser.add_argument("--model-path", type=str, default="THUDM/chatglm-6b")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["fp32", "fp16", "int4", "int8"], default="fp16")
parser.add_argument("--precision", type=str, help="evaluate at this precision", choices=["fp32", "fp16", "int4", "int8"])
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
parser.add_argument("--cpu", action='store_true', help="use cpu")
parser.add_argument("--share", action='store_true', help="use gradio share")
Expand Down

0 comments on commit e4d219d

Please sign in to comment.