Skip to content

Commit

Permalink
Re-enable NVML monitoring for WSL (#6119)
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbluca authored May 4, 2022
1 parent c11c8ee commit baf05c0
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 12 deletions.
51 changes: 40 additions & 11 deletions distributed/diagnostics/nvml.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from platform import uname

from packaging.version import parse as parse_version

import dask

try:
Expand All @@ -10,7 +12,9 @@

nvmlInitialized = False
nvmlLibraryNotFound = False
nvmlWslInsufficientDriver = False
nvmlOwnerPID = None
minimumWslVersion = "512.15"


def _in_wsl():
Expand All @@ -21,17 +25,21 @@ def _in_wsl():


def init_once():
global nvmlInitialized, nvmlLibraryNotFound, nvmlOwnerPID
global nvmlInitialized, nvmlLibraryNotFound, nvmlWslInsufficientDriver, nvmlOwnerPID

if dask.config.get("distributed.diagnostics.nvml") is False or _in_wsl():
nvmlInitialized = False
# nvml monitoring disabled
if dask.config.get("distributed.diagnostics.nvml") is False:
return

if pynvml is None or (nvmlInitialized is True and nvmlOwnerPID == os.getpid()):
# nvml is already initialized on this process
if nvmlInitialized and nvmlOwnerPID == os.getpid():
return

nvmlInitialized = True
nvmlOwnerPID = os.getpid()
# nvml failed to initialize due to missing / outdated requirements
if pynvml is None or nvmlLibraryNotFound or nvmlWslInsufficientDriver:
return

# attempt to initialize nvml
try:
pynvml.nvmlInit()
except (
Expand All @@ -40,11 +48,26 @@ def init_once():
pynvml.NVMLError_Unknown,
):
nvmlLibraryNotFound = True
return

# set a minimum driver version for WSL so we can assume certain queries work
if (
not nvmlLibraryNotFound
and parse_version(pynvml.nvmlSystemGetDriverVersion().decode())
< parse_version(minimumWslVersion)
and _in_wsl()
):
nvmlWslInsufficientDriver = True
return

# initialization was successful
nvmlInitialized = True
nvmlOwnerPID = os.getpid()


def device_get_count():
init_once()
if nvmlLibraryNotFound or not nvmlInitialized:
if not nvmlInitialized:
return 0
else:
return pynvml.nvmlDeviceGetCount()
Expand All @@ -53,8 +76,17 @@ def device_get_count():
def _pynvml_handles():
count = device_get_count()
if count == 0:
if pynvml is None:
raise RuntimeError(
"NVML monitoring requires PyNVML and NVML to be installed"
)
if nvmlLibraryNotFound:
raise RuntimeError("PyNVML is installed, but NVML is not")
if nvmlWslInsufficientDriver:
raise RuntimeError(
"Outdated NVIDIA drivers for WSL, please upgrade to "
f"{minimumWslVersion} or newer"
)
else:
raise RuntimeError("No GPUs available")

Expand All @@ -80,13 +112,10 @@ def has_cuda_context():
index of the device for which there's a CUDA context.
"""
init_once()
if nvmlLibraryNotFound or not nvmlInitialized:
if not nvmlInitialized:
return False
for index in range(device_get_count()):
handle = pynvml.nvmlDeviceGetHandleByIndex(index)
# TODO: WSL doesn't support this NVML query yet; when NVML monitoring is enabled
# there we may need to wrap this in a try/except block.
# See https://github.com/dask/distributed/pull/5568
if hasattr(pynvml, "nvmlDeviceGetComputeRunningProcesses_v2"):
running_processes = pynvml.nvmlDeviceGetComputeRunningProcesses_v2(handle)
else:
Expand Down
18 changes: 17 additions & 1 deletion distributed/diagnostics/tests/test_nvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,23 @@ def test_enable_disable_nvml():

with dask.config.set({"distributed.diagnostics.nvml": True}):
nvml.init_once()
assert nvml.nvmlInitialized is True
assert (
nvml.nvmlInitialized
^ nvml.nvmlLibraryNotFound
^ nvml.nvmlWslInsufficientDriver
)


def test_wsl_monitoring_enabled():
try:
pynvml.nvmlShutdown()
except pynvml.NVMLError_Uninitialized:
pass
else:
nvml.nvmlInitialized = False

nvml.init_once()
assert nvml.nvmlWslInsufficientDriver is False


def run_has_cuda_context(queue):
Expand Down

0 comments on commit baf05c0

Please sign in to comment.