Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 873 improve mem find label issues #885

Merged

Conversation

kylegallatin
Copy link
Contributor

@kylegallatin kylegallatin commented Nov 6, 2023

Summary

Purpose:
Reduce memory usage of filter.find_label_issues()

Change segmentation.filter.find_label_iussues() to (1) use a modified version of find_label_issues_batched() and iterate over images instead of full arrays to reduce the average memory usage

Usage has not changed.

Impact

While the changes should be completely backwards compatible, the functionality in find_label_issues_batched() and the LabelInspector class is now a bit duplicated and fragmented.

Screenshots

Testing

Existing tests and additional unit tests were used to confirm (1) the same results and (2) reduced memory usage.

python -m pytest -s "tests/test_segmentation.py::test_find_label_issues_memmap"

Before iterating over an image (commit: 649f811):

(cl) kgallatin@8141 cleanlab % python -m pytest -s "tests/test_segmentation.py::test_find_label_issues_memmap"
======================================================================================== test session starts =========================================================================================
platform darwin -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /Users/kgallatin/cleanlab
configfile: pyproject.toml
plugins: hypothesis-6.88.1, cov-4.1.0, lazy-fixture-0.6.3
collected 1 item                                                                                                                                                                                     

tests/test_segmentation.py (3, 2, 10, 10)
(3, 10, 10)
number of examples processed for estimating thresholds: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 800000/800000 [00:00<00:00, 2724275.00it/s]
number of examples processed for checking labels: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████| 800000/800000 [00:00<00:00, 3846467.91it/s]
Total number of examples whose labels have been evaluated: 800000
Max memory used: 114.3046875 MiB
Time taken: 4.98208212852478 seconds
.

========================================================================================= 1 passed in 8.82s ==========================================================================================

After iterating over an image (commit: 438ea79):

(cl) kgallatin@8141 cleanlab % python -m pytest -s "tests/test_segmentation.py::test_find_label_issues_memmap"
======================================================================================== test session starts =========================================================================================
platform darwin -- Python 3.10.13, pytest-7.4.3, pluggy-1.3.0
rootdir: /Users/kgallatin/cleanlab
configfile: pyproject.toml
plugins: hypothesis-6.88.1, cov-4.1.0, lazy-fixture-0.6.3
collected 1 item                                                                                                                                                                                     

number of examples processed for estimating thresholds: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 90.23it/s]
number of examples processed for checking labels: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:00<00:00, 125.26it/s]
Total number of examples whose labels have been evaluated: 800000
Average memory used: 86.19921875 MiB
Time taken: 5.0795578956604 seconds
.

========================================================================================= 1 passed in 7.53s ==========================================================================================

Validating results match older cleanlab code

print(image_scores_softmin, pixel_scores) was added to test_get_label_quality_scores_sizes and then python -m pytest -s "tests/test_segmentation.py::test_get_label_quality_scores_sizes" was run at different SHAs.

Original commit

[0.09846431 0.10313648] [[[0.89136214 0.56604853 0.70356405 ... 0.55594452 0.56947363 0.3440454 ]
  [0.33229155 0.87934306 0.95390372 ... 0.31529915 0.03921176 0.99238333]
  [0.1836061  0.64330706 0.68580903 ... 0.04706563 0.23003393 0.08291305]
  ...
  [0.76945551 0.95084098 0.79107436 ... 0.53341645 0.34642066 0.33482438]
  [0.14981747 0.63832929 0.20135665 ... 0.54854587 0.21920963 0.90021229]
  [0.75835984 0.0321533  0.207709   ... 0.74089989 0.87371662 0.57492294]]

 [[0.27125528 0.91985156 0.30336524 ... 0.46254333 0.01366739 0.20941733]
  [0.81830203 0.82138594 0.93408234 ... 0.16271972 0.68834335 0.25020788]
  [0.35912337 0.04292717 0.70221824 ... 0.27187053 0.84944595 0.92747146]
  ...
  [0.80970585 0.26072079 0.48013684 ... 0.47340933 0.00754304 0.21593769]
  [0.68927337 0.53519544 0.45118022 ... 0.88845114 0.58060761 0.85428862]
  [0.10506886 0.64047901 0.71572487 ... 0.21239425 0.62293779 0.79775874]]]

Current commit

