Skip to content

bm.benchmark() fails to de-allocate GPU memory. garbage collection fail? #186

Open
@ergonyc

Description

Report

I'm running bm.benchmark() using three embedding_obsm_keys on 2.7m cell dataset. It previously ran, albiet slowly on a smaller dataset using only two keys.
I can reproduce it on a local 4080, and on dual remote v100s.

CODE: ( almost directly from tutoral)

adata_input = "artifacts/cohortv2.2.0/cohort.final_adata.h5ad" #~2.7m cells x 3k genes
# Set CPUs to use for parallel computing
sc._settings.ScanpyConfig.n_jobs = -1
batch_key = "sample"
label_key = "cell_type"
adata = sc.read_h5ad(adata_input)  # type: ignore

adata.obsm["Unintegrated"] = adata2.obsm["X_pca"]
biocons = BioConservation(isolated_labels=False)
bm = Benchmarker(
    adata,
    batch_key=batch_key,
    label_key=label_key,
    embedding_obsm_keys=["Unintegrated", "X_scvi", "X_pca_harmony"],
    pre_integrated_embedding_obsm_key="X_pca",
    bio_conservation_metrics=biocons,
    n_jobs=-1,
)
bm.prepare(neighbor_computer=faiss_brute_force_nn)  # full GPU half memory (scales with different GPUs)
bm.benchmark() 

FAILURE:

Computing neighbors: 100%|██████████| 3/3 [04:21<00:00, 87.18s/it]
Embeddings:   0%|          | 0/3 [00:00<?, ?it/s]                                                2024-11-25 23:57:13.747881: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 2.23GiB (rounded to 2395844096)requested by op        
2024-11-25 23:57:13.748060: W external/tsl/tsl/framework/bfc_allocator.cc:494] **********************_********************__**********__*****************_________________*********
E1125 23:57:13.748128      17 pjrt_stream_executor_client.cc:2826] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2395843920 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    2.23GiB
              constant allocation:         0B
        maybe_live_out allocation:    2.23GiB
     preallocated temp allocation:         0B
                 total allocation:    4.46GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 2.23GiB
		Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=2]" source_file="/opt/scripts/main/artifact_metrics.py" source_line=76
		XLA Label: fusion
		Shape: s32[3327561,90,2]
		==========================

	Buffer 2:
		Size: 1.12GiB
		Entry Parameter Subshape: s32[3327561,90,1]
		==========================

	Buffer 3:
		Size: 1.12GiB
		Entry Parameter Subshape: s32[3327561,90,1]
		==========================


