bm.benchmark()
fails to de-allocate GPU memory. garbage collection fail? #186
Description
Report
I'm running bm.benchmark()
using three embedding_obsm_keys
on 2.7m cell dataset. It previously ran, albiet slowly on a smaller dataset using only two keys.
I can reproduce it on a local 4080, and on dual remote v100s.
CODE: ( almost directly from tutoral)
adata_input = "artifacts/cohortv2.2.0/cohort.final_adata.h5ad" #~2.7m cells x 3k genes
# Set CPUs to use for parallel computing
sc._settings.ScanpyConfig.n_jobs = -1
batch_key = "sample"
label_key = "cell_type"
adata = sc.read_h5ad(adata_input) # type: ignore
adata.obsm["Unintegrated"] = adata2.obsm["X_pca"]
biocons = BioConservation(isolated_labels=False)
bm = Benchmarker(
adata,
batch_key=batch_key,
label_key=label_key,
embedding_obsm_keys=["Unintegrated", "X_scvi", "X_pca_harmony"],
pre_integrated_embedding_obsm_key="X_pca",
bio_conservation_metrics=biocons,
n_jobs=-1,
)
bm.prepare(neighbor_computer=faiss_brute_force_nn) # full GPU half memory (scales with different GPUs)
bm.benchmark()
FAILURE:
Computing neighbors: 100%|██████████| 3/3 [04:21<00:00, 87.18s/it]
Embeddings: 0%| | 0/3 [00:00<?, ?it/s] 2024-11-25 23:57:13.747881: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.23GiB (rounded to 2395844096)requested by op
2024-11-25 23:57:13.748060: W external/tsl/tsl/framework/bfc_allocator.cc:494] **********************_********************__**********__*****************_________________*********
E1125 23:57:13.748128 17 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2395843920 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 2.23GiB
constant allocation: 0B
maybe_live_out allocation: 2.23GiB
preallocated temp allocation: 0B
total allocation: 4.46GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 2.23GiB
Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=2]" source_file="/opt/scripts/main/artifact_metrics.py" source_line=76
XLA Label: fusion
Shape: s32[3327561,90,2]
==========================
Buffer 2:
Size: 1.12GiB
Entry Parameter Subshape: s32[3327561,90,1]
==========================
Buffer 3:
Size: 1.12GiB
Entry Parameter Subshape: s32[3327561,90,1]
==========================
Embeddings: 0%| | 0/3 [7:09:01<?, ?it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/opt/scripts/main/artifact_metrics.py", line 131, in <module>
main(args)
File "/opt/scripts/main/artifact_metrics.py", line 96, in main
report_df = get_artifact_metrics(adata, args.batch_key, args.label_key, report_dir)
File "/opt/scripts/main/artifact_metrics.py", line 76, in get_artifact_metrics
bm.benchmark()
File "/usr/local/lib/python3.10/site-packages/scib_metrics/benchmark/_core.py", line 221, in benchmark
metric_value = getattr(MetricAnnDataAPI, metric_name)(ad, metric_fn)
File "/usr/local/lib/python3.10/site-packages/scib_metrics/benchmark/_core.py", line 92, in <lambda>
ilisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_BATCH])
File "/usr/local/lib/python3.10/site-packages/scib_metrics/_lisi.py", line 66, in ilisi_knn
lisi = lisi_knn(X, batches, perplexity=perplexity)
File "/usr/local/lib/python3.10/site-packages/scib_metrics/_lisi.py", line 36, in lisi_knn
simpson = compute_simpson_index(
File "/usr/local/lib/python3.10/site-packages/scib_metrics/utils/_lisi.py", line 132, in compute_simpson_index
out = jax.vmap(simpson_fn)(knn_dists, knn_labels, self_mask)
File "/usr/local/lib/python3.10/site-packages/scib_metrics/utils/_lisi.py", line 89, in _compute_simpson_index_cell
return jnp.where(H == 0, -1, _non_zero_H_simpson())
File "/usr/local/lib/python3.10/site-packages/scib_metrics/utils/_lisi.py", line 86, in _non_zero_H_simpson
sumP = jnp.bincount(knn_labels_row, weights=P, length=n_batches)
File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1453, in bincount
return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 529, in add
return scatter._scatter_update(self.array, self.index, values,
File "/usr/local/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 80, in _scatter_update
return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
File "/usr/local/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 131, in _scatter_impl
out = scatter_op(
ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2395843920 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 2.23GiB
constant allocation: 0B
maybe_live_out allocation: 2.23GiB
preallocated temp allocation: 0B
total allocation: 4.46GiB
total fragmentation: 0B (0.00%)
Peak buffers:
Buffer 1:
Size: 2.23GiB
Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=2]" source_file="/opt/scripts/main/artifact_metrics.py" source_line=76
XLA Label: fusion
Shape: s32[3327561,90,2]
==========================
Buffer 2:
Size: 1.12GiB
Entry Parameter Subshape: s32[3327561,90,1]
==========================
Buffer 3:
Size: 1.12GiB
Entry Parameter Subshape: s32[3327561,90,1]
==========================
The bm.prepare(neighbor_computer=faiss_brute_force_nn)
works fine and speed scales nicely with resources. (Thanks Jax!). Which exact metric it fails on seems to vary, but in general for each itteration of the embedding key:
line 211 benchmarker/_core.py
for emb_key, ad in tqdm(self._emb_adatas.items(), desc="Embeddings", position=0, colour="green"):
The memory usage steps on each iteration, as themetric_fn(ad)
is called where I believe each ad is loaded to GPU memory. The proper behavior should probably be to free that GPU memory and/or re-use it for the other emb_key.
More corroborative evidence of issue:
If I subsample the adata to be small enough to not hit the ceiling on the three iterations, I can follow the steps of GPU memory utilization jumping at each iteration and get the bm.benchmark() to finish. But the memory is not freed from the GPU when bm.benchmark()
returns. Even after deleting the instance or reassigning bm, GPU garbage collection doesn't happen. E.g. If i run in an interactive python the memory allocation the memory is held until the python is killed.
Version information
cat requirements.txt
$ cat requirements.txt
argparse==1.4.0
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
anndata==0.10.7
annoy==1.17.3
array_api_compat==1.7.1
async-timeout==4.0.3
attrs==23.2.0
blosc2==2.6.2
certifi==2024.6.2
charset-normalizer==3.3.2
chex==0.1.86
contextlib2==21.6.0
contourpy==1.2.1
cycler==0.12.1
Cython==3.0.10
docrep==0.3.2
etils==1.7.0
exceptiongroup==1.2.1
faiss-gpu==1.7.2
filelock==3.14.0
flax==0.8.4
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.6.0
h5py==3.11.0
harmonypy==0.0.9
idna==3.7
igraph==0.11.5
imageio==2.34.1
importlib_resources==6.4.0
Jinja2==3.1.4
joblib==1.4.2
kiwisolver==1.4.5
lazy_loader==0.4
leidenalg==0.10.2
lightning==2.1.4
lightning-utilities==0.11.2
llvmlite==0.42.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
mdurl==0.1.2
ml-collections==0.1.1
ml-dtypes==0.4.0
mpmath==1.3.0
msgpack==1.0.8
mudata==0.2.3
multidict==6.0.5
multipledispatch==1.0.0
muon==0.1.5
natsort==8.4.0
ndindex==1.8
nest-asyncio==1.6.0
networkx==3.3
numba==0.59.1
numexpr==2.10.0
numpy==1.26.4
numpyro==0.15.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvcc-cu12==12.2.140
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.2.140
nvidia-nvtx-cu12==12.1.105
opt-einsum==3.3.0
optax==0.2.2
orbax-checkpoint==0.5.15
packaging==24.0
pandas==2.2.2
pathlib==1.0.1
patsy==0.5.6
pillow==10.3.0
pip==23.0.1
plottable==0.1.5
protobuf==5.27.0
py-cpuinfo==9.0.0
Pygments==2.18.0
pymde==0.1.18
pynndescent==0.5.12
pyparsing==3.1.2
pyro-api==0.1.2
pyro-ppl==1.9.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.2.5
pytz==2024.1
PyYAML==6.0.1
requests==2.32.3
rich==13.7.1
scanpy==1.9.8
scib-metrics==0.5.1
scikit-image==0.23.2
scikit-learn==1.5.0
scikit-misc==0.3.1
scipy==1.13.1
scrublet==0.2.3
scvi-tools==1.2.0
seaborn==0.13.2
session-info==1.0.0
setuptools==65.5.0
six==1.16.0
statsmodels==0.14.2
stdlib-list==0.10.0
sympy==1.12.1
tables==3.9.2
tensorstore==0.1.60
texttable==1.7.0
threadpoolctl==3.5.0
tifffile==2024.5.22
toolz==0.12.1
torch==2.3.0
torchmetrics==1.4.0.post0
torchvision==0.18.0
tqdm==4.66.4
triton==2.3.0
typing_extensions==4.12.1
tzdata==2024.1
umap-learn==0.5.6
urllib3==2.2.1
yarl==1.9.4
zipp==3.19.1
jax==0.4.28
$ cat Dockerfile
ARG CUDA_VERSION
FROM us-central1-docker.pkg.dev/dnastack-asap-parkinsons/workflow-images/util:1.1.1 as scripts
FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04
ENV CUDA_VERSION "${CUDA_VERSION}"
LABEL MAINTAINER="Karen Fang <karen@dnastack.com>"
ARG IMAGE_NAME
ENV IMAGE_NAME "${IMAGE_NAME}"
ARG IMAGE_TAG
ENV IMAGE_TAG "${IMAGE_TAG}"
ARG JAX_VERSION
ENV JAX_VERSION "${JAX_VERSION}"
ARG CUDA_VERSION
ENV CUDA_VERSION "${CUDA_VERSION}"
ENV DEBIAN_FRONTEND noninteractive
RUN apt-get -qq update \
&& apt-get -qq install \
build-essential \
wget \
time \
xxd \
curl \
zlib1g-dev \
libncursesw5-dev \
libssl-dev \
libsqlite3-dev \
tk-dev \
libgdbm-dev \
libc6-dev \
libbz2-dev \
libffi-dev \
liblzma-dev
ARG PYTHON3_VERSION
ENV PYTHON3_VERSION "${PYTHON3_VERSION}"
RUN curl -O https://www.python.org/ftp/python/${PYTHON3_VERSION}/Python-${PYTHON3_VERSION}.tar.xz && \
tar -xvf Python-${PYTHON3_VERSION}.tar.xz --directory /opt/ && \
rm Python-${PYTHON3_VERSION}.tar.xz
RUN cd /opt/Python-${PYTHON3_VERSION} && \
./configure && \
make && \
make altinstall
ENV PATH "${PATH}:/opt/Python-${PYTHON3_VERSION}"
RUN ln -s /opt/Python-${PYTHON3_VERSION}/python /opt/Python-${PYTHON3_VERSION}/python3
COPY ./requirements.txt /opt/requirements.txt
RUN python3 -m pip install -r /opt/requirements.txt
# CUDA-enabled jaxlib is needed
RUN python3 -m pip install --upgrade "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# Add override scripts
COPY scripts/main /opt/scripts/main
COPY scripts/utility /opt/scripts/utility
# Add resources
COPY resources /opt/resources
# gcloud sdk; needed to upload output files
ARG GCLOUD_CLI_VERSION
ENV GCLOUD_CLI_VERSION "${GCLOUD_CLI_VERSION}"
RUN wget "https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-${GCLOUD_CLI_VERSION}-linux-x86_64.tar.gz" \
&& tar -zxvf "google-cloud-cli-${GCLOUD_CLI_VERSION}-linux-x86_64.tar.gz" --directory /opt \
&& rm "google-cloud-cli-${GCLOUD_CLI_VERSION}-linux-x86_64.tar.gz"
ENV PATH "${PATH}:/opt/google-cloud-sdk/bin"
COPY --from=scripts /opt/scripts /opt/scripts
ENV PATH "${PATH}:/opt/scripts"
Activity