[0.09846431 0.10313648] [[[0.89136214 0.56604853 0.70356405 ... 0.55594452 0.56947363 0.3440454 ]
  [0.33229155 0.87934306 0.95390372 ... 0.31529915 0.03921176 0.99238333]
  [0.1836061  0.64330706 0.68580903 ... 0.04706563 0.23003393 0.08291305]
  ...
  [0.76945551 0.95084098 0.79107436 ... 0.53341645 0.34642066 0.33482438]
  [0.14981747 0.63832929 0.20135665 ... 0.54854587 0.21920963 0.90021229]
  [0.75835984 0.0321533  0.207709   ... 0.74089989 0.87371662 0.57492294]]

 [[0.27125528 0.91985156 0.30336524 ... 0.46254333 0.01366739 0.20941733]
  [0.81830203 0.82138594 0.93408234 ... 0.16271972 0.68834335 0.25020788]
  [0.35912337 0.04292717 0.70221824 ... 0.27187053 0.84944595 0.92747146]
  ...
  [0.80970585 0.26072079 0.48013684 ... 0.47340933 0.00754304 0.21593769]
  [0.68927337 0.53519544 0.45118022 ... 0.88845114 0.58060761 0.85428862]
  [0.10506886 0.64047901 0.71572487 ... 0.21239425 0.62293779 0.79775874]]]

Unaddressed Cases

Links to Relevant Issues or Conversations

#863

References

Reviewer Notes

  • Is there anything that should be addressed within this scope?
  • Addition recommendations on testing memory usage?

@CLAassistant
Copy link

CLAassistant commented Nov 6, 2023

CLA assistant check
All committers have signed the CLA.

np.save(pred_probs_file, np.random.rand(20, 5, 200, 200))
np.save(labels_file, np.random.randint(0, 2, (20, 200, 200)))

# Load the numpy arrays from disk and convert to zarr arrays
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AFAICT there's no zarr being used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops! Leftover from previous work, will remove.

end_time = time.time()
print(f"Average memory used: {end_mem - start_mem} MiB")
print(f"Time taken: {end_time - start_time} seconds")

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add some sort of assert about the mem-usage?

Eg. this test could run the same dataset (that contains say 200 images) twice, one time with: batch_size = 200, one time with batch_size = 5. Verify mem-usage is much lower with batch_size = 5, and that the results match between the two different runs

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea! You got it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting...I'm trying to profile and compare these in the same run using

    # Test with low batch size
    peak_mem_low_batch_size = memory_usage(
        proc=(find_label_issues, (pred_labels, pred_probs), {'n_jobs': None, 'batch_size': 5}),
        interval=0.01,
        max_usage=True,
        include_children=True,
    )
    print(f"Peak memory used with batch size 5: {peak_mem_low_batch_size} MiB")

    # Test with high batch size
    peak_mem_high_batch_size = memory_usage(
        proc=(find_label_issues, (pred_labels, pred_probs), {'n_jobs': None, 'batch_size': 200}),
        interval=0.01,
        max_usage=True,
        include_children=True,
    )
    print(f"Peak memory used with batch size 200: {peak_mem_high_batch_size} MiB")

However, I notice whichever method gets called first uses higher memory. Naively searching leads to believe this is OS memory caching or Python garbage collection...but to eliminate those variables entirely and still effectively test mem I'll need to think of a better way to isolate them, so stay tuned...

@jwmueller
Copy link
Member

duplicating some of LabelInspector internals in this file seems fine. I think it'll be too complicated to perfectly refactor the existing LabelInspector to handle this weird application of it to image data, and don't think it's worth the potential of introducing bugs in the original LabelInspector or the clarity of the original code (which is intended to be easy for developers to modify).

@jwmueller
Copy link
Member

don't forget to sign the CLA before requesting a review of the PR

Copy link

codecov bot commented Nov 14, 2023

Codecov Report

All modified and coverable lines are covered by tests ✅

Comparison is base (3fb20e9) 96.78% compared to head (09523d8) 96.77%.

❗ Current head 09523d8 differs from pull request most recent head 3e3e204. Consider uploading reports for the commit 3e3e204 to get more accurate results

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #885      +/-   ##
==========================================
- Coverage   96.78%   96.77%   -0.02%     
==========================================
  Files          71       67       -4     
  Lines        5605     5330     -275     
  Branches      953      922      -31     
==========================================
- Hits         5425     5158     -267     
+ Misses         93       88       -5     
+ Partials       87       84       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@kylegallatin kylegallatin marked this pull request as ready for review November 14, 2023 19:21
Comment on lines +227 to +231
# with pytest.raises(ValueError):
# get_label_quality_scores(labels, pred_probs, method="num_pixel_issues", batch_size=-1)
# get_label_quality_scores(
# labels, pred_probs, method="num_pixel_issues", downsample=1, batch_size=0
# )
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is now failing. I'm unsure which ValueError was supposed to be raised and why (guessing batch_size <= 0?) but looks unrelated to my code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid commenting out tests or code that is broken. Instead, we should mark it as xfail, until we either fix it or any bug causing the error in a separate PR.

Copy link
Member

@elisno elisno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your efforts in improving our codebase, @kylegallatin! I'm accepting this PR as it stands.