Embeddings:   0%|          | 0/3 [7:09:01<?, ?it/s]
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/opt/scripts/main/artifact_metrics.py", line 131, in <module>
    main(args)
  File "/opt/scripts/main/artifact_metrics.py", line 96, in main
    report_df = get_artifact_metrics(adata, args.batch_key, args.label_key, report_dir)
  File "/opt/scripts/main/artifact_metrics.py", line 76, in get_artifact_metrics
    bm.benchmark()
  File "/usr/local/lib/python3.10/site-packages/scib_metrics/benchmark/_core.py", line 221, in benchmark
    metric_value = getattr(MetricAnnDataAPI, metric_name)(ad, metric_fn)
  File "/usr/local/lib/python3.10/site-packages/scib_metrics/benchmark/_core.py", line 92, in <lambda>
    ilisi_knn = lambda ad, fn: fn(ad.uns["90_neighbor_res"], ad.obs[_BATCH])
  File "/usr/local/lib/python3.10/site-packages/scib_metrics/_lisi.py", line 66, in ilisi_knn
    lisi = lisi_knn(X, batches, perplexity=perplexity)
  File "/usr/local/lib/python3.10/site-packages/scib_metrics/_lisi.py", line 36, in lisi_knn
    simpson = compute_simpson_index(
  File "/usr/local/lib/python3.10/site-packages/scib_metrics/utils/_lisi.py", line 132, in compute_simpson_index
    out = jax.vmap(simpson_fn)(knn_dists, knn_labels, self_mask)
  File "/usr/local/lib/python3.10/site-packages/scib_metrics/utils/_lisi.py", line 89, in _compute_simpson_index_cell
    return jnp.where(H == 0, -1, _non_zero_H_simpson())
  File "/usr/local/lib/python3.10/site-packages/scib_metrics/utils/_lisi.py", line 86, in _non_zero_H_simpson
    sumP = jnp.bincount(knn_labels_row, weights=P, length=n_batches)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1453, in bincount
    return zeros(length, _dtype(weights)).at[clip(x, 0)].add(weights)
  File "/usr/local/lib/python3.10/site-packages/jax/_src/numpy/array_methods.py", line 529, in add
    return scatter._scatter_update(self.array, self.index, values,
  File "/usr/local/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 80, in _scatter_update
    return _scatter_impl(x, y, scatter_op, treedef, static_idx, dynamic_idx,
  File "/usr/local/lib/python3.10/site-packages/jax/_src/ops/scatter.py", line 131, in _scatter_impl
    out = scatter_op(
ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 2395843920 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:    2.23GiB
              constant allocation:         0B
        maybe_live_out allocation:    2.23GiB
     preallocated temp allocation:         0B
                 total allocation:    4.46GiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 2.23GiB
		Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=2]" source_file="/opt/scripts/main/artifact_metrics.py" source_line=76
		XLA Label: fusion
		Shape: s32[3327561,90,2]
		==========================

	Buffer 2:
		Size: 1.12GiB
		Entry Parameter Subshape: s32[3327561,90,1]
		==========================

	Buffer 3:
		Size: 1.12GiB
		Entry Parameter Subshape: s32[3327561,90,1]
		==========================

The bm.prepare(neighbor_computer=faiss_brute_force_nn) works fine and speed scales nicely with resources. (Thanks Jax!). Which exact metric it fails on seems to vary, but in general for each itteration of the embedding key:

line 211 benchmarker/_core.py for emb_key, ad in tqdm(self._emb_adatas.items(), desc="Embeddings", position=0, colour="green"):
The memory usage steps on each iteration, as the metric_fn(ad) is called where I believe each ad is loaded to GPU memory. The proper behavior should probably be to free that GPU memory and/or re-use it for the other emb_key.

More corroborative evidence of issue:
If I subsample the adata to be small enough to not hit the ceiling on the three iterations, I can follow the steps of GPU memory utilization jumping at each iteration and get the bm.benchmark() to finish. But the memory is not freed from the GPU when bm.benchmark() returns. Even after deleting the instance or reassigning bm, GPU garbage collection doesn't happen. E.g. If i run in an interactive python the memory allocation the memory is held until the python is killed.

Version information

cat requirements.txt

$ cat requirements.txt

argparse==1.4.0
absl-py==2.1.0
aiohttp==3.9.5
aiosignal==1.3.1
anndata==0.10.7
annoy==1.17.3
array_api_compat==1.7.1
async-timeout==4.0.3
attrs==23.2.0
blosc2==2.6.2
certifi==2024.6.2
charset-normalizer==3.3.2
chex==0.1.86
contextlib2==21.6.0
contourpy==1.2.1
cycler==0.12.1
Cython==3.0.10
docrep==0.3.2
etils==1.7.0
exceptiongroup==1.2.1
faiss-gpu==1.7.2
filelock==3.14.0
flax==0.8.4
fonttools==4.53.0
frozenlist==1.4.1
fsspec==2024.6.0
h5py==3.11.0
harmonypy==0.0.9
idna==3.7
igraph==0.11.5
imageio==2.34.1
importlib_resources==6.4.0
Jinja2==3.1.4
joblib==1.4.2
kiwisolver==1.4.5
lazy_loader==0.4
leidenalg==0.10.2
lightning==2.1.4
lightning-utilities==0.11.2
llvmlite==0.42.0
markdown-it-py==3.0.0
MarkupSafe==2.1.5
matplotlib==3.9.0
mdurl==0.1.2
ml-collections==0.1.1
ml-dtypes==0.4.0
mpmath==1.3.0
msgpack==1.0.8
mudata==0.2.3
multidict==6.0.5
multipledispatch==1.0.0
muon==0.1.5
natsort==8.4.0
ndindex==1.8
nest-asyncio==1.6.0
networkx==3.3
numba==0.59.1
numexpr==2.10.0
numpy==1.26.4
numpyro==0.15.0
nvidia-cublas-cu12==12.1.3.1
nvidia-cuda-cupti-cu12==12.1.105
nvidia-cuda-nvcc-cu12==12.2.140
nvidia-cuda-nvrtc-cu12==12.1.105
nvidia-cuda-runtime-cu12==12.1.105
nvidia-cudnn-cu12==8.9.2.26
nvidia-cufft-cu12==11.0.2.54
nvidia-curand-cu12==10.3.2.106
nvidia-cusolver-cu12==11.4.5.107
nvidia-cusparse-cu12==12.1.0.106
nvidia-nccl-cu12==2.20.5
nvidia-nvjitlink-cu12==12.2.140
nvidia-nvtx-cu12==12.1.105
opt-einsum==3.3.0
optax==0.2.2
orbax-checkpoint==0.5.15
packaging==24.0
pandas==2.2.2
pathlib==1.0.1
patsy==0.5.6
pillow==10.3.0
pip==23.0.1
plottable==0.1.5
protobuf==5.27.0
py-cpuinfo==9.0.0
Pygments==2.18.0
pymde==0.1.18
pynndescent==0.5.12
pyparsing==3.1.2
pyro-api==0.1.2
pyro-ppl==1.9.1
python-dateutil==2.9.0.post0
pytorch-lightning==2.2.5
pytz==2024.1
PyYAML==6.0.1
requests==2.32.3
rich==13.7.1
scanpy==1.9.8
scib-metrics==0.5.1
scikit-image==0.23.2
scikit-learn==1.5.0
scikit-misc==0.3.1
scipy==1.13.1
scrublet==0.2.3
scvi-tools==1.2.0
seaborn==0.13.2
session-info==1.0.0
setuptools==65.5.0
six==1.16.0
statsmodels==0.14.2
stdlib-list==0.10.0
sympy==1.12.1
tables==3.9.2
tensorstore==0.1.60
texttable==1.7.0
threadpoolctl==3.5.0
tifffile==2024.5.22
toolz==0.12.1
torch==2.3.0
torchmetrics==1.4.0.post0
torchvision==0.18.0
tqdm==4.66.4
triton==2.3.0
typing_extensions==4.12.1
tzdata==2024.1
umap-learn==0.5.6
urllib3==2.2.1
yarl==1.9.4
zipp==3.19.1
jax==0.4.28
$ cat Dockerfile

ARG CUDA_VERSION
FROM us-central1-docker.pkg.dev/dnastack-asap-parkinsons/workflow-images/util:1.1.1 as scripts

FROM nvcr.io/nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04
ENV CUDA_VERSION "${CUDA_VERSION}"

LABEL MAINTAINER="Karen Fang <karen@dnastack.com>"

ARG IMAGE_NAME
ENV IMAGE_NAME "${IMAGE_NAME}"
ARG IMAGE_TAG
ENV IMAGE_TAG "${IMAGE_TAG}"
ARG JAX_VERSION
ENV JAX_VERSION "${JAX_VERSION}"
ARG CUDA_VERSION
ENV CUDA_VERSION "${CUDA_VERSION}"

ENV DEBIAN_FRONTEND noninteractive

RUN apt-get -qq update \
	&& apt-get -qq install \
		build-essential \
		wget \
		time \
		xxd \
		curl \
		zlib1g-dev \
		libncursesw5-dev \
		libssl-dev \
		libsqlite3-dev \
		tk-dev \
		libgdbm-dev \
		libc6-dev \
		libbz2-dev \
		libffi-dev \
		liblzma-dev

ARG PYTHON3_VERSION
ENV PYTHON3_VERSION "${PYTHON3_VERSION}"
RUN curl -O https://www.python.org/ftp/python/${PYTHON3_VERSION}/Python-${PYTHON3_VERSION}.tar.xz && \
	tar -xvf Python-${PYTHON3_VERSION}.tar.xz --directory /opt/ && \
	rm Python-${PYTHON3_VERSION}.tar.xz
RUN cd /opt/Python-${PYTHON3_VERSION} && \
	./configure && \
	make && \
	make altinstall

ENV PATH "${PATH}:/opt/Python-${PYTHON3_VERSION}"

RUN ln -s /opt/Python-${PYTHON3_VERSION}/python /opt/Python-${PYTHON3_VERSION}/python3

COPY ./requirements.txt /opt/requirements.txt
RUN python3 -m pip install -r /opt/requirements.txt

# CUDA-enabled jaxlib is needed
RUN python3 -m pip install --upgrade "jax[cuda12_pip]==${JAX_VERSION}" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

# Add override scripts
COPY scripts/main /opt/scripts/main
COPY scripts/utility /opt/scripts/utility

# Add resources
COPY resources /opt/resources

# gcloud sdk; needed to upload output files
ARG GCLOUD_CLI_VERSION
ENV GCLOUD_CLI_VERSION "${GCLOUD_CLI_VERSION}"
RUN wget "https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-cli-${GCLOUD_CLI_VERSION}-linux-x86_64.tar.gz" \
	&& tar -zxvf "google-cloud-cli-${GCLOUD_CLI_VERSION}-linux-x86_64.tar.gz" --directory /opt \
	&& rm "google-cloud-cli-${GCLOUD_CLI_VERSION}-linux-x86_64.tar.gz"

ENV PATH "${PATH}:/opt/google-cloud-sdk/bin"

COPY --from=scripts /opt/scripts /opt/scripts
ENV PATH "${PATH}:/opt/scripts"

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      `bm.benchmark()` fails to de-allocate GPU memory. garbage collection fail? · Issue #186 · YosefLab/scib-metrics