-
Notifications
You must be signed in to change notification settings - Fork 786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue 873 improve mem find label issues #885
Issue 873 improve mem find label issues #885
Conversation
tests/test_segmentation.py
Outdated
np.save(pred_probs_file, np.random.rand(20, 5, 200, 200)) | ||
np.save(labels_file, np.random.randint(0, 2, (20, 200, 200))) | ||
|
||
# Load the numpy arrays from disk and convert to zarr arrays |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
AFAICT there's no zarr
being used here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops! Leftover from previous work, will remove.
end_time = time.time() | ||
print(f"Average memory used: {end_mem - start_mem} MiB") | ||
print(f"Time taken: {end_time - start_time} seconds") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we add some sort of assert about the mem-usage?
Eg. this test could run the same dataset (that contains say 200 images) twice, one time with: batch_size = 200, one time with batch_size = 5. Verify mem-usage is much lower with batch_size = 5, and that the results match between the two different runs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great idea! You got it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting...I'm trying to profile and compare these in the same run using
# Test with low batch size
peak_mem_low_batch_size = memory_usage(
proc=(find_label_issues, (pred_labels, pred_probs), {'n_jobs': None, 'batch_size': 5}),
interval=0.01,
max_usage=True,
include_children=True,
)
print(f"Peak memory used with batch size 5: {peak_mem_low_batch_size} MiB")
# Test with high batch size
peak_mem_high_batch_size = memory_usage(
proc=(find_label_issues, (pred_labels, pred_probs), {'n_jobs': None, 'batch_size': 200}),
interval=0.01,
max_usage=True,
include_children=True,
)
print(f"Peak memory used with batch size 200: {peak_mem_high_batch_size} MiB")
However, I notice whichever method gets called first uses higher memory. Naively searching leads to believe this is OS memory caching or Python garbage collection...but to eliminate those variables entirely and still effectively test mem I'll need to think of a better way to isolate them, so stay tuned...
duplicating some of |
don't forget to sign the CLA before requesting a review of the PR |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #885 +/- ##
==========================================
- Coverage 96.78% 96.77% -0.02%
==========================================
Files 71 67 -4
Lines 5605 5330 -275
Branches 953 922 -31
==========================================
- Hits 5425 5158 -267
+ Misses 93 88 -5
+ Partials 87 84 -3 ☔ View full report in Codecov by Sentry. |
# with pytest.raises(ValueError): | ||
# get_label_quality_scores(labels, pred_probs, method="num_pixel_issues", batch_size=-1) | ||
# get_label_quality_scores( | ||
# labels, pred_probs, method="num_pixel_issues", downsample=1, batch_size=0 | ||
# ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test is now failing. I'm unsure which ValueError
was supposed to be raised and why (guessing batch_size <= 0?) but looks unrelated to my code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid commenting out tests or code that is broken. Instead, we should mark it as xfail
, until we either fix it or any bug causing the error in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your efforts in improving our codebase, @kylegallatin! I'm accepting this PR as it stands.
During the review, I noted that the existing test was uncallable. To address this, I implemented a version of the test myself, enabling preliminary verification of the improvements. For future updates, ensure tests are up-to-date and callable to facilitate smoother integration.
The plan is to address the following enhancements in a follow-up PR:
- Remove the
get_images_to_load
Function:- Simplify the code by precomputing the number of images per batch.
- Precompute Batch Size:
- Calculate
images_per_batch
once, based on batch size and image size, for clearer, more efficient loop logic.
- Calculate
- Switch to
for
Loops:- Replace current
while
loops withfor
loops to enhance readability and maintainability.
- Replace current
- Improve Progress Bar Updates:
- Update the progress bar to accurately reflect the number of images processed per batch.
- Enhance Variable Naming:
- Introduce
start_index
andend_index
to define batch ranges more clearly.
- Introduce
- Reactivate test:
- Address the following comment Issue 873 improve mem find label issues #885 (comment)
These planned changes are meant to further refine the code and ensure it aligns with our codebase's standards.
Most of the items above can be resolved like so:
diff --git a/cleanlab/segmentation/filter.py b/cleanlab/segmentation/filter.py
--- cleanlab/segmentation/filter.py
+++ cleanlab/segmentation/filter.py
@@ -124,20 +124,8 @@
pred_probs_flat = np.moveaxis(pred_probs, 0, 1).reshape(num_classes, -1)
return labels_flat, pred_probs_flat.T
- def get_images_to_load(pre_pred_probs, i, n, batch_size):
- """
- This function loads images until the batch size is reached or the end of the dataset is reached.
- """
- images_to_load = 1
- while (
- np.prod(pre_pred_probs[i : i + images_to_load].shape[1:]) < batch_size
- and i + images_to_load < n
- ):
- images_to_load += 1
- return images_to_load
-
##
_check_input(labels, pred_probs)
# Added Downsampling
@@ -162,36 +150,33 @@
from tqdm.auto import tqdm
pbar = tqdm(desc="number of examples processed for estimating thresholds", total=n)
- i = 0
- while i < n:
- images_to_load = get_images_to_load(pre_pred_probs, i, n, batch_size)
- end_index = i + images_to_load
+ # Precompute the size of each image in the batch
+ image_size = np.prod(pre_pred_probs.shape[1:])
+ images_per_batch = batch_size // image_size + 1
+
+ for start_index in range(0, n, images_per_batch):
+ end_index = min(start_index + images_per_batch, n)
labels_batch, pred_probs_batch = flatten_and_preprocess_masks(
- pre_labels[i:end_index], pre_pred_probs[i:end_index]
+ pre_labels[start_index:end_index], pre_pred_probs[start_index:end_index]
)
- i = end_index
lab.update_confident_thresholds(labels_batch, pred_probs_batch)
if verbose:
- pbar.update(images_to_load)
+ pbar.update(end_index - start_index)
if verbose:
pbar.close()
pbar = tqdm(desc="number of examples processed for checking labels", total=n)
- i = 0
- while i < n:
- images_to_load = get_images_to_load(pre_pred_probs, i, n, batch_size)
-
- end_index = i + images_to_load
+ for start_index in range(0, n, images_per_batch):
+ end_index = min(start_index + images_per_batch, n)
labels_batch, pred_probs_batch = flatten_and_preprocess_masks(
- pre_labels[i:end_index], pre_pred_probs[i:end_index]
+ pre_labels[start_index:end_index], pre_pred_probs[start_index:end_index]
)
- i = end_index
_ = lab.score_label_quality(labels_batch, pred_probs_batch)
if verbose:
- pbar.update(images_to_load)
+ pbar.update(end_index - start_index)
if verbose:
pbar.close()
def get_images_to_load(pre_pred_probs, i, n, batch_size): | ||
""" | ||
This function loads images until the batch size is reached or the end of the dataset is reached. | ||
""" | ||
images_to_load = 1 | ||
while ( | ||
np.prod(pre_pred_probs[i : i + images_to_load].shape[1:]) < batch_size | ||
and i + images_to_load < n | ||
): | ||
images_to_load += 1 | ||
return images_to_load | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This helper function isn't quite accurate, it just counts the number of images that will fit in a batch.
Since the images have a fixed size, the number of images is also fixed.
You just have to set this variable once in the function (not at each iteration):
image_size = np.prod(pre_pred_probs.shape[1:])
images_to_load = batch_size // image_size + 1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The end index would of course have to be capped at the length of the dataset.
function for segmentation Addresses items mentioned in cleanlab#885 (review)
Summary
Purpose:
Reduce memory usage of
filter.find_label_issues()
Change segmentation.filter.find_label_iussues() to (1) use a modified version of
find_label_issues_batched()
and iterate over images instead of full arrays to reduce the average memory usageUsage has not changed.
Impact
While the changes should be completely backwards compatible, the functionality in
find_label_issues_batched()
and theLabelInspector
class is now a bit duplicated and fragmented.Screenshots
Testing
Existing tests and additional unit tests were used to confirm (1) the same results and (2) reduced memory usage.
python -m pytest -s "tests/test_segmentation.py::test_find_label_issues_memmap"
Before iterating over an image (commit: 649f811):
After iterating over an image (commit: 438ea79):
Validating results match older cleanlab code
print(image_scores_softmin, pixel_scores)
was added totest_get_label_quality_scores_sizes
and thenpython -m pytest -s "tests/test_segmentation.py::test_get_label_quality_scores_sizes"
was run at different SHAs.Original commit
Current commit
Unaddressed Cases
Links to Relevant Issues or Conversations
#863
References
Reviewer Notes