During the review, I noted that the existing test was uncallable. To address this, I implemented a version of the test myself, enabling preliminary verification of the improvements. For future updates, ensure tests are up-to-date and callable to facilitate smoother integration.

The plan is to address the following enhancements in a follow-up PR:

  • Remove the get_images_to_load Function:
    • Simplify the code by precomputing the number of images per batch.
  • Precompute Batch Size:
    • Calculate images_per_batch once, based on batch size and image size, for clearer, more efficient loop logic.
  • Switch to for Loops:
    • Replace current while loops with for loops to enhance readability and maintainability.
  • Improve Progress Bar Updates:
    • Update the progress bar to accurately reflect the number of images processed per batch.
  • Enhance Variable Naming:
    • Introduce start_index and end_index to define batch ranges more clearly.
  • Reactivate test:

These planned changes are meant to further refine the code and ensure it aligns with our codebase's standards.

Most of the items above can be resolved like so:

diff --git a/cleanlab/segmentation/filter.py b/cleanlab/segmentation/filter.py
--- cleanlab/segmentation/filter.py
+++ cleanlab/segmentation/filter.py
@@ -124,20 +124,8 @@
         pred_probs_flat = np.moveaxis(pred_probs, 0, 1).reshape(num_classes, -1)
 
         return labels_flat, pred_probs_flat.T
 
-    def get_images_to_load(pre_pred_probs, i, n, batch_size):
-        """
-        This function loads images until the batch size is reached or the end of the dataset is reached.
-        """
-        images_to_load = 1
-        while (
-            np.prod(pre_pred_probs[i : i + images_to_load].shape[1:]) < batch_size
-            and i + images_to_load < n
-        ):
-            images_to_load += 1
-        return images_to_load
-
     ##
     _check_input(labels, pred_probs)
 
     # Added Downsampling
@@ -162,36 +150,33 @@
         from tqdm.auto import tqdm
 
         pbar = tqdm(desc="number of examples processed for estimating thresholds", total=n)
 
-    i = 0
-    while i < n:
-        images_to_load = get_images_to_load(pre_pred_probs, i, n, batch_size)
-        end_index = i + images_to_load
+    # Precompute the size of each image in the batch
+    image_size = np.prod(pre_pred_probs.shape[1:])
+    images_per_batch = batch_size // image_size + 1
+
+    for start_index in range(0, n, images_per_batch):
+        end_index = min(start_index + images_per_batch, n)
         labels_batch, pred_probs_batch = flatten_and_preprocess_masks(
-            pre_labels[i:end_index], pre_pred_probs[i:end_index]
+            pre_labels[start_index:end_index], pre_pred_probs[start_index:end_index]
         )
-        i = end_index
         lab.update_confident_thresholds(labels_batch, pred_probs_batch)
         if verbose:
-            pbar.update(images_to_load)
+            pbar.update(end_index - start_index)
 
     if verbose:
         pbar.close()
         pbar = tqdm(desc="number of examples processed for checking labels", total=n)
 
-    i = 0
-    while i < n:
-        images_to_load = get_images_to_load(pre_pred_probs, i, n, batch_size)
-
-        end_index = i + images_to_load
+   for start_index in range(0, n, images_per_batch):
+        end_index = min(start_index + images_per_batch, n)
         labels_batch, pred_probs_batch = flatten_and_preprocess_masks(
-            pre_labels[i:end_index], pre_pred_probs[i:end_index]
+            pre_labels[start_index:end_index], pre_pred_probs[start_index:end_index]
         )
-        i = end_index
         _ = lab.score_label_quality(labels_batch, pred_probs_batch)
         if verbose:
-            pbar.update(images_to_load)
+            pbar.update(end_index - start_index)
 
     if verbose:
         pbar.close()

Comment on lines +128 to +139
def get_images_to_load(pre_pred_probs, i, n, batch_size):
"""
This function loads images until the batch size is reached or the end of the dataset is reached.
"""
images_to_load = 1
while (
np.prod(pre_pred_probs[i : i + images_to_load].shape[1:]) < batch_size
and i + images_to_load < n
):
images_to_load += 1
return images_to_load

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper function isn't quite accurate, it just counts the number of images that will fit in a batch.

Since the images have a fixed size, the number of images is also fixed.

You just have to set this variable once in the function (not at each iteration):

image_size = np.prod(pre_pred_probs.shape[1:])
images_to_load = batch_size // image_size + 1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The end index would of course have to be capped at the length of the dataset.

@elisno elisno merged commit f3a65b8 into cleanlab:master Dec 15, 2023
19 checks passed
elisno added a commit to elisno/cleanlab that referenced this pull request Dec 15, 2023
function for segmentation

Addresses items mentioned in cleanlab#885 (review)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants